feat: Add OpenAIEmbeddingEncoder to EmbeddingRetriever (#3356)

This commit is contained in:
Vladimir Blagojevic 2022-10-14 15:01:03 +02:00 committed by GitHub
parent 5ebe3cb33d
commit 159cd5a666
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 154 additions and 1 deletions

View File

@ -2975,6 +2975,17 @@
"items": { "items": {
"type": "string" "type": "string"
} }
},
"api_key": {
"title": "Api Key",
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
]
} }
}, },
"required": [ "required": [

View File

@ -2975,6 +2975,17 @@
"items": { "items": {
"type": "string" "type": "string"
} }
},
"api_key": {
"title": "Api Key",
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
]
} }
}, },
"required": [ "required": [

View File

@ -1,9 +1,11 @@
import json
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union
import numpy as np import numpy as np
import requests
import torch import torch
from sentence_transformers import InputExample from sentence_transformers import InputExample
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -11,6 +13,7 @@ from torch.utils.data.sampler import SequentialSampler
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from haystack.errors import OpenAIError
from haystack.modeling.data_handler.dataloader import NamedDataLoader from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename
from haystack.modeling.infer import Inferencer from haystack.modeling.infer import Inferencer
@ -20,7 +23,6 @@ from haystack.schema import Document
if TYPE_CHECKING: if TYPE_CHECKING:
from haystack.nodes.retriever import EmbeddingRetriever from haystack.nodes.retriever import EmbeddingRetriever
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -374,9 +376,82 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
) )
class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"):
# pretrained embedding models coming from:
self.max_seq_len = retriever.max_seq_len
self.url = "https://api.openai.com/v1/embeddings"
self.api_key = retriever.api_key
self.batch_size = retriever.batch_size
self.progress_bar = retriever.progress_bar
model_class: str = next(
(m for m in ["ada", "babbage", "davinci", "curie"] if m in retriever.embedding_model), "babbage"
)
self.query_model_encoder_engine = f"text-search-{model_class}-query-001"
self.doc_model_encoder_engine = f"text-search-{model_class}-doc-001"
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
def _ensure_text_limit(self, text: str) -> str:
"""
Ensure that length of the text is within the maximum length of the model.
OpenAI embedding models have a limit of 2048 tokens
"""
tokenized_payload = self.tokenizer(text)
return self.tokenizer.decode(tokenized_payload["input_ids"][: self.max_seq_len])
def embed(self, model: str, text: List[str]) -> np.ndarray:
payload = {"model": model, "input": text}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30)
res = json.loads(response.text)
if response.status_code != 200:
raise OpenAIError(
f"OpenAI returned an error.\n"
f"Status code: {response.status_code}\n"
f"Response body: {response.text}"
)
unordered_embeddings = [(ans["index"], ans["embedding"]) for ans in res["data"]]
ordered_embeddings = sorted(unordered_embeddings, key=lambda x: x[0])
generated_embeddings = [emb[1] for emb in ordered_embeddings]
return np.array(generated_embeddings)
def embed_batch(self, model: str, text: List[str]) -> np.ndarray:
all_embeddings = []
for i in tqdm(
range(0, len(text), self.batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = text[i : i + self.batch_size]
batch_limited = [self._ensure_text_limit(content) for content in batch]
generated_embeddings = self.embed(model, batch_limited)
all_embeddings.append(generated_embeddings)
return np.concatenate(all_embeddings)
def embed_queries(self, queries: List[str]) -> np.ndarray:
return self.embed_batch(self.query_model_encoder_engine, queries)
def embed_documents(self, docs: List[Document]) -> np.ndarray:
return self.embed_batch(self.doc_model_encoder_engine, [d.content for d in docs])
def train(
self,
training_data: List[Dict[str, Any]],
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: int = None,
batch_size: int = 16,
):
raise NotImplementedError(f"Training is not implemented for {self.__class__}")
def save(self, save_dir: Union[Path, str]):
raise NotImplementedError(f"Saving is not implemented for {self.__class__}")
_EMBEDDING_ENCODERS: Dict[str, Callable] = { _EMBEDDING_ENCODERS: Dict[str, Callable] = {
"farm": _DefaultEmbeddingEncoder, "farm": _DefaultEmbeddingEncoder,
"transformers": _DefaultEmbeddingEncoder, "transformers": _DefaultEmbeddingEncoder,
"sentence_transformers": _SentenceTransformersEmbeddingEncoder, "sentence_transformers": _SentenceTransformersEmbeddingEncoder,
"retribert": _RetribertEmbeddingEncoder, "retribert": _RetribertEmbeddingEncoder,
"openai": _OpenAIEmbeddingEncoder,
} }

View File

@ -1495,6 +1495,7 @@ class EmbeddingRetriever(DenseRetriever):
use_auth_token: Optional[Union[str, bool]] = None, use_auth_token: Optional[Union[str, bool]] = None,
scale_score: bool = True, scale_score: bool = True,
embed_meta_fields: List[str] = [], embed_meta_fields: List[str] = [],
api_key: Optional[str] = None,
): ):
""" """
:param document_store: An instance of DocumentStore from which to retrieve documents. :param document_store: An instance of DocumentStore from which to retrieve documents.
@ -1511,6 +1512,7 @@ class EmbeddingRetriever(DenseRetriever):
- ``'transformers'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder) - ``'transformers'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder)
- ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder) - ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder)
- ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder) - ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder)
- ``'openai'``: (will use `_OpenAIEmbeddingEncoder` as embedding encoder)
:param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only). :param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only).
Options: Options:
@ -1541,6 +1543,9 @@ class EmbeddingRetriever(DenseRetriever):
This approach is also used in the TableTextRetriever paper and is likely to improve This approach is also used in the TableTextRetriever paper and is likely to improve
performance if your titles contain meaningful information for retrieval performance if your titles contain meaningful information for retrieval
(topic, entities etc.). (topic, entities etc.).
:param api_key: The OpenAI API key. Required if one wants to use OpenAI embeddings. For more
details see https://beta.openai.com/account/api-keys
""" """
super().__init__() super().__init__()
@ -1561,6 +1566,7 @@ class EmbeddingRetriever(DenseRetriever):
self.progress_bar = progress_bar self.progress_bar = progress_bar
self.use_auth_token = use_auth_token self.use_auth_token = use_auth_token
self.scale_score = scale_score self.scale_score = scale_score
self.api_key = api_key
self.model_format = ( self.model_format = (
self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token) self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token)
if model_format is None if model_format is None
@ -1869,6 +1875,8 @@ class EmbeddingRetriever(DenseRetriever):
@staticmethod @staticmethod
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str: def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
if any(m in model_name_or_path for m in ["ada", "babbage", "davinci", "curie"]):
return "openai"
# Check if model name is a local directory with sentence transformers config file in it # Check if model name is a local directory with sentence transformers config file in it
if Path(model_name_or_path).exists(): if Path(model_name_or_path).exists():
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists(): if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():

View File

@ -818,6 +818,13 @@ def get_retriever(retriever_type, document_store):
retriever = EmbeddingRetriever( retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False
) )
elif retriever_type == "openai":
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="ada",
use_gpu=False,
api_key=os.environ.get("OPENAI_API_KEY", ""),
)
elif retriever_type == "dpr_lfqa": elif retriever_type == "dpr_lfqa":
retriever = DensePassageRetriever( retriever = DensePassageRetriever(
document_store=document_store, document_store=document_store,

View File

@ -1,4 +1,5 @@
import logging import logging
import os
from math import isclose from math import isclose
import numpy as np import numpy as np
@ -10,6 +11,7 @@ from elasticsearch import Elasticsearch
from haystack.document_stores import WeaviateDocumentStore from haystack.document_stores import WeaviateDocumentStore
from haystack.nodes.retriever.base import BaseRetriever from haystack.nodes.retriever.base import BaseRetriever
from haystack.pipelines import DocumentSearchPipeline
from haystack.schema import Document from haystack.schema import Document
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack.document_stores.faiss import FAISSDocumentStore from haystack.document_stores.faiss import FAISSDocumentStore
@ -216,6 +218,45 @@ def test_retribert_embedding(document_store, retriever, docs_with_ids):
assert isclose(embedding[0], expected_value, rel_tol=0.001) assert isclose(embedding[0], expected_value, rel_tol=0.001)
@pytest.mark.integration
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
@pytest.mark.embedding_dim(1024)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_openai_embedding(document_store, retriever, docs_with_ids):
document_store.return_embedding = True
document_store.write_documents(docs_with_ids)
document_store.update_embeddings(retriever=retriever)
docs = document_store.get_all_documents()
docs = sorted(docs, key=lambda d: d.id)
for doc in docs:
assert len(doc.embedding) == 1024
@pytest.mark.integration
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
@pytest.mark.embedding_dim(1024)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_retriever_basic_search(document_store, retriever, docs_with_ids):
document_store.return_embedding = True
document_store.write_documents(docs_with_ids)
document_store.update_embeddings(retriever=retriever)
p_retrieval = DocumentSearchPipeline(retriever)
res = p_retrieval.run(query="Madrid", params={"Retriever": {"top_k": 1}})
assert len(res["documents"]) == 1
assert "Madrid" in res["documents"][0].content
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True) @pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True) @pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)