From a2d459dbed175ef2335d74af083c2b1b0a7a7b93 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Tue, 25 Oct 2022 17:00:35 +0200 Subject: [PATCH] fix: warning if doc store similarity function is incompatible with Sentence Transformers model (#3455) * check_docstore_similarity_function * remove import --- .../nodes/retriever/_embedding_encoder.py | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py index a0612d6f4..f4f313d3d 100644 --- a/haystack/nodes/retriever/_embedding_encoder.py +++ b/haystack/nodes/retriever/_embedding_encoder.py @@ -89,6 +89,33 @@ class _BaseEmbeddingEncoder: """ pass + def _check_docstore_similarity_function(self, retriever: "EmbeddingRetriever"): + """ + Check that document_store uses a similarity function + compatible with the embedding model + """ + docstore_similarity = retriever.document_store.similarity + model_name = retriever.embedding_model + + if "sentence-transformers" in model_name.lower(): + model_similarity = None + if "-cos-" in model_name.lower(): + model_similarity = "cosine" + elif "-dot-" in model_name.lower(): + model_similarity = "dot_product" + + if model_similarity is not None and docstore_similarity != model_similarity: + logger.warning( + f"You seem to be using {model_name} model with the {docstore_similarity} function instead of the recommended {model_similarity}. " + f"This can be set when initializing the DocumentStore" + ) + elif "dpr" in model_name.lower() and docstore_similarity != "dot_product": + logger.warning( + f"You seem to be using a DPR model with the {docstore_similarity} function. " + f"We recommend using dot_product instead. " + f"This can be set when initializing the DocumentStore" + ) + class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder): def __init__(self, retriever: "EmbeddingRetriever"): @@ -105,21 +132,7 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder): num_processes=0, use_auth_token=retriever.use_auth_token, ) - # Check that document_store has the right similarity function - similarity = retriever.document_store.similarity - # If we are using a sentence transformer model - if "sentence" in retriever.embedding_model.lower() and similarity != "cosine": - logger.warning( - f"You seem to be using a Sentence Transformer with the {similarity} function. " - f"We recommend using cosine instead. " - f"This can be set when initializing the DocumentStore" - ) - elif "dpr" in retriever.embedding_model.lower() and similarity != "dot_product": - logger.warning( - f"You seem to be using a DPR model with the {similarity} function. " - f"We recommend using dot_product instead. " - f"This can be set when initializing the DocumentStore" - ) + self._check_docstore_similarity_function(retriever) def embed(self, texts: Union[List[List[str]], List[str], str]) -> np.ndarray: # TODO: FARM's `sample_to_features_text` need to fix following warning - @@ -182,13 +195,7 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder): self.batch_size = retriever.batch_size self.embedding_model.max_seq_length = retriever.max_seq_len self.show_progress_bar = retriever.progress_bar - document_store = retriever.document_store - if document_store.similarity != "cosine": - logger.warning( - f"You are using a Sentence Transformer with the {document_store.similarity} function. " - f"We recommend using cosine instead. " - f"This can be set when initializing the DocumentStore" - ) + self._check_docstore_similarity_function(retriever) def embed(self, texts: Union[List[str], str]) -> np.ndarray: # texts can be a list of strings