feat:Add dimensions parameter to OpenAI Embedders to fully support th… (#6841)

* feat:Add dimensions parameter to OpenAI Embedders to fully support the new models

* fixed linting

* changed != None to is not None
This commit is contained in:
sahusiddharth 2024-02-05 20:50:46 +05:30 committed by GitHub
parent 0fbb0655f0
commit 3bd6ba93ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 28 additions and 5 deletions

View File

@ -33,6 +33,7 @@ class OpenAIDocumentEmbedder:
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "text-embedding-ada-002",
dimensions: Optional[int] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
prefix: str = "",
@ -46,6 +47,7 @@ class OpenAIDocumentEmbedder:
Create a OpenAIDocumentEmbedder component.
:param api_key: The OpenAI API key.
:param model: The name of the model to use.
:param dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
@ -59,6 +61,7 @@ class OpenAIDocumentEmbedder:
"""
self.api_key = api_key
self.model = model
self.dimensions = dimensions
self.api_base_url = api_base_url
self.organization = organization
self.prefix = prefix
@ -84,6 +87,7 @@ class OpenAIDocumentEmbedder:
return default_to_dict(
self,
model=self.model,
dimensions=self.dimensions,
organization=self.organization,
api_base_url=self.api_base_url,
prefix=self.prefix,
@ -131,7 +135,10 @@ class OpenAIDocumentEmbedder:
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
response = self.client.embeddings.create(model=self.model, input=batch)
if self.dimensions is not None:
response = self.client.embeddings.create(model=self.model, dimensions=self.dimensions, input=batch)
else:
response = self.client.embeddings.create(model=self.model, input=batch)
embeddings = [el.embedding for el in response.data]
all_embeddings.extend(embeddings)

View File

@ -1,8 +1,8 @@
from typing import List, Optional, Dict, Any
from typing import Any, Dict, List, Optional
from openai import OpenAI
from haystack import component, default_to_dict, default_from_dict
from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
@ -31,6 +31,7 @@ class OpenAITextEmbedder:
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "text-embedding-ada-002",
dimensions: Optional[int] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
prefix: str = "",
@ -42,6 +43,7 @@ class OpenAITextEmbedder:
:param api_key: The OpenAI API key.
:param model: The name of the OpenAI model to use. For more details on the available models,
see [OpenAI documentation](https://platform.openai.com/docs/guides/embeddings/embedding-models).
:param dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
@ -49,6 +51,7 @@ class OpenAITextEmbedder:
:param suffix: A string to add to the end of each text.
"""
self.model = model
self.dimensions = dimensions
self.organization = organization
self.prefix = prefix
self.suffix = suffix
@ -69,6 +72,7 @@ class OpenAITextEmbedder:
organization=self.organization,
prefix=self.prefix,
suffix=self.suffix,
dimensions=self.dimensions,
api_key=self.api_key.to_dict(),
)
@ -92,7 +96,11 @@ class OpenAITextEmbedder:
# replace newlines, which can negatively affect performance.
text_to_embed = text_to_embed.replace("\n", " ")
response = self.client.embeddings.create(model=self.model, input=text_to_embed)
if self.dimensions is not None:
response = self.client.embeddings.create(model=self.model, dimensions=self.dimensions, input=text_to_embed)
else:
response = self.client.embeddings.create(model=self.model, input=text_to_embed)
meta = {"model": response.model, "usage": dict(response.usage)}
return {"embedding": response.data[0].embedding, "meta": meta}

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
add dimensions parameter to OpenAI Embedders to fully support new embedding models like text-embedding-3-small, text-embedding-3-large and upcoming ones

View File

@ -71,6 +71,7 @@ class TestOpenAIDocumentEmbedder:
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"api_base_url": None,
"model": "text-embedding-ada-002",
"dimensions": None,
"organization": None,
"prefix": "",
"suffix": "",
@ -101,6 +102,7 @@ class TestOpenAIDocumentEmbedder:
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"api_base_url": None,
"model": "model",
"dimensions": None,
"organization": "my-org",
"prefix": "prefix",
"suffix": "suffix",

View File

@ -1,9 +1,9 @@
import os
from haystack.utils.auth import Secret
import pytest
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
from haystack.utils.auth import Secret
class TestOpenAITextEmbedder:
@ -44,6 +44,7 @@ class TestOpenAITextEmbedder:
"type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"dimensions": None,
"model": "text-embedding-ada-002",
"organization": None,
"prefix": "",
@ -66,6 +67,7 @@ class TestOpenAITextEmbedder:
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "model",
"dimensions": None,
"organization": "fake-organization",
"prefix": "prefix",
"suffix": "suffix",