mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 10:26:27 +00:00
feat: Add CohereEmbeddingEncoder to EmbeddingRetriever (#3453)
This commit is contained in:
parent
7b15799853
commit
5ca96357ff
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@ -25,6 +25,7 @@ env:
|
|||||||
--ignore=test/nodes/test_summarizer_translation.py
|
--ignore=test/nodes/test_summarizer_translation.py
|
||||||
--ignore=test/nodes/test_summarizer.py
|
--ignore=test/nodes/test_summarizer.py
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
||||||
|
@ -128,3 +128,11 @@ class OpenAIRateLimitError(OpenAIError):
|
|||||||
|
|
||||||
def __init__(self, message: Optional[str] = None):
|
def __init__(self, message: Optional[str] = None):
|
||||||
super().__init__(message=message, status_code=429)
|
super().__init__(message=message, status_code=429)
|
||||||
|
|
||||||
|
|
||||||
|
class CohereError(NodeError):
|
||||||
|
"""Exception for issues that occur in the Cohere APIs"""
|
||||||
|
|
||||||
|
def __init__(self, message: Optional[str] = None, status_code: Optional[int] = None):
|
||||||
|
super().__init__(message=message)
|
||||||
|
self.status_code = status_code
|
||||||
|
@ -13,7 +13,7 @@ from torch.utils.data.sampler import SequentialSampler
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
from haystack.errors import OpenAIError, OpenAIRateLimitError
|
from haystack.errors import OpenAIError, OpenAIRateLimitError, CohereError
|
||||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||||
from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename
|
from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename
|
||||||
from haystack.modeling.infer import Inferencer
|
from haystack.modeling.infer import Inferencer
|
||||||
@ -386,11 +386,12 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
|
|
||||||
class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||||
def __init__(self, retriever: "EmbeddingRetriever"):
|
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||||
# pretrained embedding models coming from:
|
# See https://beta.openai.com/docs/guides/embeddings for more details
|
||||||
self.max_seq_len = retriever.max_seq_len
|
# OpenAI has a max seq length of 2048 tokens and unknown max batch size
|
||||||
|
self.max_seq_len = min(2048, retriever.max_seq_len)
|
||||||
self.url = "https://api.openai.com/v1/embeddings"
|
self.url = "https://api.openai.com/v1/embeddings"
|
||||||
self.api_key = retriever.api_key
|
self.api_key = retriever.api_key
|
||||||
self.batch_size = retriever.batch_size
|
self.batch_size = min(64, retriever.batch_size)
|
||||||
self.progress_bar = retriever.progress_bar
|
self.progress_bar = retriever.progress_bar
|
||||||
model_class: str = next(
|
model_class: str = next(
|
||||||
(m for m in ["ada", "babbage", "davinci", "curie"] if m in retriever.embedding_model), "babbage"
|
(m for m in ["ada", "babbage", "davinci", "curie"] if m in retriever.embedding_model), "babbage"
|
||||||
@ -463,10 +464,73 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
raise NotImplementedError(f"Saving is not implemented for {self.__class__}")
|
raise NotImplementedError(f"Saving is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
|
||||||
|
class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||||
|
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||||
|
# See https://docs.cohere.ai/embed-reference/ for more details
|
||||||
|
# Cohere has a max seq length of 4096 tokens and a max batch size of 16
|
||||||
|
self.max_seq_len = min(4096, retriever.max_seq_len)
|
||||||
|
self.url = "https://api.cohere.ai/embed"
|
||||||
|
self.api_key = retriever.api_key
|
||||||
|
self.batch_size = min(16, retriever.batch_size)
|
||||||
|
self.progress_bar = retriever.progress_bar
|
||||||
|
self.model: str = next((m for m in ["small", "medium", "large"] if m in retriever.embedding_model), "large")
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
def _ensure_text_limit(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Ensure that length of the text is within the maximum length of the model.
|
||||||
|
Cohere embedding models have a limit of 4096 tokens
|
||||||
|
"""
|
||||||
|
tokenized_payload = self.tokenizer(text)
|
||||||
|
return self.tokenizer.decode(tokenized_payload["input_ids"][: self.max_seq_len])
|
||||||
|
|
||||||
|
@retry_with_exponential_backoff(backoff_in_seconds=10, max_retries=5, errors=(CohereError,))
|
||||||
|
def embed(self, model: str, text: List[str]) -> np.ndarray:
|
||||||
|
payload = {"model": model, "texts": text}
|
||||||
|
headers = {"Authorization": f"BEARER {self.api_key}", "Content-Type": "application/json"}
|
||||||
|
response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30)
|
||||||
|
res = json.loads(response.text)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise CohereError(response.text, status_code=response.status_code)
|
||||||
|
generated_embeddings = [e for e in res["embeddings"]]
|
||||||
|
return np.array(generated_embeddings)
|
||||||
|
|
||||||
|
def embed_batch(self, text: List[str]) -> np.ndarray:
|
||||||
|
all_embeddings = []
|
||||||
|
for i in tqdm(
|
||||||
|
range(0, len(text), self.batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
||||||
|
):
|
||||||
|
batch = text[i : i + self.batch_size]
|
||||||
|
batch_limited = [self._ensure_text_limit(content) for content in batch]
|
||||||
|
generated_embeddings = self.embed(self.model, batch_limited)
|
||||||
|
all_embeddings.append(generated_embeddings)
|
||||||
|
return np.concatenate(all_embeddings)
|
||||||
|
|
||||||
|
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||||
|
return self.embed_batch(queries)
|
||||||
|
|
||||||
|
def embed_documents(self, docs: List[Document]) -> np.ndarray:
|
||||||
|
return self.embed_batch([d.content for d in docs])
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
training_data: List[Dict[str, Any]],
|
||||||
|
learning_rate: float = 2e-5,
|
||||||
|
n_epochs: int = 1,
|
||||||
|
num_warmup_steps: int = None,
|
||||||
|
batch_size: int = 16,
|
||||||
|
):
|
||||||
|
raise NotImplementedError(f"Training is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
def save(self, save_dir: Union[Path, str]):
|
||||||
|
raise NotImplementedError(f"Saving is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
|
||||||
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
|
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
|
||||||
"farm": _DefaultEmbeddingEncoder,
|
"farm": _DefaultEmbeddingEncoder,
|
||||||
"transformers": _DefaultEmbeddingEncoder,
|
"transformers": _DefaultEmbeddingEncoder,
|
||||||
"sentence_transformers": _SentenceTransformersEmbeddingEncoder,
|
"sentence_transformers": _SentenceTransformersEmbeddingEncoder,
|
||||||
"retribert": _RetribertEmbeddingEncoder,
|
"retribert": _RetribertEmbeddingEncoder,
|
||||||
"openai": _OpenAIEmbeddingEncoder,
|
"openai": _OpenAIEmbeddingEncoder,
|
||||||
|
"cohere": _CohereEmbeddingEncoder,
|
||||||
}
|
}
|
||||||
|
@ -1499,7 +1499,10 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
||||||
:param embedding_model: Local path or name of model in Hugging Face's model hub such as ``'sentence-transformers/all-MiniLM-L6-v2'``
|
:param embedding_model: Local path or name of model in Hugging Face's model hub such
|
||||||
|
as ``'sentence-transformers/all-MiniLM-L6-v2'``. The embedding model could also
|
||||||
|
potentially be an OpenAI model ["ada", "babbage", "davinci", "curie"] or
|
||||||
|
a Cohere model ["small", "medium", "large"].
|
||||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||||
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
|
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
|
||||||
:param batch_size: Number of documents to encode at once.
|
:param batch_size: Number of documents to encode at once.
|
||||||
@ -1513,6 +1516,7 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
- ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder)
|
- ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder)
|
||||||
- ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder)
|
- ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder)
|
||||||
- ``'openai'``: (will use `_OpenAIEmbeddingEncoder` as embedding encoder)
|
- ``'openai'``: (will use `_OpenAIEmbeddingEncoder` as embedding encoder)
|
||||||
|
- ``'cohere'``: (will use `_CohereEmbeddingEncoder` as embedding encoder)
|
||||||
:param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only).
|
:param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only).
|
||||||
Options:
|
Options:
|
||||||
|
|
||||||
@ -1543,8 +1547,8 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
This approach is also used in the TableTextRetriever paper and is likely to improve
|
This approach is also used in the TableTextRetriever paper and is likely to improve
|
||||||
performance if your titles contain meaningful information for retrieval
|
performance if your titles contain meaningful information for retrieval
|
||||||
(topic, entities etc.).
|
(topic, entities etc.).
|
||||||
:param api_key: The OpenAI API key. Required if one wants to use OpenAI embeddings. For more
|
:param api_key: The OpenAI API key or the Cohere API key. Required if one wants to use OpenAI/Cohere embeddings.
|
||||||
details see https://beta.openai.com/account/api-keys
|
For more details see https://beta.openai.com/account/api-keys and https://dashboard.cohere.ai/api-keys
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1877,6 +1881,8 @@ class EmbeddingRetriever(DenseRetriever):
|
|||||||
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
|
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
|
||||||
if any(m in model_name_or_path for m in ["ada", "babbage", "davinci", "curie"]):
|
if any(m in model_name_or_path for m in ["ada", "babbage", "davinci", "curie"]):
|
||||||
return "openai"
|
return "openai"
|
||||||
|
if model_name_or_path in ["small", "medium", "large"]:
|
||||||
|
return "cohere"
|
||||||
# Check if model name is a local directory with sentence transformers config file in it
|
# Check if model name is a local directory with sentence transformers config file in it
|
||||||
if Path(model_name_or_path).exists():
|
if Path(model_name_or_path).exists():
|
||||||
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():
|
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():
|
||||||
|
@ -825,6 +825,13 @@ def get_retriever(retriever_type, document_store):
|
|||||||
use_gpu=False,
|
use_gpu=False,
|
||||||
api_key=os.environ.get("OPENAI_API_KEY", ""),
|
api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||||
)
|
)
|
||||||
|
elif retriever_type == "cohere":
|
||||||
|
retriever = EmbeddingRetriever(
|
||||||
|
document_store=document_store,
|
||||||
|
embedding_model="small",
|
||||||
|
use_gpu=False,
|
||||||
|
api_key=os.environ.get("COHERE_API_KEY", ""),
|
||||||
|
)
|
||||||
elif retriever_type == "dpr_lfqa":
|
elif retriever_type == "dpr_lfqa":
|
||||||
retriever = DensePassageRetriever(
|
retriever = DensePassageRetriever(
|
||||||
document_store=document_store,
|
document_store=document_store,
|
||||||
|
@ -266,13 +266,14 @@ def test_retribert_embedding(document_store, retriever, docs_with_ids):
|
|||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["openai", "cohere"], indirect=True)
|
||||||
@pytest.mark.embedding_dim(1024)
|
@pytest.mark.embedding_dim(1024)
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not os.environ.get("OPENAI_API_KEY", None),
|
not os.environ.get("OPENAI_API_KEY", None) and not os.environ.get("COHERE_API_KEY", None),
|
||||||
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
reason="Please export an env var called OPENAI_API_KEY/COHERE_API_KEY containing "
|
||||||
|
"the OpenAI/Cohere API key to run this test.",
|
||||||
)
|
)
|
||||||
def test_openai_embedding(document_store, retriever, docs_with_ids):
|
def test_basic_embedding(document_store, retriever, docs_with_ids):
|
||||||
document_store.return_embedding = True
|
document_store.return_embedding = True
|
||||||
document_store.write_documents(docs_with_ids)
|
document_store.write_documents(docs_with_ids)
|
||||||
document_store.update_embeddings(retriever=retriever)
|
document_store.update_embeddings(retriever=retriever)
|
||||||
@ -286,11 +287,12 @@ def test_openai_embedding(document_store, retriever, docs_with_ids):
|
|||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["openai", "cohere"], indirect=True)
|
||||||
@pytest.mark.embedding_dim(1024)
|
@pytest.mark.embedding_dim(1024)
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not os.environ.get("OPENAI_API_KEY", None),
|
not os.environ.get("OPENAI_API_KEY", None) and not os.environ.get("COHERE_API_KEY", None),
|
||||||
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
reason="Please export an env var called OPENAI_API_KEY/COHERE_API_KEY containing "
|
||||||
|
"the OpenAI/Cohere API key to run this test.",
|
||||||
)
|
)
|
||||||
def test_retriever_basic_search(document_store, retriever, docs_with_ids):
|
def test_retriever_basic_search(document_store, retriever, docs_with_ids):
|
||||||
document_store.return_embedding = True
|
document_store.return_embedding = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user