feat: Add CohereEmbeddingEncoder to EmbeddingRetriever (#3453)

This commit is contained in:
Vladimir Blagojevic 2022-10-25 17:52:29 +02:00 committed by GitHub
parent 7b15799853
commit 5ca96357ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 102 additions and 14 deletions

View File

@ -25,6 +25,7 @@ env:
--ignore=test/nodes/test_summarizer_translation.py --ignore=test/nodes/test_summarizer_translation.py
--ignore=test/nodes/test_summarizer.py --ignore=test/nodes/test_summarizer.py
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
jobs: jobs:

View File

@ -128,3 +128,11 @@ class OpenAIRateLimitError(OpenAIError):
def __init__(self, message: Optional[str] = None): def __init__(self, message: Optional[str] = None):
super().__init__(message=message, status_code=429) super().__init__(message=message, status_code=429)
class CohereError(NodeError):
"""Exception for issues that occur in the Cohere APIs"""
def __init__(self, message: Optional[str] = None, status_code: Optional[int] = None):
super().__init__(message=message)
self.status_code = status_code

View File

@ -13,7 +13,7 @@ from torch.utils.data.sampler import SequentialSampler
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from haystack.errors import OpenAIError, OpenAIRateLimitError from haystack.errors import OpenAIError, OpenAIRateLimitError, CohereError
from haystack.modeling.data_handler.dataloader import NamedDataLoader from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename
from haystack.modeling.infer import Inferencer from haystack.modeling.infer import Inferencer
@ -386,11 +386,12 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder): class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"): def __init__(self, retriever: "EmbeddingRetriever"):
# pretrained embedding models coming from: # See https://beta.openai.com/docs/guides/embeddings for more details
self.max_seq_len = retriever.max_seq_len # OpenAI has a max seq length of 2048 tokens and unknown max batch size
self.max_seq_len = min(2048, retriever.max_seq_len)
self.url = "https://api.openai.com/v1/embeddings" self.url = "https://api.openai.com/v1/embeddings"
self.api_key = retriever.api_key self.api_key = retriever.api_key
self.batch_size = retriever.batch_size self.batch_size = min(64, retriever.batch_size)
self.progress_bar = retriever.progress_bar self.progress_bar = retriever.progress_bar
model_class: str = next( model_class: str = next(
(m for m in ["ada", "babbage", "davinci", "curie"] if m in retriever.embedding_model), "babbage" (m for m in ["ada", "babbage", "davinci", "curie"] if m in retriever.embedding_model), "babbage"
@ -463,10 +464,73 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
raise NotImplementedError(f"Saving is not implemented for {self.__class__}") raise NotImplementedError(f"Saving is not implemented for {self.__class__}")
class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"):
# See https://docs.cohere.ai/embed-reference/ for more details
# Cohere has a max seq length of 4096 tokens and a max batch size of 16
self.max_seq_len = min(4096, retriever.max_seq_len)
self.url = "https://api.cohere.ai/embed"
self.api_key = retriever.api_key
self.batch_size = min(16, retriever.batch_size)
self.progress_bar = retriever.progress_bar
self.model: str = next((m for m in ["small", "medium", "large"] if m in retriever.embedding_model), "large")
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
def _ensure_text_limit(self, text: str) -> str:
"""
Ensure that length of the text is within the maximum length of the model.
Cohere embedding models have a limit of 4096 tokens
"""
tokenized_payload = self.tokenizer(text)
return self.tokenizer.decode(tokenized_payload["input_ids"][: self.max_seq_len])
@retry_with_exponential_backoff(backoff_in_seconds=10, max_retries=5, errors=(CohereError,))
def embed(self, model: str, text: List[str]) -> np.ndarray:
payload = {"model": model, "texts": text}
headers = {"Authorization": f"BEARER {self.api_key}", "Content-Type": "application/json"}
response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30)
res = json.loads(response.text)
if response.status_code != 200:
raise CohereError(response.text, status_code=response.status_code)
generated_embeddings = [e for e in res["embeddings"]]
return np.array(generated_embeddings)
def embed_batch(self, text: List[str]) -> np.ndarray:
all_embeddings = []
for i in tqdm(
range(0, len(text), self.batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = text[i : i + self.batch_size]
batch_limited = [self._ensure_text_limit(content) for content in batch]
generated_embeddings = self.embed(self.model, batch_limited)
all_embeddings.append(generated_embeddings)
return np.concatenate(all_embeddings)
def embed_queries(self, queries: List[str]) -> np.ndarray:
return self.embed_batch(queries)
def embed_documents(self, docs: List[Document]) -> np.ndarray:
return self.embed_batch([d.content for d in docs])
def train(
self,
training_data: List[Dict[str, Any]],
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: int = None,
batch_size: int = 16,
):
raise NotImplementedError(f"Training is not implemented for {self.__class__}")
def save(self, save_dir: Union[Path, str]):
raise NotImplementedError(f"Saving is not implemented for {self.__class__}")
_EMBEDDING_ENCODERS: Dict[str, Callable] = { _EMBEDDING_ENCODERS: Dict[str, Callable] = {
"farm": _DefaultEmbeddingEncoder, "farm": _DefaultEmbeddingEncoder,
"transformers": _DefaultEmbeddingEncoder, "transformers": _DefaultEmbeddingEncoder,
"sentence_transformers": _SentenceTransformersEmbeddingEncoder, "sentence_transformers": _SentenceTransformersEmbeddingEncoder,
"retribert": _RetribertEmbeddingEncoder, "retribert": _RetribertEmbeddingEncoder,
"openai": _OpenAIEmbeddingEncoder, "openai": _OpenAIEmbeddingEncoder,
"cohere": _CohereEmbeddingEncoder,
} }

View File

@ -1499,7 +1499,10 @@ class EmbeddingRetriever(DenseRetriever):
): ):
""" """
:param document_store: An instance of DocumentStore from which to retrieve documents. :param document_store: An instance of DocumentStore from which to retrieve documents.
:param embedding_model: Local path or name of model in Hugging Face's model hub such as ``'sentence-transformers/all-MiniLM-L6-v2'`` :param embedding_model: Local path or name of model in Hugging Face's model hub such
as ``'sentence-transformers/all-MiniLM-L6-v2'``. The embedding model could also
potentially be an OpenAI model ["ada", "babbage", "davinci", "curie"] or
a Cohere model ["small", "medium", "large"].
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available. :param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
:param batch_size: Number of documents to encode at once. :param batch_size: Number of documents to encode at once.
@ -1513,6 +1516,7 @@ class EmbeddingRetriever(DenseRetriever):
- ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder) - ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder)
- ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder) - ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder)
- ``'openai'``: (will use `_OpenAIEmbeddingEncoder` as embedding encoder) - ``'openai'``: (will use `_OpenAIEmbeddingEncoder` as embedding encoder)
- ``'cohere'``: (will use `_CohereEmbeddingEncoder` as embedding encoder)
:param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only). :param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only).
Options: Options:
@ -1543,8 +1547,8 @@ class EmbeddingRetriever(DenseRetriever):
This approach is also used in the TableTextRetriever paper and is likely to improve This approach is also used in the TableTextRetriever paper and is likely to improve
performance if your titles contain meaningful information for retrieval performance if your titles contain meaningful information for retrieval
(topic, entities etc.). (topic, entities etc.).
:param api_key: The OpenAI API key. Required if one wants to use OpenAI embeddings. For more :param api_key: The OpenAI API key or the Cohere API key. Required if one wants to use OpenAI/Cohere embeddings.
details see https://beta.openai.com/account/api-keys For more details see https://beta.openai.com/account/api-keys and https://dashboard.cohere.ai/api-keys
""" """
super().__init__() super().__init__()
@ -1877,6 +1881,8 @@ class EmbeddingRetriever(DenseRetriever):
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str: def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
if any(m in model_name_or_path for m in ["ada", "babbage", "davinci", "curie"]): if any(m in model_name_or_path for m in ["ada", "babbage", "davinci", "curie"]):
return "openai" return "openai"
if model_name_or_path in ["small", "medium", "large"]:
return "cohere"
# Check if model name is a local directory with sentence transformers config file in it # Check if model name is a local directory with sentence transformers config file in it
if Path(model_name_or_path).exists(): if Path(model_name_or_path).exists():
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists(): if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():

View File

@ -825,6 +825,13 @@ def get_retriever(retriever_type, document_store):
use_gpu=False, use_gpu=False,
api_key=os.environ.get("OPENAI_API_KEY", ""), api_key=os.environ.get("OPENAI_API_KEY", ""),
) )
elif retriever_type == "cohere":
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="small",
use_gpu=False,
api_key=os.environ.get("COHERE_API_KEY", ""),
)
elif retriever_type == "dpr_lfqa": elif retriever_type == "dpr_lfqa":
retriever = DensePassageRetriever( retriever = DensePassageRetriever(
document_store=document_store, document_store=document_store,

View File

@ -266,13 +266,14 @@ def test_retribert_embedding(document_store, retriever, docs_with_ids):
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.parametrize("document_store", ["memory"], indirect=True) @pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["openai"], indirect=True) @pytest.mark.parametrize("retriever", ["openai", "cohere"], indirect=True)
@pytest.mark.embedding_dim(1024) @pytest.mark.embedding_dim(1024)
@pytest.mark.skipif( @pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None), not os.environ.get("OPENAI_API_KEY", None) and not os.environ.get("COHERE_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", reason="Please export an env var called OPENAI_API_KEY/COHERE_API_KEY containing "
"the OpenAI/Cohere API key to run this test.",
) )
def test_openai_embedding(document_store, retriever, docs_with_ids): def test_basic_embedding(document_store, retriever, docs_with_ids):
document_store.return_embedding = True document_store.return_embedding = True
document_store.write_documents(docs_with_ids) document_store.write_documents(docs_with_ids)
document_store.update_embeddings(retriever=retriever) document_store.update_embeddings(retriever=retriever)
@ -286,11 +287,12 @@ def test_openai_embedding(document_store, retriever, docs_with_ids):
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.parametrize("document_store", ["memory"], indirect=True) @pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["openai"], indirect=True) @pytest.mark.parametrize("retriever", ["openai", "cohere"], indirect=True)
@pytest.mark.embedding_dim(1024) @pytest.mark.embedding_dim(1024)
@pytest.mark.skipif( @pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None), not os.environ.get("OPENAI_API_KEY", None) and not os.environ.get("COHERE_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", reason="Please export an env var called OPENAI_API_KEY/COHERE_API_KEY containing "
"the OpenAI/Cohere API key to run this test.",
) )
def test_retriever_basic_search(document_store, retriever, docs_with_ids): def test_retriever_basic_search(document_store, retriever, docs_with_ids):
document_store.return_embedding = True document_store.return_embedding = True