Update cohere embedding models (#3704)

This commit is contained in:
Vladimir Blagojevic 2022-12-16 16:49:59 +01:00 committed by GitHub
parent 4afdbc33b2
commit 42926596e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 2 deletions

View File

@ -477,7 +477,14 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
self.api_key = retriever.api_key
self.batch_size = min(16, retriever.batch_size)
self.progress_bar = retriever.progress_bar
self.model: str = next((m for m in ["small", "medium", "large"] if m in retriever.embedding_model), "large")
self.model: str = next(
(
m
for m in ["small", "medium", "large", "multilingual-22-12", "finance-sentiment"]
if m in retriever.embedding_model
),
"multilingual-22-12",
)
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
def _ensure_text_limit(self, text: str) -> str:

View File

@ -1819,7 +1819,7 @@ class EmbeddingRetriever(DenseRetriever):
def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str:
if any(m in model_name_or_path for m in ["ada", "babbage", "davinci", "curie"]):
return "openai"
if model_name_or_path in ["small", "medium", "large"]:
if model_name_or_path in ["small", "medium", "large", "multilingual-22-12", "finance-sentiment"]:
return "cohere"
# Check if model name is a local directory with sentence transformers config file in it
if Path(model_name_or_path).exists():