mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-17 02:48:30 +00:00
feat: Add Azure OpenAI embeddings support (#4332)
* feate: add Azure OpenAI as embedding option * feat: Add Azure OpenAI embeddings support * refactor: check api key * refactor: better type checking for Azure * refactor: enable parallelism + separate and update tests * refactor: string reformat * refactor: explicit typing * refactor: update refs and remove unused code
This commit is contained in:
parent
c7dddfeaea
commit
1548c5ba0f
@ -1,7 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from multiprocessing import cpu_count
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
@ -9,7 +11,7 @@ from tqdm.auto import tqdm
|
||||
from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC
|
||||
from haystack.nodes.retriever._base_embedding_encoder import _BaseEmbeddingEncoder
|
||||
from haystack.schema import Document
|
||||
from haystack.utils.openai_utils import USE_TIKTOKEN, load_openai_tokenizer, openai_request, count_openai_tokens
|
||||
from haystack.utils.openai_utils import USE_TIKTOKEN, count_openai_tokens, load_openai_tokenizer, openai_request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.nodes.retriever import EmbeddingRetriever
|
||||
@ -21,8 +23,19 @@ OPENAI_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
|
||||
|
||||
class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||
# See https://beta.openai.com/docs/guides/embeddings for more details
|
||||
self.url = "https://api.openai.com/v1/embeddings"
|
||||
# See https://platform.openai.com/docs/guides/embeddings and
|
||||
# https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/embeddings?tabs=console for more details
|
||||
self.using_azure = (
|
||||
retriever.azure_deployment_name is not None
|
||||
and retriever.azure_base_url is not None
|
||||
and retriever.api_version is not None
|
||||
)
|
||||
|
||||
if self.using_azure:
|
||||
self.url = f"{retriever.azure_base_url}/openai/deployments/{retriever.azure_deployment_name}/embeddings?api-version={retriever.api_version}"
|
||||
else:
|
||||
self.url = "https://api.openai.com/v1/embeddings"
|
||||
|
||||
self.api_key = retriever.api_key
|
||||
self.batch_size = min(64, retriever.batch_size)
|
||||
self.progress_bar = retriever.progress_bar
|
||||
@ -81,13 +94,37 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
return decoded_string
|
||||
|
||||
def embed(self, model: str, text: List[str]) -> np.ndarray:
|
||||
payload = {"model": model, "input": text}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
res = openai_request(url=self.url, headers=headers, payload=payload, timeout=OPENAI_TIMEOUT)
|
||||
if self.api_key is None:
|
||||
raise ValueError(
|
||||
f"{'Azure ' if self.using_azure else ''}OpenAI API key is not set. You can set it via the `api_key` parameter of the EmbeddingRetriever."
|
||||
)
|
||||
|
||||
generated_embeddings: List[Any] = []
|
||||
|
||||
headers: Dict[str, str] = {"Content-Type": "application/json"}
|
||||
|
||||
def azure_get_embedding(input: str):
|
||||
headers["api-key"] = str(self.api_key)
|
||||
azure_payload: Dict[str, str] = {"input": input}
|
||||
res = openai_request(url=self.url, headers=headers, payload=azure_payload, timeout=OPENAI_TIMEOUT)
|
||||
return res["data"][0]["embedding"]
|
||||
|
||||
if self.using_azure:
|
||||
thread_count = cpu_count() if len(text) > cpu_count() else len(text)
|
||||
with ThreadPoolExecutor(max_workers=thread_count) as executor:
|
||||
results: Iterator[Dict[str, Any]] = executor.map(azure_get_embedding, text)
|
||||
generated_embeddings.extend(results)
|
||||
else:
|
||||
payload: Dict[str, Union[List[str], str]] = {"model": model, "input": text}
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
res = openai_request(url=self.url, headers=headers, payload=payload, timeout=OPENAI_TIMEOUT)
|
||||
|
||||
unordered_embeddings = [(ans["index"], ans["embedding"]) for ans in res["data"]]
|
||||
ordered_embeddings = sorted(unordered_embeddings, key=lambda x: x[0])
|
||||
|
||||
generated_embeddings = [emb[1] for emb in ordered_embeddings]
|
||||
|
||||
unordered_embeddings = [(ans["index"], ans["embedding"]) for ans in res["data"]]
|
||||
ordered_embeddings = sorted(unordered_embeddings, key=lambda x: x[0])
|
||||
generated_embeddings = [emb[1] for emb in ordered_embeddings]
|
||||
return np.array(generated_embeddings)
|
||||
|
||||
def embed_batch(self, model: str, text: List[str]) -> np.ndarray:
|
||||
|
@ -1465,6 +1465,9 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
scale_score: bool = True,
|
||||
embed_meta_fields: Optional[List[str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
azure_api_version: str = "2022-12-01",
|
||||
azure_base_url: Optional[str] = None,
|
||||
azure_deployment_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
||||
@ -1519,7 +1522,11 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
If no value is provided, a default empty list will be created.
|
||||
:param api_key: The OpenAI API key or the Cohere API key. Required if one wants to use OpenAI/Cohere embeddings.
|
||||
For more details see https://beta.openai.com/account/api-keys and https://dashboard.cohere.ai/api-keys
|
||||
|
||||
:param api_version: The version of the Azure OpenAI API to use. The default is `2022-12-01` version.
|
||||
:param azure_base_url: The base URL for the Azure OpenAI API. If not supplied, Azure OpenAI API will not be used.
|
||||
This parameter is an OpenAI Azure endpoint, usually in the form `https://<your-endpoint>.openai.azure.com'
|
||||
:param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API
|
||||
will not be used.
|
||||
"""
|
||||
if embed_meta_fields is None:
|
||||
embed_meta_fields = []
|
||||
@ -1543,6 +1550,9 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
self.use_auth_token = use_auth_token
|
||||
self.scale_score = scale_score
|
||||
self.api_key = api_key
|
||||
self.api_version = azure_api_version
|
||||
self.azure_base_url = azure_base_url
|
||||
self.azure_deployment_name = azure_deployment_name
|
||||
self.model_format = (
|
||||
self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token)
|
||||
if model_format is None
|
||||
|
@ -695,9 +695,18 @@ def get_retriever(retriever_type, document_store):
|
||||
elif retriever_type == "openai":
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store,
|
||||
embedding_model="ada",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
use_gpu=False,
|
||||
api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
elif retriever_type == "azure":
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
use_gpu=False,
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
azure_base_url=os.getenv("AZURE_OPENAI_BASE_URL"),
|
||||
azure_deployment_name=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME_EMBED"),
|
||||
)
|
||||
elif retriever_type == "cohere":
|
||||
retriever = EmbeddingRetriever(
|
||||
@ -985,10 +994,12 @@ def haystack_openai_config(request):
|
||||
if not api_key:
|
||||
return {}
|
||||
else:
|
||||
return {"api_key": api_key}
|
||||
return {"api_key": api_key, "embedding_model": "text-embedding-ada-002"}
|
||||
elif request.param == "azure":
|
||||
return haystack_azure_conf()
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_model(request):
|
||||
|
@ -378,14 +378,13 @@ def test_openai_embedding_retriever_selection():
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["openai", "cohere"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["cohere"], indirect=True)
|
||||
@pytest.mark.embedding_dim(1024)
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None) and not os.environ.get("COHERE_API_KEY", None),
|
||||
reason="Please export an env var called OPENAI_API_KEY/COHERE_API_KEY containing "
|
||||
"the OpenAI/Cohere API key to run this test.",
|
||||
not os.environ.get("COHERE_API_KEY", None),
|
||||
reason="Please export an env var called COHERE_API_KEY containing " "the Cohere API key to run this test.",
|
||||
)
|
||||
def test_basic_embedding(document_store, retriever, docs_with_ids):
|
||||
def test_basic_cohere_embedding(document_store, retriever, docs_with_ids):
|
||||
document_store.return_embedding = True
|
||||
document_store.write_documents(docs_with_ids)
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
@ -399,14 +398,105 @@ def test_basic_embedding(document_store, retriever, docs_with_ids):
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["openai", "cohere"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
|
||||
@pytest.mark.embedding_dim(1536)
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason=("Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test."),
|
||||
)
|
||||
def test_basic_openai_embedding(document_store, retriever, docs_with_ids):
|
||||
document_store.return_embedding = True
|
||||
document_store.write_documents(docs_with_ids)
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
|
||||
docs = document_store.get_all_documents()
|
||||
docs = sorted(docs, key=lambda d: d.id)
|
||||
|
||||
for doc in docs:
|
||||
assert len(doc.embedding) == 1536
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["azure"], indirect=True)
|
||||
@pytest.mark.embedding_dim(1536)
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None)
|
||||
and not os.environ.get("AZURE_OPENAI_BASE_URL", None)
|
||||
and not os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME_EMBED", None),
|
||||
reason=(
|
||||
"Please export env variables called AZURE_OPENAI_API_KEY containing "
|
||||
"the Azure OpenAI key, AZURE_OPENAI_BASE_URL containing "
|
||||
"the Azure OpenAI base URL, and AZURE_OPENAI_DEPLOYMENT_NAME_EMBED containing "
|
||||
"the Azure OpenAI deployment name to run this test."
|
||||
),
|
||||
)
|
||||
def test_basic_azure_embedding(document_store, retriever, docs_with_ids):
|
||||
document_store.return_embedding = True
|
||||
document_store.write_documents(docs_with_ids)
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
|
||||
docs = document_store.get_all_documents()
|
||||
docs = sorted(docs, key=lambda d: d.id)
|
||||
|
||||
for doc in docs:
|
||||
assert len(doc.embedding) == 1536
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["cohere"], indirect=True)
|
||||
@pytest.mark.embedding_dim(1024)
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None) and not os.environ.get("COHERE_API_KEY", None),
|
||||
reason="Please export an env var called OPENAI_API_KEY/COHERE_API_KEY containing "
|
||||
"the OpenAI/Cohere API key to run this test.",
|
||||
not os.environ.get("COHERE_API_KEY", None),
|
||||
reason="Please export an env var called COHERE_API_KEY containing the Cohere API key to run this test.",
|
||||
)
|
||||
def test_retriever_basic_search(document_store, retriever, docs_with_ids):
|
||||
def test_retriever_basic_cohere_search(document_store, retriever, docs_with_ids):
|
||||
document_store.return_embedding = True
|
||||
document_store.write_documents(docs_with_ids)
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
|
||||
p_retrieval = DocumentSearchPipeline(retriever)
|
||||
res = p_retrieval.run(query="Madrid", params={"Retriever": {"top_k": 1}})
|
||||
assert len(res["documents"]) == 1
|
||||
assert "Madrid" in res["documents"][0].content
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
|
||||
@pytest.mark.embedding_dim(1536)
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Please export env called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_retriever_basic_openai_search(document_store, retriever, docs_with_ids):
|
||||
document_store.return_embedding = True
|
||||
document_store.write_documents(docs_with_ids)
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
|
||||
p_retrieval = DocumentSearchPipeline(retriever)
|
||||
res = p_retrieval.run(query="Madrid", params={"Retriever": {"top_k": 1}})
|
||||
assert len(res["documents"]) == 1
|
||||
assert "Madrid" in res["documents"][0].content
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["azure"], indirect=True)
|
||||
@pytest.mark.embedding_dim(1536)
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None)
|
||||
and not os.environ.get("AZURE_OPENAI_BASE_URL", None)
|
||||
and not os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME_EMBED", None),
|
||||
reason=(
|
||||
"Please export env variables called AZURE_OPENAI_API_KEY containing "
|
||||
"the Azure OpenAI key, AZURE_OPENAI_BASE_URL containing "
|
||||
"the Azure OpenAI base URL, and AZURE_OPENAI_DEPLOYMENT_NAME_EMBED containing "
|
||||
"the Azure OpenAI deployment name to run this test."
|
||||
),
|
||||
)
|
||||
def test_retriever_basic_azure_search(document_store, retriever, docs_with_ids):
|
||||
document_store.return_embedding = True
|
||||
document_store.write_documents(docs_with_ids)
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
|
Loading…
x
Reference in New Issue
Block a user