mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-29 16:08:38 +00:00
feat: Use truncate option for Cohere.embed (#3865)
* Use truncate option for cohere request instead of GPT2 tokenizer to truncate texts * Update max batch size for cohere which is 96 Co-authored-by: ZanSara <sarazanzo94@gmail.com>
This commit is contained in:
parent
04deb3b535
commit
d2bba4935b
@ -487,11 +487,11 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||
# See https://docs.cohere.ai/embed-reference/ for more details
|
||||
# Cohere has a max seq length of 4096 tokens and a max batch size of 16
|
||||
# Cohere has a max seq length of 4096 tokens and a max batch size of 96
|
||||
self.max_seq_len = min(4096, retriever.max_seq_len)
|
||||
self.url = "https://api.cohere.ai/embed"
|
||||
self.api_key = retriever.api_key
|
||||
self.batch_size = min(16, retriever.batch_size)
|
||||
self.batch_size = min(96, retriever.batch_size)
|
||||
self.progress_bar = retriever.progress_bar
|
||||
self.model: str = next(
|
||||
(
|
||||
@ -501,19 +501,10 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
),
|
||||
"multilingual-22-12",
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
def _ensure_text_limit(self, text: str) -> str:
|
||||
"""
|
||||
Ensure that length of the text is within the maximum length of the model.
|
||||
Cohere embedding models have a limit of 4096 tokens
|
||||
"""
|
||||
tokenized_payload = self.tokenizer(text)
|
||||
return self.tokenizer.decode(tokenized_payload["input_ids"][: self.max_seq_len])
|
||||
|
||||
@retry_with_exponential_backoff(backoff_in_seconds=10, max_retries=5, errors=(CohereError,))
|
||||
def embed(self, model: str, text: List[str]) -> np.ndarray:
|
||||
payload = {"model": model, "texts": text}
|
||||
payload = {"model": model, "texts": text, "truncate": "END"}
|
||||
headers = {"Authorization": f"BEARER {self.api_key}", "Content-Type": "application/json"}
|
||||
response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30)
|
||||
res = json.loads(response.text)
|
||||
@ -528,8 +519,7 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
range(0, len(text), self.batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
||||
):
|
||||
batch = text[i : i + self.batch_size]
|
||||
batch_limited = [self._ensure_text_limit(content) for content in batch]
|
||||
generated_embeddings = self.embed(self.model, batch_limited)
|
||||
generated_embeddings = self.embed(self.model, batch)
|
||||
all_embeddings.append(generated_embeddings)
|
||||
return np.concatenate(all_embeddings)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user