diff --git a/haystack/json-schemas/haystack-pipeline-1.10.0rc0.schema.json b/haystack/json-schemas/haystack-pipeline-1.10.0rc0.schema.json index 5c05004f0..30587956f 100644 --- a/haystack/json-schemas/haystack-pipeline-1.10.0rc0.schema.json +++ b/haystack/json-schemas/haystack-pipeline-1.10.0rc0.schema.json @@ -2975,6 +2975,17 @@ "items": { "type": "string" } + }, + "api_key": { + "title": "Api Key", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ] } }, "required": [ diff --git a/haystack/json-schemas/haystack-pipeline-main.schema.json b/haystack/json-schemas/haystack-pipeline-main.schema.json index 1dd72b7be..c23eb2326 100644 --- a/haystack/json-schemas/haystack-pipeline-main.schema.json +++ b/haystack/json-schemas/haystack-pipeline-main.schema.json @@ -2975,6 +2975,17 @@ "items": { "type": "string" } + }, + "api_key": { + "title": "Api Key", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ] } }, "required": [ diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py index a7c30f4c0..f2d4dba34 100644 --- a/haystack/nodes/retriever/_embedding_encoder.py +++ b/haystack/nodes/retriever/_embedding_encoder.py @@ -1,9 +1,11 @@ +import json import logging from abc import abstractmethod from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union import numpy as np +import requests import torch from sentence_transformers import InputExample from torch.utils.data import DataLoader @@ -11,6 +13,7 @@ from torch.utils.data.sampler import SequentialSampler from tqdm.auto import tqdm from transformers import AutoModel, AutoTokenizer +from haystack.errors import OpenAIError from haystack.modeling.data_handler.dataloader import NamedDataLoader from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename from haystack.modeling.infer import Inferencer @@ -20,7 +23,6 @@ from haystack.schema import Document if TYPE_CHECKING: from haystack.nodes.retriever import EmbeddingRetriever - logger = logging.getLogger(__name__) @@ -374,9 +376,82 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder): ) +class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder): + def __init__(self, retriever: "EmbeddingRetriever"): + # pretrained embedding models coming from: + self.max_seq_len = retriever.max_seq_len + self.url = "https://api.openai.com/v1/embeddings" + self.api_key = retriever.api_key + self.batch_size = retriever.batch_size + self.progress_bar = retriever.progress_bar + model_class: str = next( + (m for m in ["ada", "babbage", "davinci", "curie"] if m in retriever.embedding_model), "babbage" + ) + self.query_model_encoder_engine = f"text-search-{model_class}-query-001" + self.doc_model_encoder_engine = f"text-search-{model_class}-doc-001" + self.tokenizer = AutoTokenizer.from_pretrained("gpt2") + + def _ensure_text_limit(self, text: str) -> str: + """ + Ensure that length of the text is within the maximum length of the model. + OpenAI embedding models have a limit of 2048 tokens + """ + tokenized_payload = self.tokenizer(text) + return self.tokenizer.decode(tokenized_payload["input_ids"][: self.max_seq_len]) + + def embed(self, model: str, text: List[str]) -> np.ndarray: + payload = {"model": model, "input": text} + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30) + res = json.loads(response.text) + + if response.status_code != 200: + raise OpenAIError( + f"OpenAI returned an error.\n" + f"Status code: {response.status_code}\n" + f"Response body: {response.text}" + ) + + unordered_embeddings = [(ans["index"], ans["embedding"]) for ans in res["data"]] + ordered_embeddings = sorted(unordered_embeddings, key=lambda x: x[0]) + generated_embeddings = [emb[1] for emb in ordered_embeddings] + return np.array(generated_embeddings) + + def embed_batch(self, model: str, text: List[str]) -> np.ndarray: + all_embeddings = [] + for i in tqdm( + range(0, len(text), self.batch_size), disable=not self.progress_bar, desc="Calculating embeddings" + ): + batch = text[i : i + self.batch_size] + batch_limited = [self._ensure_text_limit(content) for content in batch] + generated_embeddings = self.embed(model, batch_limited) + all_embeddings.append(generated_embeddings) + return np.concatenate(all_embeddings) + + def embed_queries(self, queries: List[str]) -> np.ndarray: + return self.embed_batch(self.query_model_encoder_engine, queries) + + def embed_documents(self, docs: List[Document]) -> np.ndarray: + return self.embed_batch(self.doc_model_encoder_engine, [d.content for d in docs]) + + def train( + self, + training_data: List[Dict[str, Any]], + learning_rate: float = 2e-5, + n_epochs: int = 1, + num_warmup_steps: int = None, + batch_size: int = 16, + ): + raise NotImplementedError(f"Training is not implemented for {self.__class__}") + + def save(self, save_dir: Union[Path, str]): + raise NotImplementedError(f"Saving is not implemented for {self.__class__}") + + _EMBEDDING_ENCODERS: Dict[str, Callable] = { "farm": _DefaultEmbeddingEncoder, "transformers": _DefaultEmbeddingEncoder, "sentence_transformers": _SentenceTransformersEmbeddingEncoder, "retribert": _RetribertEmbeddingEncoder, + "openai": _OpenAIEmbeddingEncoder, } diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 132337e6c..09e4f32fb 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1495,6 +1495,7 @@ class EmbeddingRetriever(DenseRetriever): use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True, embed_meta_fields: List[str] = [], + api_key: Optional[str] = None, ): """ :param document_store: An instance of DocumentStore from which to retrieve documents. @@ -1511,6 +1512,7 @@ class EmbeddingRetriever(DenseRetriever): - ``'transformers'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder) - ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder) - ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder) + - ``'openai'``: (will use `_OpenAIEmbeddingEncoder` as embedding encoder) :param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only). Options: @@ -1541,6 +1543,9 @@ class EmbeddingRetriever(DenseRetriever): This approach is also used in the TableTextRetriever paper and is likely to improve performance if your titles contain meaningful information for retrieval (topic, entities etc.). + :param api_key: The OpenAI API key. Required if one wants to use OpenAI embeddings. For more + details see https://beta.openai.com/account/api-keys + """ super().__init__() @@ -1561,6 +1566,7 @@ class EmbeddingRetriever(DenseRetriever): self.progress_bar = progress_bar self.use_auth_token = use_auth_token self.scale_score = scale_score + self.api_key = api_key self.model_format = ( self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token) if model_format is None @@ -1869,6 +1875,8 @@ class EmbeddingRetriever(DenseRetriever): @staticmethod def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str: + if any(m in model_name_or_path for m in ["ada", "babbage", "davinci", "curie"]): + return "openai" # Check if model name is a local directory with sentence transformers config file in it if Path(model_name_or_path).exists(): if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists(): diff --git a/test/conftest.py b/test/conftest.py index 1ac5fe7e4..872eb5d88 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -818,6 +818,13 @@ def get_retriever(retriever_type, document_store): retriever = EmbeddingRetriever( document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False ) + elif retriever_type == "openai": + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="ada", + use_gpu=False, + api_key=os.environ.get("OPENAI_API_KEY", ""), + ) elif retriever_type == "dpr_lfqa": retriever = DensePassageRetriever( document_store=document_store, diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index 076a70dce..4e25a8a05 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -1,4 +1,5 @@ import logging +import os from math import isclose import numpy as np @@ -10,6 +11,7 @@ from elasticsearch import Elasticsearch from haystack.document_stores import WeaviateDocumentStore from haystack.nodes.retriever.base import BaseRetriever +from haystack.pipelines import DocumentSearchPipeline from haystack.schema import Document from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore from haystack.document_stores.faiss import FAISSDocumentStore @@ -216,6 +218,45 @@ def test_retribert_embedding(document_store, retriever, docs_with_ids): assert isclose(embedding[0], expected_value, rel_tol=0.001) +@pytest.mark.integration +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +@pytest.mark.parametrize("retriever", ["openai"], indirect=True) +@pytest.mark.embedding_dim(1024) +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_openai_embedding(document_store, retriever, docs_with_ids): + document_store.return_embedding = True + document_store.write_documents(docs_with_ids) + document_store.update_embeddings(retriever=retriever) + + docs = document_store.get_all_documents() + docs = sorted(docs, key=lambda d: d.id) + + for doc in docs: + assert len(doc.embedding) == 1024 + + +@pytest.mark.integration +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +@pytest.mark.parametrize("retriever", ["openai"], indirect=True) +@pytest.mark.embedding_dim(1024) +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_retriever_basic_search(document_store, retriever, docs_with_ids): + document_store.return_embedding = True + document_store.write_documents(docs_with_ids) + document_store.update_embeddings(retriever=retriever) + + p_retrieval = DocumentSearchPipeline(retriever) + res = p_retrieval.run(query="Madrid", params={"Retriever": {"top_k": 1}}) + assert len(res["documents"]) == 1 + assert "Madrid" in res["documents"][0].content + + @pytest.mark.integration @pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True) @pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)