feat: Use truncate option for Cohere.embed (#3865)

* Use truncate option for cohere request instead of GPT2 tokenizer to truncate texts

* Update max batch size for cohere which is 96

Co-authored-by: ZanSara <sarazanzo94@gmail.com>
This commit is contained in:
Sebastian 2023-01-20 09:49:55 +01:00 committed by GitHub
parent 04deb3b535
commit d2bba4935b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -487,11 +487,11 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"):
# See https://docs.cohere.ai/embed-reference/ for more details
# Cohere has a max seq length of 4096 tokens and a max batch size of 16
# Cohere has a max seq length of 4096 tokens and a max batch size of 96
self.max_seq_len = min(4096, retriever.max_seq_len)
self.url = "https://api.cohere.ai/embed"
self.api_key = retriever.api_key
self.batch_size = min(16, retriever.batch_size)
self.batch_size = min(96, retriever.batch_size)
self.progress_bar = retriever.progress_bar
self.model: str = next(
(
@ -501,19 +501,10 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
),
"multilingual-22-12",
)
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
def _ensure_text_limit(self, text: str) -> str:
"""
Ensure that length of the text is within the maximum length of the model.
Cohere embedding models have a limit of 4096 tokens
"""
tokenized_payload = self.tokenizer(text)
return self.tokenizer.decode(tokenized_payload["input_ids"][: self.max_seq_len])
@retry_with_exponential_backoff(backoff_in_seconds=10, max_retries=5, errors=(CohereError,))
def embed(self, model: str, text: List[str]) -> np.ndarray:
payload = {"model": model, "texts": text}
payload = {"model": model, "texts": text, "truncate": "END"}
headers = {"Authorization": f"BEARER {self.api_key}", "Content-Type": "application/json"}
response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30)
res = json.loads(response.text)
@ -528,8 +519,7 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
range(0, len(text), self.batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = text[i : i + self.batch_size]
batch_limited = [self._ensure_text_limit(content) for content in batch]
generated_embeddings = self.embed(self.model, batch_limited)
generated_embeddings = self.embed(self.model, batch)
all_embeddings.append(generated_embeddings)
return np.concatenate(all_embeddings)