mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-04 03:39:31 +00:00 
			
		
		
		
	fix: warning if doc store similarity function is incompatible with Sentence Transformers model (#3455)
* check_docstore_similarity_function * remove import
This commit is contained in:
		
							parent
							
								
									54ec13eaf7
								
							
						
					
					
						commit
						a2d459dbed
					
				@ -89,6 +89,33 @@ class _BaseEmbeddingEncoder:
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def _check_docstore_similarity_function(self, retriever: "EmbeddingRetriever"):
 | 
			
		||||
        """
 | 
			
		||||
        Check that document_store uses a similarity function
 | 
			
		||||
        compatible with the embedding model
 | 
			
		||||
        """
 | 
			
		||||
        docstore_similarity = retriever.document_store.similarity
 | 
			
		||||
        model_name = retriever.embedding_model
 | 
			
		||||
 | 
			
		||||
        if "sentence-transformers" in model_name.lower():
 | 
			
		||||
            model_similarity = None
 | 
			
		||||
            if "-cos-" in model_name.lower():
 | 
			
		||||
                model_similarity = "cosine"
 | 
			
		||||
            elif "-dot-" in model_name.lower():
 | 
			
		||||
                model_similarity = "dot_product"
 | 
			
		||||
 | 
			
		||||
            if model_similarity is not None and docstore_similarity != model_similarity:
 | 
			
		||||
                logger.warning(
 | 
			
		||||
                    f"You seem to be using {model_name} model with the {docstore_similarity} function instead of the recommended {model_similarity}. "
 | 
			
		||||
                    f"This can be set when initializing the DocumentStore"
 | 
			
		||||
                )
 | 
			
		||||
        elif "dpr" in model_name.lower() and docstore_similarity != "dot_product":
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                f"You seem to be using a DPR model with the {docstore_similarity} function. "
 | 
			
		||||
                f"We recommend using dot_product instead. "
 | 
			
		||||
                f"This can be set when initializing the DocumentStore"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
 | 
			
		||||
    def __init__(self, retriever: "EmbeddingRetriever"):
 | 
			
		||||
@ -105,21 +132,7 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
 | 
			
		||||
            num_processes=0,
 | 
			
		||||
            use_auth_token=retriever.use_auth_token,
 | 
			
		||||
        )
 | 
			
		||||
        # Check that document_store has the right similarity function
 | 
			
		||||
        similarity = retriever.document_store.similarity
 | 
			
		||||
        # If we are using a sentence transformer model
 | 
			
		||||
        if "sentence" in retriever.embedding_model.lower() and similarity != "cosine":
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                f"You seem to be using a Sentence Transformer with the {similarity} function. "
 | 
			
		||||
                f"We recommend using cosine instead. "
 | 
			
		||||
                f"This can be set when initializing the DocumentStore"
 | 
			
		||||
            )
 | 
			
		||||
        elif "dpr" in retriever.embedding_model.lower() and similarity != "dot_product":
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                f"You seem to be using a DPR model with the {similarity} function. "
 | 
			
		||||
                f"We recommend using dot_product instead. "
 | 
			
		||||
                f"This can be set when initializing the DocumentStore"
 | 
			
		||||
            )
 | 
			
		||||
        self._check_docstore_similarity_function(retriever)
 | 
			
		||||
 | 
			
		||||
    def embed(self, texts: Union[List[List[str]], List[str], str]) -> np.ndarray:
 | 
			
		||||
        # TODO: FARM's `sample_to_features_text` need to fix following warning -
 | 
			
		||||
@ -182,13 +195,7 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
 | 
			
		||||
        self.batch_size = retriever.batch_size
 | 
			
		||||
        self.embedding_model.max_seq_length = retriever.max_seq_len
 | 
			
		||||
        self.show_progress_bar = retriever.progress_bar
 | 
			
		||||
        document_store = retriever.document_store
 | 
			
		||||
        if document_store.similarity != "cosine":
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                f"You are using a Sentence Transformer with the {document_store.similarity} function. "
 | 
			
		||||
                f"We recommend using cosine instead. "
 | 
			
		||||
                f"This can be set when initializing the DocumentStore"
 | 
			
		||||
            )
 | 
			
		||||
        self._check_docstore_similarity_function(retriever)
 | 
			
		||||
 | 
			
		||||
    def embed(self, texts: Union[List[str], str]) -> np.ndarray:
 | 
			
		||||
        # texts can be a list of strings
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user