Skip to content

Commit 0385576

Browse files
feat(openai): add openai embeddings api support (#1345)
1 parent 37468a1 commit 0385576

2 files changed

Lines changed: 164 additions & 18 deletions

File tree

langfuse/openai.py

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,20 @@ class OpenAiDefinition:
177177
sync=False,
178178
min_version="1.66.0",
179179
),
180+
OpenAiDefinition(
181+
module="openai.resources.embeddings",
182+
object="Embeddings",
183+
method="create",
184+
type="embedding",
185+
sync=True,
186+
),
187+
OpenAiDefinition(
188+
module="openai.resources.embeddings",
189+
object="AsyncEmbeddings",
190+
method="create",
191+
type="embedding",
192+
sync=False,
193+
),
180194
]
181195

182196

@@ -340,10 +354,13 @@ def _extract_chat_response(kwargs: Any) -> Any:
340354

341355

342356
def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> Any:
343-
name = kwargs.get("name", "OpenAI-generation")
357+
default_name = (
358+
"OpenAI-embedding" if resource.type == "embedding" else "OpenAI-generation"
359+
)
360+
name = kwargs.get("name", default_name)
344361

345362
if name is None:
346-
name = "OpenAI-generation"
363+
name = default_name
347364

348365
if name is not None and not isinstance(name, str):
349366
raise TypeError("name must be a string")
@@ -395,6 +412,8 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> A
395412
prompt = kwargs.get("input", None)
396413
elif resource.type == "chat":
397414
prompt = _extract_chat_prompt(kwargs)
415+
elif resource.type == "embedding":
416+
prompt = kwargs.get("input", None)
398417

399418
parsed_temperature = (
400419
kwargs.get("temperature", 1)
@@ -440,23 +459,41 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> A
440459

441460
parsed_n = kwargs.get("n", 1) if not isinstance(kwargs.get("n", 1), NotGiven) else 1
442461

443-
modelParameters = {
444-
"temperature": parsed_temperature,
445-
"max_tokens": parsed_max_tokens, # casing?
446-
"top_p": parsed_top_p,
447-
"frequency_penalty": parsed_frequency_penalty,
448-
"presence_penalty": parsed_presence_penalty,
449-
}
462+
if resource.type == "embedding":
463+
parsed_dimensions = (
464+
kwargs.get("dimensions", None)
465+
if not isinstance(kwargs.get("dimensions", None), NotGiven)
466+
else None
467+
)
468+
parsed_encoding_format = (
469+
kwargs.get("encoding_format", "float")
470+
if not isinstance(kwargs.get("encoding_format", "float"), NotGiven)
471+
else "float"
472+
)
473+
474+
modelParameters = {}
475+
if parsed_dimensions is not None:
476+
modelParameters["dimensions"] = parsed_dimensions
477+
if parsed_encoding_format != "float":
478+
modelParameters["encoding_format"] = parsed_encoding_format
479+
else:
480+
modelParameters = {
481+
"temperature": parsed_temperature,
482+
"max_tokens": parsed_max_tokens,
483+
"top_p": parsed_top_p,
484+
"frequency_penalty": parsed_frequency_penalty,
485+
"presence_penalty": parsed_presence_penalty,
486+
}
450487

451-
if parsed_max_completion_tokens is not None:
452-
modelParameters.pop("max_tokens", None)
453-
modelParameters["max_completion_tokens"] = parsed_max_completion_tokens
488+
if parsed_max_completion_tokens is not None:
489+
modelParameters.pop("max_tokens", None)
490+
modelParameters["max_completion_tokens"] = parsed_max_completion_tokens
454491

455-
if parsed_n is not None and parsed_n > 1:
456-
modelParameters["n"] = parsed_n
492+
if parsed_n is not None and parsed_n > 1:
493+
modelParameters["n"] = parsed_n
457494

458-
if parsed_seed is not None:
459-
modelParameters["seed"] = parsed_seed
495+
if parsed_seed is not None:
496+
modelParameters["seed"] = parsed_seed
460497

461498
langfuse_prompt = kwargs.get("langfuse_prompt", None)
462499

@@ -729,6 +766,20 @@ def _get_langfuse_data_from_default_response(
729766
else choice.get("message", None)
730767
)
731768

769+
elif resource.type == "embedding":
770+
data = response.get("data", [])
771+
if len(data) > 0:
772+
first_embedding = data[0]
773+
embedding_vector = (
774+
first_embedding.embedding
775+
if hasattr(first_embedding, "embedding")
776+
else first_embedding.get("embedding", [])
777+
)
778+
completion = {
779+
"dimensions": len(embedding_vector) if embedding_vector else 0,
780+
"count": len(data),
781+
}
782+
732783
usage = _parse_usage(response.get("usage", None))
733784

734785
return (model, completion, usage)
@@ -757,8 +808,12 @@ def _wrap(
757808
langfuse_data = _get_langfuse_data_from_kwargs(open_ai_resource, langfuse_args)
758809
langfuse_client = get_client(public_key=langfuse_args["langfuse_public_key"])
759810

811+
observation_type = (
812+
"embedding" if open_ai_resource.type == "embedding" else "generation"
813+
)
814+
760815
generation = langfuse_client.start_observation(
761-
as_type="generation",
816+
as_type=observation_type,
762817
name=langfuse_data["name"],
763818
input=langfuse_data.get("input", None),
764819
metadata=langfuse_data.get("metadata", None),
@@ -824,8 +879,12 @@ async def _wrap_async(
824879
langfuse_data = _get_langfuse_data_from_kwargs(open_ai_resource, langfuse_args)
825880
langfuse_client = get_client(public_key=langfuse_args["langfuse_public_key"])
826881

882+
observation_type = (
883+
"embedding" if open_ai_resource.type == "embedding" else "generation"
884+
)
885+
827886
generation = langfuse_client.start_observation(
828-
as_type="generation",
887+
as_type=observation_type,
829888
name=langfuse_data["name"],
830889
input=langfuse_data.get("input", None),
831890
metadata=langfuse_data.get("metadata", None),

tests/test_openai.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,3 +1514,90 @@ def test_response_api_reasoning(openai):
15141514
assert generationData.usage.total is not None
15151515
assert generationData.output is not None
15161516
assert generationData.metadata is not None
1517+
1518+
1519+
def test_openai_embeddings(openai):
1520+
embedding_name = create_uuid()
1521+
openai.OpenAI().embeddings.create(
1522+
name=embedding_name,
1523+
model="text-embedding-ada-002",
1524+
input="The quick brown fox jumps over the lazy dog",
1525+
metadata={"test_key": "test_value"},
1526+
)
1527+
1528+
langfuse.flush()
1529+
sleep(1)
1530+
1531+
embedding = get_api().observations.get_many(name=embedding_name, type="EMBEDDING")
1532+
1533+
assert len(embedding.data) != 0
1534+
embedding_data = embedding.data[0]
1535+
assert embedding_data.name == embedding_name
1536+
assert embedding_data.metadata["test_key"] == "test_value"
1537+
assert embedding_data.input == "The quick brown fox jumps over the lazy dog"
1538+
assert embedding_data.type == "EMBEDDING"
1539+
assert embedding_data.model == "text-embedding-ada-002"
1540+
assert embedding_data.start_time is not None
1541+
assert embedding_data.end_time is not None
1542+
assert embedding_data.start_time < embedding_data.end_time
1543+
assert embedding_data.usage.input is not None
1544+
assert embedding_data.usage.total is not None
1545+
assert embedding_data.output is not None
1546+
assert "dimensions" in embedding_data.output
1547+
assert "count" in embedding_data.output
1548+
assert embedding_data.output["count"] == 1
1549+
1550+
1551+
def test_openai_embeddings_multiple_inputs(openai):
1552+
embedding_name = create_uuid()
1553+
inputs = ["The quick brown fox", "jumps over the lazy dog", "Hello world"]
1554+
1555+
openai.OpenAI().embeddings.create(
1556+
name=embedding_name,
1557+
model="text-embedding-ada-002",
1558+
input=inputs,
1559+
metadata={"batch_size": len(inputs)},
1560+
)
1561+
1562+
langfuse.flush()
1563+
sleep(1)
1564+
1565+
embedding = get_api().observations.get_many(name=embedding_name, type="EMBEDDING")
1566+
1567+
assert len(embedding.data) != 0
1568+
embedding_data = embedding.data[0]
1569+
assert embedding_data.name == embedding_name
1570+
assert embedding_data.input == inputs
1571+
assert embedding_data.type == "EMBEDDING"
1572+
assert embedding_data.model == "text-embedding-ada-002"
1573+
assert embedding_data.usage.input is not None
1574+
assert embedding_data.usage.total is not None
1575+
assert embedding_data.output["count"] == len(inputs)
1576+
1577+
1578+
@pytest.mark.asyncio
1579+
async def test_async_openai_embeddings(openai):
1580+
client = openai.AsyncOpenAI()
1581+
embedding_name = create_uuid()
1582+
1583+
await client.embeddings.create(
1584+
name=embedding_name,
1585+
model="text-embedding-ada-002",
1586+
input="Async embedding test",
1587+
metadata={"async": True},
1588+
)
1589+
1590+
langfuse.flush()
1591+
sleep(1)
1592+
1593+
embedding = get_api().observations.get_many(name=embedding_name, type="EMBEDDING")
1594+
1595+
assert len(embedding.data) != 0
1596+
embedding_data = embedding.data[0]
1597+
assert embedding_data.name == embedding_name
1598+
assert embedding_data.input == "Async embedding test"
1599+
assert embedding_data.type == "EMBEDDING"
1600+
assert embedding_data.model == "text-embedding-ada-002"
1601+
assert embedding_data.metadata["async"] is True
1602+
assert embedding_data.usage.input is not None
1603+
assert embedding_data.usage.total is not None

0 commit comments

Comments
 (0)