Scale dot product into probabilities (#667)

* scale dot product

* Add tip in documentation

* Add recommendation boxes

* WIP: Use similarity attribute in all doc stores

* Implement similarity for InMemoryDS

* Add FAISS support

* Clean printout

* Update documentation

* Implement document field map
This commit is contained in:
Branden Chan 2020-12-11 12:10:24 +01:00 committed by GitHub
parent a0e146dde6
commit d8154939fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 93 additions and 13 deletions

View File

@ -122,6 +122,16 @@ Indexing using DPR is comparatively expensive in terms of required computation s
The embeddings that are created in this step can be stored in FAISS, a database optimized for vector similarity.
DPR can also work with the ElasticsearchDocumentStore or the InMemoryDocumentStore.
<div class="recommendation">
**Tip**
When using DPR, it is recommended that you use the dot product similarity function since that is how it is trained.
To do so, simply provide `similarity="dot_product"` when initializing the DocumentStore
as is done in the code example below.
</div>
There are two design decisions that have made DPR particularly performant.
@ -136,7 +146,7 @@ If youd like to learn how to set up a DPR based system, have a look at our tu
### Initialisation
```python
document_store = FAISSDocumentStore()
document_store = FAISSDocumentStore(similarity="dot_product")
...
retriever = DensePassageRetriever(
document_store=document_store,
@ -161,10 +171,20 @@ They are particular suited to cases where your query input is similar in style t
i.e. when you are searching for most similar documents.
This is not inherently suited to query based search where the length, language and format of the query usually significantly differs from the searched for text.
<div class="recommendation">
**Tip**
When using Sentence Transformer models, we recommend that you use a cosine similarity function.
To do so, simply provide `similarity="cosine"` when initializing the DocumentStore
as is done in the code example below.
</div>
### Initialisation
```python
document_store = ElasticsearchDocumentStore()
document_store = ElasticsearchDocumentStore(similarity="cosine")
...
retriever = EmbeddingRetriever(document_store=document_store,
embedding_model="deepset/sentence_bert")

View File

@ -12,6 +12,7 @@ class BaseDocumentStore(ABC):
"""
index: Optional[str]
label_index: Optional[str]
similarity: Optional[str]
@abstractmethod
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):

View File

@ -117,6 +117,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
self.update_existing_documents = update_existing_documents
self.refresh_type = refresh_type
self.similarity = similarity
if similarity == "cosine":
self.similarity_fn_name = "cosineSimilarity"
elif similarity == "dot_product":
@ -596,7 +597,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
if score:
if adapt_score_for_embedding:
score -= 1000
probability = (score + 1) / 2 # scaling probability from cosine similarity
if self.similarity == "cosine":
probability = (score + 1) / 2 # scaling probability from cosine similarity
elif self.similarity == "dot_product":
probability = float(expit(np.asarray(score / 100))) # scaling probability from dot product
else:
probability = float(expit(np.asarray(score / 8))) # scaling probability from TFIDF/BM25
else:

View File

@ -9,6 +9,8 @@ from haystack import Document
from haystack.document_store.sql import SQLDocumentStore
from haystack.retriever.base import BaseRetriever
from scipy.special import expit
if platform != 'win32' and platform != 'cygwin':
import faiss
else:
@ -39,6 +41,7 @@ class FAISSDocumentStore(SQLDocumentStore):
return_embedding: bool = False,
update_existing_documents: bool = False,
index: str = "document",
similarity: str = "dot_product",
**kwargs,
):
"""
@ -82,6 +85,11 @@ class FAISSDocumentStore(SQLDocumentStore):
self.index_buffer_size = index_buffer_size
self.return_embedding = return_embedding
if similarity == "dot_product":
self.similarity = similarity
else:
raise ValueError("The FAISS document store can currently only support dot_product similarity. "
"Please set similarity=\"dot_product\"")
super().__init__(
url=sql_url,
update_existing_documents=update_existing_documents,
@ -116,7 +124,8 @@ class FAISSDocumentStore(SQLDocumentStore):
# doc + metadata index
index = index or self.index
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
field_map = self._create_document_field_map()
document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
add_vectors = False if document_objects[0].embedding is None else True
@ -142,6 +151,11 @@ class FAISSDocumentStore(SQLDocumentStore):
super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index)
def _create_document_field_map(self) -> Dict:
return {
self.index: "embedding",
}
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
@ -273,7 +287,7 @@ class FAISSDocumentStore(SQLDocumentStore):
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
for doc in documents:
doc.score = scores_for_vector_ids[doc.meta["vector_id"]]
doc.probability = (doc.score + 1) / 2
doc.probability = float(expit(np.asarray(doc.score / 100)))
if return_embedding is True:
doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"]))

View File

@ -8,6 +8,8 @@ from haystack import Document, Label
from haystack.preprocessor.utils import eval_data_from_file
from haystack.retriever.base import BaseRetriever
from scipy.spatial.distance import cosine
import logging
logger = logging.getLogger(__name__)
@ -17,13 +19,14 @@ class InMemoryDocumentStore(BaseDocumentStore):
In-memory document store
"""
def __init__(self, embedding_field: Optional[str] = "embedding", return_embedding: bool = False):
def __init__(self, embedding_field: Optional[str] = "embedding", return_embedding: bool = False, similarity="dot_product"):
self.indexes: Dict[str, Dict] = defaultdict(dict)
self.index: str = "document"
self.label_index: str = "label"
self.embedding_field: str = embedding_field if embedding_field is not None else "embedding"
self.embedding_dim: int = 768
self.return_embedding: bool = return_embedding
self.similarity: str = similarity
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
"""
@ -41,11 +44,17 @@ class InMemoryDocumentStore(BaseDocumentStore):
"""
index = index or self.index
documents_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
field_map = self._create_document_field_map()
documents_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
for document in documents_objects:
self.indexes[index][document.id] = document
def _create_document_field_map(self):
return {
self.embedding_field: "embedding",
}
def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[str] = None):
"""Write annotation labels into document store."""
index = index or self.label_index
@ -106,18 +115,24 @@ class InMemoryDocumentStore(BaseDocumentStore):
candidate_docs = []
for idx, doc in self.indexes[index].items():
curr_meta = deepcopy(doc.meta)
new_document = Document(
id=doc.id,
text=doc.text,
meta=deepcopy(doc.meta)
meta=curr_meta,
embedding=doc.embedding
)
new_document.embedding = doc.embedding if return_embedding is True else None
score = dot(query_emb, doc.embedding) / (
norm(query_emb) * norm(doc.embedding)
)
if self.similarity == "dot_product":
score = dot(query_emb, doc.embedding) / (
norm(query_emb) * norm(doc.embedding)
)
elif self.similarity == "cosine":
# cosine similarity score = 1 - cosine distance
score = 1 - cosine(query_emb, doc.embedding)
new_document.score = score
new_document.probability = (score + 1) / 2
candidate_docs.append(new_document)
return sorted(candidate_docs, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)[0:top_k]

View File

@ -87,6 +87,8 @@ class SQLDocumentStore(BaseDocumentStore):
self.index = index
self.label_index = label_index
self.update_existing_documents = update_existing_documents
if getattr(self, "similarity", None) is None:
self.similarity = None
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""

View File

@ -6,6 +6,7 @@ from pathlib import Path
from tqdm import tqdm
from haystack.document_store.base import BaseDocumentStore
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack import Document
from haystack.retriever.base import BaseRetriever
@ -86,6 +87,11 @@ class DensePassageRetriever(BaseRetriever):
self.max_seq_len_passage = max_seq_len_passage
self.max_seq_len_query = max_seq_len_query
if document_store.similarity != "dot_product":
logger.warning(f"You are using a Dense Passage Retriever model with the {document_store.similarity} function. "
"We recommend you use dot_product instead. "
"This can be set when initializing the DocumentStore")
if use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")
else:
@ -399,6 +405,18 @@ class EmbeddingRetriever(BaseRetriever):
embedding_model, task_type="embeddings", extraction_strategy=self.pooling_strategy,
extraction_layer=self.emb_extraction_layer, gpu=use_gpu, batch_size=4, max_seq_len=512, num_processes=0
)
# Check that document_store has the right similarity function
similarity = document_store.similarity
# If we are using a sentence transformer model
if "sentence" in embedding_model.lower() and similarity != "cosine":
logger.warning(f"You seem to be using a Sentence Transformer with the {similarity} function. "
f"We recommend using cosine instead. "
f"This can be set when initializing the DocumentStore")
elif "dpr" in embedding_model.lower() and similarity != "dot_product":
logger.warning(f"You seem to be using a DPR model with the {similarity} function. "
f"We recommend using dot_product instead. "
f"This can be set when initializing the DocumentStore")
elif model_format == "sentence_transformers":
try:
@ -414,6 +432,11 @@ class EmbeddingRetriever(BaseRetriever):
else:
device = "cpu"
self.embedding_model = SentenceTransformer(embedding_model, device=device)
if document_store.similarity != "cosine":
logger.warning(
f"You are using a Sentence Transformer with the {document_store.similarity} function. "
f"We recommend using cosine instead. "
f"This can be set when initializing the DocumentStore")
else:
raise NotImplementedError

View File

@ -44,7 +44,8 @@ document_store = ElasticsearchDocumentStore(host="localhost", username="", passw
index="document",
embedding_field="question_emb",
embedding_dim=768,
excluded_meta_data=["question_emb"])
excluded_meta_data=["question_emb"],
similarity="cosine")
### Create a Retriever using embeddings
# Instead of retrieving via Elasticsearch's plain BM25, we want to use vector similarity of the questions (user question vs. FAQ ones).