mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 10:49:30 +00:00
Update cohere embedding models (#3704)
This commit is contained in:
parent
4afdbc33b2
commit
42926596e4
@ -477,7 +477,14 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
self.api_key = retriever.api_key
|
||||
self.batch_size = min(16, retriever.batch_size)
|
||||
self.progress_bar = retriever.progress_bar
|
||||
self.model: str = next((m for m in ["small", "medium", "large"] if m in retriever.embedding_model), "large")
|
||||
self.model: str = next(
|
||||
(
|
||||
m
|
||||
for m in ["small", "medium", "large", "multilingual-22-12", "finance-sentiment"]
|
||||
if m in retriever.embedding_model
|
||||
),
|
||||
"multilingual-22-12",
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
def _ensure_text_limit(self, text: str) -> str:
|
||||
|
||||
@ -1819,7 +1819,7 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
|
||||
if any(m in model_name_or_path for m in ["ada", "babbage", "davinci", "curie"]):
|
||||
return "openai"
|
||||
if model_name_or_path in ["small", "medium", "large"]:
|
||||
if model_name_or_path in ["small", "medium", "large", "multilingual-22-12", "finance-sentiment"]:
|
||||
return "cohere"
|
||||
# Check if model name is a local directory with sentence transformers config file in it
|
||||
if Path(model_name_or_path).exists():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user