mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 02:16:32 +00:00
feat: Add OpenAIEmbeddingEncoder to EmbeddingRetriever (#3356)
This commit is contained in:
parent
5ebe3cb33d
commit
159cd5a666
@ -2975,6 +2975,17 @@
|
|||||||
"items": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"api_key": {
|
||||||
|
"title": "Api Key",
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -2975,6 +2975,17 @@
|
|||||||
"items": {
|
"items": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"api_key": {
|
||||||
|
"title": "Api Key",
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from sentence_transformers import InputExample
|
from sentence_transformers import InputExample
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -11,6 +13,7 @@ from torch.utils.data.sampler import SequentialSampler
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
|
from haystack.errors import OpenAIError
|
||||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||||
from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename
|
from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename
|
||||||
from haystack.modeling.infer import Inferencer
|
from haystack.modeling.infer import Inferencer
|
||||||
@ -20,7 +23,6 @@ from haystack.schema import Document
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from haystack.nodes.retriever import EmbeddingRetriever
|
from haystack.nodes.retriever import EmbeddingRetriever
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -374,9 +376,82 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||||
|
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||||
|
# pretrained embedding models coming from:
|
||||||
|
self.max_seq_len = retriever.max_seq_len
|
||||||
|
self.url = "https://api.openai.com/v1/embeddings"
|
||||||
|
self.api_key = retriever.api_key
|
||||||
|
self.batch_size = retriever.batch_size
|
||||||
|
self.progress_bar = retriever.progress_bar
|
||||||
|
model_class: str = next(
|
||||||
|
(m for m in ["ada", "babbage", "davinci", "curie"] if m in retriever.embedding_model), "babbage"
|
||||||
|
)
|
||||||
|
self.query_model_encoder_engine = f"text-search-{model_class}-query-001"
|
||||||
|
self.doc_model_encoder_engine = f"text-search-{model_class}-doc-001"
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
def _ensure_text_limit(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Ensure that length of the text is within the maximum length of the model.
|
||||||
|
OpenAI embedding models have a limit of 2048 tokens
|
||||||
|
"""
|
||||||
|
tokenized_payload = self.tokenizer(text)
|
||||||
|
return self.tokenizer.decode(tokenized_payload["input_ids"][: self.max_seq_len])
|
||||||
|
|
||||||
|
def embed(self, model: str, text: List[str]) -> np.ndarray:
|
||||||
|
payload = {"model": model, "input": text}
|
||||||
|
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
|
response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30)
|
||||||
|
res = json.loads(response.text)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise OpenAIError(
|
||||||
|
f"OpenAI returned an error.\n"
|
||||||
|
f"Status code: {response.status_code}\n"
|
||||||
|
f"Response body: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
unordered_embeddings = [(ans["index"], ans["embedding"]) for ans in res["data"]]
|
||||||
|
ordered_embeddings = sorted(unordered_embeddings, key=lambda x: x[0])
|
||||||
|
generated_embeddings = [emb[1] for emb in ordered_embeddings]
|
||||||
|
return np.array(generated_embeddings)
|
||||||
|
|
||||||
|
def embed_batch(self, model: str, text: List[str]) -> np.ndarray:
|
||||||
|
all_embeddings = []
|
||||||
|
for i in tqdm(
|
||||||
|
range(0, len(text), self.batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
||||||
|
):
|
||||||
|
batch = text[i : i + self.batch_size]
|
||||||
|
batch_limited = [self._ensure_text_limit(content) for content in batch]
|
||||||
|
generated_embeddings = self.embed(model, batch_limited)
|
||||||
|
all_embeddings.append(generated_embeddings)
|
||||||
|
return np.concatenate(all_embeddings)
|
||||||
|
|
||||||
|
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||||
|
return self.embed_batch(self.query_model_encoder_engine, queries)
|
||||||
|
|
||||||
|
def embed_documents(self, docs: List[Document]) -> np.ndarray:
|
||||||
|
return self.embed_batch(self.doc_model_encoder_engine, [d.content for d in docs])
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
training_data: List[Dict[str, Any]],
|
||||||
|
learning_rate: float = 2e-5,
|
||||||
|
n_epochs: int = 1,
|
||||||
|
num_warmup_steps: int = None,
|
||||||
|
batch_size: int = 16,
|
||||||
|
):
|
||||||
|
raise NotImplementedError(f"Training is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
def save(self, save_dir: Union[Path, str]):
|
||||||
|
raise NotImplementedError(f"Saving is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
|
||||||
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
|
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
|
||||||
"farm": _DefaultEmbeddingEncoder,
|
"farm": _DefaultEmbeddingEncoder,
|
||||||
"transformers": _DefaultEmbeddingEncoder,
|
"transformers": _DefaultEmbeddingEncoder,
|
||||||
"sentence_transformers": _SentenceTransformersEmbeddingEncoder,
|
"sentence_transformers": _SentenceTransformersEmbeddingEncoder,
|
||||||
"retribert": _RetribertEmbeddingEncoder,
|
"retribert": _RetribertEmbeddingEncoder,
|
||||||
|
"openai": _OpenAIEmbeddingEncoder,
|
||||||
}
|
}
|
||||||
|
@ -1495,6 +1495,7 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
use_auth_token: Optional[Union[str, bool]] = None,
|
use_auth_token: Optional[Union[str, bool]] = None,
|
||||||
scale_score: bool = True,
|
scale_score: bool = True,
|
||||||
embed_meta_fields: List[str] = [],
|
embed_meta_fields: List[str] = [],
|
||||||
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
||||||
@ -1511,6 +1512,7 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
- ``'transformers'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder)
|
- ``'transformers'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder)
|
||||||
- ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder)
|
- ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder)
|
||||||
- ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder)
|
- ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder)
|
||||||
|
- ``'openai'``: (will use `_OpenAIEmbeddingEncoder` as embedding encoder)
|
||||||
:param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only).
|
:param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only).
|
||||||
Options:
|
Options:
|
||||||
|
|
||||||
@ -1541,6 +1543,9 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
This approach is also used in the TableTextRetriever paper and is likely to improve
|
This approach is also used in the TableTextRetriever paper and is likely to improve
|
||||||
performance if your titles contain meaningful information for retrieval
|
performance if your titles contain meaningful information for retrieval
|
||||||
(topic, entities etc.).
|
(topic, entities etc.).
|
||||||
|
:param api_key: The OpenAI API key. Required if one wants to use OpenAI embeddings. For more
|
||||||
|
details see https://beta.openai.com/account/api-keys
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1561,6 +1566,7 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
self.progress_bar = progress_bar
|
self.progress_bar = progress_bar
|
||||||
self.use_auth_token = use_auth_token
|
self.use_auth_token = use_auth_token
|
||||||
self.scale_score = scale_score
|
self.scale_score = scale_score
|
||||||
|
self.api_key = api_key
|
||||||
self.model_format = (
|
self.model_format = (
|
||||||
self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token)
|
self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token)
|
||||||
if model_format is None
|
if model_format is None
|
||||||
@ -1869,6 +1875,8 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
|
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
|
||||||
|
if any(m in model_name_or_path for m in ["ada", "babbage", "davinci", "curie"]):
|
||||||
|
return "openai"
|
||||||
# Check if model name is a local directory with sentence transformers config file in it
|
# Check if model name is a local directory with sentence transformers config file in it
|
||||||
if Path(model_name_or_path).exists():
|
if Path(model_name_or_path).exists():
|
||||||
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():
|
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():
|
||||||
|
@ -818,6 +818,13 @@ def get_retriever(retriever_type, document_store):
|
|||||||
retriever = EmbeddingRetriever(
|
retriever = EmbeddingRetriever(
|
||||||
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False
|
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False
|
||||||
)
|
)
|
||||||
|
elif retriever_type == "openai":
|
||||||
|
retriever = EmbeddingRetriever(
|
||||||
|
document_store=document_store,
|
||||||
|
embedding_model="ada",
|
||||||
|
use_gpu=False,
|
||||||
|
api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||||
|
)
|
||||||
elif retriever_type == "dpr_lfqa":
|
elif retriever_type == "dpr_lfqa":
|
||||||
retriever = DensePassageRetriever(
|
retriever = DensePassageRetriever(
|
||||||
document_store=document_store,
|
document_store=document_store,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from math import isclose
|
from math import isclose
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,6 +11,7 @@ from elasticsearch import Elasticsearch
|
|||||||
|
|
||||||
from haystack.document_stores import WeaviateDocumentStore
|
from haystack.document_stores import WeaviateDocumentStore
|
||||||
from haystack.nodes.retriever.base import BaseRetriever
|
from haystack.nodes.retriever.base import BaseRetriever
|
||||||
|
from haystack.pipelines import DocumentSearchPipeline
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||||
from haystack.document_stores.faiss import FAISSDocumentStore
|
from haystack.document_stores.faiss import FAISSDocumentStore
|
||||||
@ -216,6 +218,45 @@ def test_retribert_embedding(document_store, retriever, docs_with_ids):
|
|||||||
assert isclose(embedding[0], expected_value, rel_tol=0.001)
|
assert isclose(embedding[0], expected_value, rel_tol=0.001)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
|
||||||
|
@pytest.mark.embedding_dim(1024)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("OPENAI_API_KEY", None),
|
||||||
|
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||||
|
)
|
||||||
|
def test_openai_embedding(document_store, retriever, docs_with_ids):
|
||||||
|
document_store.return_embedding = True
|
||||||
|
document_store.write_documents(docs_with_ids)
|
||||||
|
document_store.update_embeddings(retriever=retriever)
|
||||||
|
|
||||||
|
docs = document_store.get_all_documents()
|
||||||
|
docs = sorted(docs, key=lambda d: d.id)
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
assert len(doc.embedding) == 1024
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
|
||||||
|
@pytest.mark.embedding_dim(1024)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("OPENAI_API_KEY", None),
|
||||||
|
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||||
|
)
|
||||||
|
def test_retriever_basic_search(document_store, retriever, docs_with_ids):
|
||||||
|
document_store.return_embedding = True
|
||||||
|
document_store.write_documents(docs_with_ids)
|
||||||
|
document_store.update_embeddings(retriever=retriever)
|
||||||
|
|
||||||
|
p_retrieval = DocumentSearchPipeline(retriever)
|
||||||
|
res = p_retrieval.run(query="Madrid", params={"Retriever": {"top_k": 1}})
|
||||||
|
assert len(res["documents"]) == 1
|
||||||
|
assert "Madrid" in res["documents"][0].content
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user