mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 10:49:30 +00:00
fix: warning if doc store similarity function is incompatible with Sentence Transformers model (#3455)
* check_docstore_similarity_function * remove import
This commit is contained in:
parent
54ec13eaf7
commit
a2d459dbed
@ -89,6 +89,33 @@ class _BaseEmbeddingEncoder:
|
||||
"""
|
||||
pass
|
||||
|
||||
def _check_docstore_similarity_function(self, retriever: "EmbeddingRetriever"):
|
||||
"""
|
||||
Check that document_store uses a similarity function
|
||||
compatible with the embedding model
|
||||
"""
|
||||
docstore_similarity = retriever.document_store.similarity
|
||||
model_name = retriever.embedding_model
|
||||
|
||||
if "sentence-transformers" in model_name.lower():
|
||||
model_similarity = None
|
||||
if "-cos-" in model_name.lower():
|
||||
model_similarity = "cosine"
|
||||
elif "-dot-" in model_name.lower():
|
||||
model_similarity = "dot_product"
|
||||
|
||||
if model_similarity is not None and docstore_similarity != model_similarity:
|
||||
logger.warning(
|
||||
f"You seem to be using {model_name} model with the {docstore_similarity} function instead of the recommended {model_similarity}. "
|
||||
f"This can be set when initializing the DocumentStore"
|
||||
)
|
||||
elif "dpr" in model_name.lower() and docstore_similarity != "dot_product":
|
||||
logger.warning(
|
||||
f"You seem to be using a DPR model with the {docstore_similarity} function. "
|
||||
f"We recommend using dot_product instead. "
|
||||
f"This can be set when initializing the DocumentStore"
|
||||
)
|
||||
|
||||
|
||||
class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||
@ -105,21 +132,7 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
num_processes=0,
|
||||
use_auth_token=retriever.use_auth_token,
|
||||
)
|
||||
# Check that document_store has the right similarity function
|
||||
similarity = retriever.document_store.similarity
|
||||
# If we are using a sentence transformer model
|
||||
if "sentence" in retriever.embedding_model.lower() and similarity != "cosine":
|
||||
logger.warning(
|
||||
f"You seem to be using a Sentence Transformer with the {similarity} function. "
|
||||
f"We recommend using cosine instead. "
|
||||
f"This can be set when initializing the DocumentStore"
|
||||
)
|
||||
elif "dpr" in retriever.embedding_model.lower() and similarity != "dot_product":
|
||||
logger.warning(
|
||||
f"You seem to be using a DPR model with the {similarity} function. "
|
||||
f"We recommend using dot_product instead. "
|
||||
f"This can be set when initializing the DocumentStore"
|
||||
)
|
||||
self._check_docstore_similarity_function(retriever)
|
||||
|
||||
def embed(self, texts: Union[List[List[str]], List[str], str]) -> np.ndarray:
|
||||
# TODO: FARM's `sample_to_features_text` need to fix following warning -
|
||||
@ -182,13 +195,7 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
self.batch_size = retriever.batch_size
|
||||
self.embedding_model.max_seq_length = retriever.max_seq_len
|
||||
self.show_progress_bar = retriever.progress_bar
|
||||
document_store = retriever.document_store
|
||||
if document_store.similarity != "cosine":
|
||||
logger.warning(
|
||||
f"You are using a Sentence Transformer with the {document_store.similarity} function. "
|
||||
f"We recommend using cosine instead. "
|
||||
f"This can be set when initializing the DocumentStore"
|
||||
)
|
||||
self._check_docstore_similarity_function(retriever)
|
||||
|
||||
def embed(self, texts: Union[List[str], str]) -> np.ndarray:
|
||||
# texts can be a list of strings
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user