mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 13:38:53 +00:00
feat:Add dimensions parameter to OpenAI Embedders to fully support th… (#6841)
* feat:Add dimensions parameter to OpenAI Embedders to fully support the new models * fixed linting * changed != None to is not None
This commit is contained in:
parent
0fbb0655f0
commit
3bd6ba93ca
@ -33,6 +33,7 @@ class OpenAIDocumentEmbedder:
|
||||
self,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "text-embedding-ada-002",
|
||||
dimensions: Optional[int] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
prefix: str = "",
|
||||
@ -46,6 +47,7 @@ class OpenAIDocumentEmbedder:
|
||||
Create a OpenAIDocumentEmbedder component.
|
||||
:param api_key: The OpenAI API key.
|
||||
:param model: The name of the model to use.
|
||||
:param dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
||||
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
|
||||
:param organization: The Organization ID, defaults to `None`. See
|
||||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
|
||||
@ -59,6 +61,7 @@ class OpenAIDocumentEmbedder:
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.api_base_url = api_base_url
|
||||
self.organization = organization
|
||||
self.prefix = prefix
|
||||
@ -84,6 +87,7 @@ class OpenAIDocumentEmbedder:
|
||||
return default_to_dict(
|
||||
self,
|
||||
model=self.model,
|
||||
dimensions=self.dimensions,
|
||||
organization=self.organization,
|
||||
api_base_url=self.api_base_url,
|
||||
prefix=self.prefix,
|
||||
@ -131,7 +135,10 @@ class OpenAIDocumentEmbedder:
|
||||
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
||||
):
|
||||
batch = texts_to_embed[i : i + batch_size]
|
||||
response = self.client.embeddings.create(model=self.model, input=batch)
|
||||
if self.dimensions is not None:
|
||||
response = self.client.embeddings.create(model=self.model, dimensions=self.dimensions, input=batch)
|
||||
else:
|
||||
response = self.client.embeddings.create(model=self.model, input=batch)
|
||||
embeddings = [el.embedding for el in response.data]
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ class OpenAITextEmbedder:
|
||||
self,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "text-embedding-ada-002",
|
||||
dimensions: Optional[int] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
prefix: str = "",
|
||||
@ -42,6 +43,7 @@ class OpenAITextEmbedder:
|
||||
:param api_key: The OpenAI API key.
|
||||
:param model: The name of the OpenAI model to use. For more details on the available models,
|
||||
see [OpenAI documentation](https://platform.openai.com/docs/guides/embeddings/embedding-models).
|
||||
:param dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
||||
:param organization: The Organization ID, defaults to `None`. See
|
||||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
|
||||
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
|
||||
@ -49,6 +51,7 @@ class OpenAITextEmbedder:
|
||||
:param suffix: A string to add to the end of each text.
|
||||
"""
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.organization = organization
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
@ -69,6 +72,7 @@ class OpenAITextEmbedder:
|
||||
organization=self.organization,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
dimensions=self.dimensions,
|
||||
api_key=self.api_key.to_dict(),
|
||||
)
|
||||
|
||||
@ -92,7 +96,11 @@ class OpenAITextEmbedder:
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text_to_embed = text_to_embed.replace("\n", " ")
|
||||
|
||||
response = self.client.embeddings.create(model=self.model, input=text_to_embed)
|
||||
if self.dimensions is not None:
|
||||
response = self.client.embeddings.create(model=self.model, dimensions=self.dimensions, input=text_to_embed)
|
||||
else:
|
||||
response = self.client.embeddings.create(model=self.model, input=text_to_embed)
|
||||
|
||||
meta = {"model": response.model, "usage": dict(response.usage)}
|
||||
|
||||
return {"embedding": response.data[0].embedding, "meta": meta}
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
add dimensions parameter to OpenAI Embedders to fully support new embedding models like text-embedding-3-small, text-embedding-3-large and upcoming ones
|
||||
@ -71,6 +71,7 @@ class TestOpenAIDocumentEmbedder:
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"api_base_url": None,
|
||||
"model": "text-embedding-ada-002",
|
||||
"dimensions": None,
|
||||
"organization": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
@ -101,6 +102,7 @@ class TestOpenAIDocumentEmbedder:
|
||||
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"api_base_url": None,
|
||||
"model": "model",
|
||||
"dimensions": None,
|
||||
"organization": "my-org",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import os
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
class TestOpenAITextEmbedder:
|
||||
@ -44,6 +44,7 @@ class TestOpenAITextEmbedder:
|
||||
"type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"dimensions": None,
|
||||
"model": "text-embedding-ada-002",
|
||||
"organization": None,
|
||||
"prefix": "",
|
||||
@ -66,6 +67,7 @@ class TestOpenAITextEmbedder:
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "model",
|
||||
"dimensions": None,
|
||||
"organization": "fake-organization",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user