fix: warning if doc store similarity function is incompatible with Sentence Transformers model (#3455)

* check_docstore_similarity_function

* remove import
This commit is contained in:
Stefano Fiorucci 2022-10-25 17:00:35 +02:00 committed by GitHub
parent 54ec13eaf7
commit a2d459dbed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -89,6 +89,33 @@ class _BaseEmbeddingEncoder:
"""
pass
def _check_docstore_similarity_function(self, retriever: "EmbeddingRetriever"):
"""
Check that document_store uses a similarity function
compatible with the embedding model
"""
docstore_similarity = retriever.document_store.similarity
model_name = retriever.embedding_model
if "sentence-transformers" in model_name.lower():
model_similarity = None
if "-cos-" in model_name.lower():
model_similarity = "cosine"
elif "-dot-" in model_name.lower():
model_similarity = "dot_product"
if model_similarity is not None and docstore_similarity != model_similarity:
logger.warning(
f"You seem to be using {model_name} model with the {docstore_similarity} function instead of the recommended {model_similarity}. "
f"This can be set when initializing the DocumentStore"
)
elif "dpr" in model_name.lower() and docstore_similarity != "dot_product":
logger.warning(
f"You seem to be using a DPR model with the {docstore_similarity} function. "
f"We recommend using dot_product instead. "
f"This can be set when initializing the DocumentStore"
)
class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"):
@ -105,21 +132,7 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
num_processes=0,
use_auth_token=retriever.use_auth_token,
)
# Check that document_store has the right similarity function
similarity = retriever.document_store.similarity
# If we are using a sentence transformer model
if "sentence" in retriever.embedding_model.lower() and similarity != "cosine":
logger.warning(
f"You seem to be using a Sentence Transformer with the {similarity} function. "
f"We recommend using cosine instead. "
f"This can be set when initializing the DocumentStore"
)
elif "dpr" in retriever.embedding_model.lower() and similarity != "dot_product":
logger.warning(
f"You seem to be using a DPR model with the {similarity} function. "
f"We recommend using dot_product instead. "
f"This can be set when initializing the DocumentStore"
)
self._check_docstore_similarity_function(retriever)
def embed(self, texts: Union[List[List[str]], List[str], str]) -> np.ndarray:
# TODO: FARM's `sample_to_features_text` need to fix following warning -
@ -182,13 +195,7 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
self.batch_size = retriever.batch_size
self.embedding_model.max_seq_length = retriever.max_seq_len
self.show_progress_bar = retriever.progress_bar
document_store = retriever.document_store
if document_store.similarity != "cosine":
logger.warning(
f"You are using a Sentence Transformer with the {document_store.similarity} function. "
f"We recommend using cosine instead. "
f"This can be set when initializing the DocumentStore"
)
self._check_docstore_similarity_function(retriever)
def embed(self, texts: Union[List[str], str]) -> np.ndarray:
# texts can be a list of strings