refactor: InMemoryDocumentStore - manage documents without embedding & fix mypy errors (#4113)

* refactoring and test

* try to replace error with warning

* more expressive and robust get_scores methods

* make get_scores methods internal
This commit is contained in:
Stefano Fiorucci 2023-02-14 17:43:11 +01:00 committed by GitHub
parent d86a511cc1
commit 24405f851c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 36 deletions

View File

@ -281,47 +281,60 @@ class InMemoryDocumentStore(KeywordDocumentStore):
else:
return None
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]: # type: ignore
def get_documents_by_id(
self,
ids: List[str],
index: Optional[str] = None,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""
Fetch documents by specifying a list of text id strings.
"""
if headers:
raise NotImplementedError("InMemoryDocumentStore does not support headers.")
if batch_size:
logger.warning(
"InMemoryDocumentStore does not support batching in `get_documents_by_id` method. This parameter is ignored."
)
index = index or self.index
documents = [self.indexes[index][id] for id in ids]
return documents
def get_scores_torch(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
def _get_scores_torch(self, query_emb: np.ndarray, documents_to_search: List[Document]) -> List[float]:
"""
Calculate similarity scores between query embedding and a list of documents using torch.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param document_to_search: List of documents to compare `query_emb` against.
:param documents_to_search: List of documents to compare `query_emb` against.
"""
query_emb = torch.tensor(query_emb, dtype=torch.float).to(self.main_device) # type: ignore [assignment]
if len(query_emb.shape) == 1:
query_emb = query_emb.unsqueeze(dim=0) # type: ignore [attr-defined]
query_emb_tensor = torch.tensor(query_emb, dtype=torch.float).to(self.main_device)
if query_emb_tensor.ndim == 1:
query_emb_tensor = query_emb_tensor.unsqueeze(dim=0)
doc_embeds = np.array([doc.embedding for doc in document_to_search])
doc_embeds = torch.as_tensor(doc_embeds, dtype=torch.float) # type: ignore [assignment]
if len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 1:
doc_embeds = doc_embeds.unsqueeze(dim=0) # type: ignore [attr-defined]
elif len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 0:
doc_embeds = np.array([doc.embedding for doc in documents_to_search])
doc_embeds_tensor = torch.as_tensor(doc_embeds, dtype=torch.float)
if doc_embeds_tensor.ndim == 1:
# if there are no embeddings, return an empty list
if doc_embeds_tensor.shape[0] == 0:
return []
doc_embeds_tensor = doc_embeds_tensor.unsqueeze(dim=0)
if self.similarity == "cosine":
# cosine similarity is just a normed dot product
query_emb_norm = torch.norm(query_emb, dim=1)
query_emb = torch.div(query_emb, query_emb_norm) # type: ignore [assignment,arg-type]
query_emb_norm = torch.norm(query_emb_tensor, dim=1)
query_emb_tensor = torch.div(query_emb_tensor, query_emb_norm)
doc_embeds_norms = torch.norm(doc_embeds, dim=1)
doc_embeds = torch.div(doc_embeds.T, doc_embeds_norms).T # type: ignore [assignment,arg-type]
doc_embeds_norms = torch.norm(doc_embeds_tensor, dim=1)
doc_embeds_tensor = torch.div(doc_embeds_tensor.T, doc_embeds_norms).T
curr_pos = 0
scores = [] # type: ignore [var-annotated]
while curr_pos < len(doc_embeds):
doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size]
doc_embeds_slice = doc_embeds_slice.to(self.main_device) # type: ignore [attr-defined]
scores: List[float] = []
while curr_pos < len(doc_embeds_tensor):
doc_embeds_slice = doc_embeds_tensor[curr_pos : curr_pos + self.scoring_batch_size]
doc_embeds_slice = doc_embeds_slice.to(self.main_device)
with torch.inference_mode():
slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu() # type: ignore [arg-type,arg-type]
slice_scores = torch.matmul(doc_embeds_slice, query_emb_tensor.T).cpu()
slice_scores = slice_scores.squeeze(dim=1)
slice_scores = slice_scores.numpy().tolist()
@ -330,21 +343,22 @@ class InMemoryDocumentStore(KeywordDocumentStore):
return scores
def get_scores_numpy(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
def _get_scores_numpy(self, query_emb: np.ndarray, documents_to_search: List[Document]) -> List[float]:
"""
Calculate similarity scores between query embedding and a list of documents using numpy.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param document_to_search: List of documents to compare `query_emb` against.
:param documents_to_search: List of documents to compare `query_emb` against.
"""
if len(query_emb.shape) == 1:
query_emb = np.expand_dims(query_emb, 0)
if query_emb.ndim == 1:
query_emb = np.expand_dims(a=query_emb, axis=0)
doc_embeds = np.array([doc.embedding for doc in document_to_search])
if len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 1:
doc_embeds = doc_embeds.unsqueeze(dim=0) # type: ignore [attr-defined]
elif len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 0:
doc_embeds = np.array([doc.embedding for doc in documents_to_search])
if doc_embeds.ndim == 1:
# if there are no embeddings, return an empty list
if doc_embeds.shape[0] == 0:
return []
doc_embeds = np.expand_dims(a=doc_embeds, axis=0)
if self.similarity == "cosine":
# cosine similarity is just a normed dot product
@ -360,11 +374,11 @@ class InMemoryDocumentStore(KeywordDocumentStore):
return scores
def get_scores(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
def _get_scores(self, query_emb: np.ndarray, documents_to_search: List[Document]) -> List[float]:
if self.main_device.type == "cuda":
scores = self.get_scores_torch(query_emb, document_to_search)
scores = self._get_scores_torch(query_emb, documents_to_search)
else:
scores = self.get_scores_numpy(query_emb, document_to_search)
scores = self._get_scores_numpy(query_emb, documents_to_search)
return scores
@ -460,11 +474,17 @@ class InMemoryDocumentStore(KeywordDocumentStore):
if query_emb is None:
return []
document_to_search = self.get_all_documents(index=index, filters=filters, return_embedding=True)
scores = self.get_scores(query_emb, document_to_search)
documents = self.get_all_documents(index=index, filters=filters, return_embedding=True)
documents_with_embeddings = [doc for doc in documents if doc.embedding is not None]
if len(documents) != len(documents_with_embeddings):
logger.warning(
"Skipping some of your documents that don't have embeddings. "
"To generate embeddings, run the document store's update_embeddings() method."
)
scores = self._get_scores(query_emb, documents_with_embeddings)
candidate_docs = []
for doc, score in zip(document_to_search, scores):
for doc, score in zip(documents_with_embeddings, scores):
curr_meta = deepcopy(doc.meta)
new_document = Document(
id=doc.id, content=doc.content, content_type=doc.content_type, meta=curr_meta, embedding=doc.embedding

View File

@ -3,6 +3,7 @@ import logging
import pandas as pd
import pytest
from rank_bm25 import BM25
import numpy as np
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.schema import Document
@ -112,3 +113,15 @@ class TestInMemoryDocumentStore(DocumentStoreBaseTestAbstract):
for docs, query_emb in zip(docs_batch, query_embs):
assert len(docs) == 5
assert (docs[0].embedding == query_emb).all()
@pytest.mark.integration
def test_memory_query_by_embedding_docs_wo_embeddings(self, ds, caplog):
# write document but don't update embeddings
ds.write_documents([Document(content="test Document")])
query_embedding = np.random.rand(768).astype(np.float32)
with caplog.at_level(logging.WARNING):
docs = ds.query_by_embedding(query_emb=query_embedding, top_k=1)
assert "Skipping some of your documents that don't have embeddings" in caplog.text
assert len(docs) == 0