feat: added dimensions parameters to Azure OpenAI Embedders (#7449)

* added dimensions parameter to AzureOpenAIEmbedders

* created releasenote

* update release note

---------

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
Nicola Procopio 2024-04-02 14:04:16 +02:00 committed by GitHub
parent 6e289698e9
commit 42c5b7af32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 30 additions and 2 deletions

View File

@ -34,6 +34,7 @@ class AzureOpenAIDocumentEmbedder:
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
dimensions: Optional[int] = None,
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
@ -53,6 +54,8 @@ class AzureOpenAIDocumentEmbedder:
The version of the API to use.
:param azure_deployment:
The deployment of the model, usually matches the model name.
:param dimensions:
The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param api_key:
The API key used for authentication.
:param azure_ad_token:
@ -90,6 +93,7 @@ class AzureOpenAIDocumentEmbedder:
self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.dimensions = dimensions
self.organization = organization
self.prefix = prefix
self.suffix = suffix
@ -124,6 +128,7 @@ class AzureOpenAIDocumentEmbedder:
self,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
dimensions=self.dimensions,
organization=self.organization,
api_version=self.api_version,
prefix=self.prefix,
@ -175,7 +180,12 @@ class AzureOpenAIDocumentEmbedder:
meta: Dict[str, Any] = {"model": "", "usage": {"prompt_tokens": 0, "total_tokens": 0}}
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Texts"):
batch = texts_to_embed[i : i + batch_size]
response = self._client.embeddings.create(model=self.azure_deployment, input=batch)
if self.dimensions is not None:
response = self._client.embeddings.create(
model=self.azure_deployment, dimensions=self.dimensions, input=batch
)
else:
response = self._client.embeddings.create(model=self.azure_deployment, input=batch)
# Append embeddings to the list
all_embeddings.extend(el.embedding for el in response.data)

View File

@ -33,6 +33,7 @@ class AzureOpenAITextEmbedder:
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
dimensions: Optional[int] = None,
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
@ -48,6 +49,8 @@ class AzureOpenAITextEmbedder:
The version of the API to use.
:param azure_deployment:
The deployment of the model, usually matches the model name.
:param dimensions:
The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param api_key:
The API key used for authentication.
:param azure_ad_token:
@ -80,6 +83,7 @@ class AzureOpenAITextEmbedder:
self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.dimensions = dimensions
self.organization = organization
self.prefix = prefix
self.suffix = suffix
@ -110,6 +114,7 @@ class AzureOpenAITextEmbedder:
self,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
dimensions=self.dimensions,
organization=self.organization,
api_version=self.api_version,
prefix=self.prefix,
@ -156,7 +161,12 @@ class AzureOpenAITextEmbedder:
# finally, replace newlines as recommended by OpenAI docs
processed_text = f"{self.prefix}{text}{self.suffix}".replace("\n", " ")
response = self._client.embeddings.create(model=self.azure_deployment, input=processed_text)
if self.dimensions is not None:
response = self._client.embeddings.create(
model=self.azure_deployment, dimensions=self.dimensions, input=processed_text
)
else:
response = self._client.embeddings.create(model=self.azure_deployment, input=processed_text)
return {
"embedding": response.data[0].embedding,

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
add dimensions parameter to Azure OpenAI Embedders (AzureOpenAITextEmbedder and AzureOpenAIDocumentEmbedder) to fully support new embedding models like text-embedding-3-small, text-embedding-3-large and upcoming ones

View File

@ -11,6 +11,7 @@ class TestAzureOpenAIDocumentEmbedder:
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
assert embedder.azure_deployment == "text-embedding-ada-002"
assert embedder.dimensions is None
assert embedder.organization is None
assert embedder.prefix == ""
assert embedder.suffix == ""
@ -30,6 +31,7 @@ class TestAzureOpenAIDocumentEmbedder:
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15",
"azure_deployment": "text-embedding-ada-002",
"dimensions": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",
"organization": None,
"prefix": "",

View File

@ -12,6 +12,7 @@ class TestAzureOpenAITextEmbedder:
assert embedder._client.api_key == "fake-api-key"
assert embedder.azure_deployment == "text-embedding-ada-002"
assert embedder.dimensions is None
assert embedder.organization is None
assert embedder.prefix == ""
assert embedder.suffix == ""
@ -26,6 +27,7 @@ class TestAzureOpenAITextEmbedder:
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"azure_deployment": "text-embedding-ada-002",
"dimensions": None,
"organization": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",
"api_version": "2023-05-15",