mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
Scale dot product into probabilities (#667)
* scale dot product * Add tip in documentation * Add recommendation boxes * WIP: Use similarity attribute in all doc stores * Implement similarity for InMemoryDS * Add FAISS support * Clean printout * Update documentation * Implement document field map
This commit is contained in:
parent
a0e146dde6
commit
d8154939fc
@ -122,6 +122,16 @@ Indexing using DPR is comparatively expensive in terms of required computation s
|
||||
The embeddings that are created in this step can be stored in FAISS, a database optimized for vector similarity.
|
||||
DPR can also work with the ElasticsearchDocumentStore or the InMemoryDocumentStore.
|
||||
|
||||
<div class="recommendation">
|
||||
|
||||
**Tip**
|
||||
|
||||
When using DPR, it is recommended that you use the dot product similarity function since that is how it is trained.
|
||||
To do so, simply provide `similarity="dot_product"` when initializing the DocumentStore
|
||||
as is done in the code example below.
|
||||
|
||||
</div>
|
||||
|
||||
There are two design decisions that have made DPR particularly performant.
|
||||
|
||||
|
||||
@ -136,7 +146,7 @@ If you’d like to learn how to set up a DPR based system, have a look at our tu
|
||||
### Initialisation
|
||||
|
||||
```python
|
||||
document_store = FAISSDocumentStore()
|
||||
document_store = FAISSDocumentStore(similarity="dot_product")
|
||||
...
|
||||
retriever = DensePassageRetriever(
|
||||
document_store=document_store,
|
||||
@ -161,10 +171,20 @@ They are particular suited to cases where your query input is similar in style t
|
||||
i.e. when you are searching for most similar documents.
|
||||
This is not inherently suited to query based search where the length, language and format of the query usually significantly differs from the searched for text.
|
||||
|
||||
<div class="recommendation">
|
||||
|
||||
**Tip**
|
||||
|
||||
When using Sentence Transformer models, we recommend that you use a cosine similarity function.
|
||||
To do so, simply provide `similarity="cosine"` when initializing the DocumentStore
|
||||
as is done in the code example below.
|
||||
|
||||
</div>
|
||||
|
||||
### Initialisation
|
||||
|
||||
```python
|
||||
document_store = ElasticsearchDocumentStore()
|
||||
document_store = ElasticsearchDocumentStore(similarity="cosine")
|
||||
...
|
||||
retriever = EmbeddingRetriever(document_store=document_store,
|
||||
embedding_model="deepset/sentence_bert")
|
||||
|
||||
@ -12,6 +12,7 @@ class BaseDocumentStore(ABC):
|
||||
"""
|
||||
index: Optional[str]
|
||||
label_index: Optional[str]
|
||||
similarity: Optional[str]
|
||||
|
||||
@abstractmethod
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
||||
|
||||
@ -117,6 +117,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
|
||||
self.update_existing_documents = update_existing_documents
|
||||
self.refresh_type = refresh_type
|
||||
self.similarity = similarity
|
||||
if similarity == "cosine":
|
||||
self.similarity_fn_name = "cosineSimilarity"
|
||||
elif similarity == "dot_product":
|
||||
@ -596,7 +597,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
if score:
|
||||
if adapt_score_for_embedding:
|
||||
score -= 1000
|
||||
probability = (score + 1) / 2 # scaling probability from cosine similarity
|
||||
if self.similarity == "cosine":
|
||||
probability = (score + 1) / 2 # scaling probability from cosine similarity
|
||||
elif self.similarity == "dot_product":
|
||||
probability = float(expit(np.asarray(score / 100))) # scaling probability from dot product
|
||||
else:
|
||||
probability = float(expit(np.asarray(score / 8))) # scaling probability from TFIDF/BM25
|
||||
else:
|
||||
|
||||
@ -9,6 +9,8 @@ from haystack import Document
|
||||
from haystack.document_store.sql import SQLDocumentStore
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
|
||||
from scipy.special import expit
|
||||
|
||||
if platform != 'win32' and platform != 'cygwin':
|
||||
import faiss
|
||||
else:
|
||||
@ -39,6 +41,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
return_embedding: bool = False,
|
||||
update_existing_documents: bool = False,
|
||||
index: str = "document",
|
||||
similarity: str = "dot_product",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -82,6 +85,11 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
self.index_buffer_size = index_buffer_size
|
||||
self.return_embedding = return_embedding
|
||||
if similarity == "dot_product":
|
||||
self.similarity = similarity
|
||||
else:
|
||||
raise ValueError("The FAISS document store can currently only support dot_product similarity. "
|
||||
"Please set similarity=\"dot_product\"")
|
||||
super().__init__(
|
||||
url=sql_url,
|
||||
update_existing_documents=update_existing_documents,
|
||||
@ -116,7 +124,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
# doc + metadata index
|
||||
index = index or self.index
|
||||
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
||||
field_map = self._create_document_field_map()
|
||||
document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
|
||||
|
||||
add_vectors = False if document_objects[0].embedding is None else True
|
||||
|
||||
@ -142,6 +151,11 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index)
|
||||
|
||||
def _create_document_field_map(self) -> Dict:
|
||||
return {
|
||||
self.index: "embedding",
|
||||
}
|
||||
|
||||
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
|
||||
"""
|
||||
Updates the embeddings in the the document store using the encoding model specified in the retriever.
|
||||
@ -273,7 +287,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
|
||||
for doc in documents:
|
||||
doc.score = scores_for_vector_ids[doc.meta["vector_id"]]
|
||||
doc.probability = (doc.score + 1) / 2
|
||||
doc.probability = float(expit(np.asarray(doc.score / 100)))
|
||||
if return_embedding is True:
|
||||
doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"]))
|
||||
|
||||
|
||||
@ -8,6 +8,8 @@ from haystack import Document, Label
|
||||
from haystack.preprocessor.utils import eval_data_from_file
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
|
||||
from scipy.spatial.distance import cosine
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -17,13 +19,14 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
In-memory document store
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_field: Optional[str] = "embedding", return_embedding: bool = False):
|
||||
def __init__(self, embedding_field: Optional[str] = "embedding", return_embedding: bool = False, similarity="dot_product"):
|
||||
self.indexes: Dict[str, Dict] = defaultdict(dict)
|
||||
self.index: str = "document"
|
||||
self.label_index: str = "label"
|
||||
self.embedding_field: str = embedding_field if embedding_field is not None else "embedding"
|
||||
self.embedding_dim: int = 768
|
||||
self.return_embedding: bool = return_embedding
|
||||
self.similarity: str = similarity
|
||||
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
||||
"""
|
||||
@ -41,11 +44,17 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
"""
|
||||
index = index or self.index
|
||||
|
||||
documents_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
||||
field_map = self._create_document_field_map()
|
||||
documents_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
|
||||
|
||||
for document in documents_objects:
|
||||
self.indexes[index][document.id] = document
|
||||
|
||||
def _create_document_field_map(self):
|
||||
return {
|
||||
self.embedding_field: "embedding",
|
||||
}
|
||||
|
||||
def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[str] = None):
|
||||
"""Write annotation labels into document store."""
|
||||
index = index or self.label_index
|
||||
@ -106,18 +115,24 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
|
||||
candidate_docs = []
|
||||
for idx, doc in self.indexes[index].items():
|
||||
curr_meta = deepcopy(doc.meta)
|
||||
new_document = Document(
|
||||
id=doc.id,
|
||||
text=doc.text,
|
||||
meta=deepcopy(doc.meta)
|
||||
meta=curr_meta,
|
||||
embedding=doc.embedding
|
||||
)
|
||||
new_document.embedding = doc.embedding if return_embedding is True else None
|
||||
score = dot(query_emb, doc.embedding) / (
|
||||
norm(query_emb) * norm(doc.embedding)
|
||||
)
|
||||
|
||||
if self.similarity == "dot_product":
|
||||
score = dot(query_emb, doc.embedding) / (
|
||||
norm(query_emb) * norm(doc.embedding)
|
||||
)
|
||||
elif self.similarity == "cosine":
|
||||
# cosine similarity score = 1 - cosine distance
|
||||
score = 1 - cosine(query_emb, doc.embedding)
|
||||
new_document.score = score
|
||||
new_document.probability = (score + 1) / 2
|
||||
|
||||
candidate_docs.append(new_document)
|
||||
|
||||
return sorted(candidate_docs, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)[0:top_k]
|
||||
|
||||
@ -87,6 +87,8 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
self.index = index
|
||||
self.label_index = label_index
|
||||
self.update_existing_documents = update_existing_documents
|
||||
if getattr(self, "similarity", None) is None:
|
||||
self.similarity = None
|
||||
|
||||
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
||||
"""Fetch a document by specifying its text id string"""
|
||||
|
||||
@ -6,6 +6,7 @@ from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack import Document
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
|
||||
@ -86,6 +87,11 @@ class DensePassageRetriever(BaseRetriever):
|
||||
self.max_seq_len_passage = max_seq_len_passage
|
||||
self.max_seq_len_query = max_seq_len_query
|
||||
|
||||
if document_store.similarity != "dot_product":
|
||||
logger.warning(f"You are using a Dense Passage Retriever model with the {document_store.similarity} function. "
|
||||
"We recommend you use dot_product instead. "
|
||||
"This can be set when initializing the DocumentStore")
|
||||
|
||||
if use_gpu and torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
@ -399,6 +405,18 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
embedding_model, task_type="embeddings", extraction_strategy=self.pooling_strategy,
|
||||
extraction_layer=self.emb_extraction_layer, gpu=use_gpu, batch_size=4, max_seq_len=512, num_processes=0
|
||||
)
|
||||
# Check that document_store has the right similarity function
|
||||
similarity = document_store.similarity
|
||||
# If we are using a sentence transformer model
|
||||
if "sentence" in embedding_model.lower() and similarity != "cosine":
|
||||
logger.warning(f"You seem to be using a Sentence Transformer with the {similarity} function. "
|
||||
f"We recommend using cosine instead. "
|
||||
f"This can be set when initializing the DocumentStore")
|
||||
elif "dpr" in embedding_model.lower() and similarity != "dot_product":
|
||||
logger.warning(f"You seem to be using a DPR model with the {similarity} function. "
|
||||
f"We recommend using dot_product instead. "
|
||||
f"This can be set when initializing the DocumentStore")
|
||||
|
||||
|
||||
elif model_format == "sentence_transformers":
|
||||
try:
|
||||
@ -414,6 +432,11 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
else:
|
||||
device = "cpu"
|
||||
self.embedding_model = SentenceTransformer(embedding_model, device=device)
|
||||
if document_store.similarity != "cosine":
|
||||
logger.warning(
|
||||
f"You are using a Sentence Transformer with the {document_store.similarity} function. "
|
||||
f"We recommend using cosine instead. "
|
||||
f"This can be set when initializing the DocumentStore")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -44,7 +44,8 @@ document_store = ElasticsearchDocumentStore(host="localhost", username="", passw
|
||||
index="document",
|
||||
embedding_field="question_emb",
|
||||
embedding_dim=768,
|
||||
excluded_meta_data=["question_emb"])
|
||||
excluded_meta_data=["question_emb"],
|
||||
similarity="cosine")
|
||||
|
||||
### Create a Retriever using embeddings
|
||||
# Instead of retrieving via Elasticsearch's plain BM25, we want to use vector similarity of the questions (user question vs. FAQ ones).
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user