diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py index 4ed663c4a..45255136b 100644 --- a/haystack/nodes/retriever/_embedding_encoder.py +++ b/haystack/nodes/retriever/_embedding_encoder.py @@ -477,7 +477,14 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder): self.api_key = retriever.api_key self.batch_size = min(16, retriever.batch_size) self.progress_bar = retriever.progress_bar - self.model: str = next((m for m in ["small", "medium", "large"] if m in retriever.embedding_model), "large") + self.model: str = next( + ( + m + for m in ["small", "medium", "large", "multilingual-22-12", "finance-sentiment"] + if m in retriever.embedding_model + ), + "multilingual-22-12", + ) self.tokenizer = AutoTokenizer.from_pretrained("gpt2") def _ensure_text_limit(self, text: str) -> str: diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index bd99a7ccd..125fda861 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1819,7 +1819,7 @@ class EmbeddingRetriever(DenseRetriever): def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str: if any(m in model_name_or_path for m in ["ada", "babbage", "davinci", "curie"]): return "openai" - if model_name_or_path in ["small", "medium", "large"]: + if model_name_or_path in ["small", "medium", "large", "multilingual-22-12", "finance-sentiment"]: return "cohere" # Check if model name is a local directory with sentence transformers config file in it if Path(model_name_or_path).exists():