mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-27 07:48:43 +00:00
refactor: InMemoryDocumentStore - manage documents without embedding & fix mypy errors (#4113)
* refactoring and test * try to replace error with warning * more expressive and robust get_scores methods * make get_scores methods internal
This commit is contained in:
parent
d86a511cc1
commit
24405f851c
@ -281,47 +281,60 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]: # type: ignore
|
||||
def get_documents_by_id(
|
||||
self,
|
||||
ids: List[str],
|
||||
index: Optional[str] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Fetch documents by specifying a list of text id strings.
|
||||
"""
|
||||
if headers:
|
||||
raise NotImplementedError("InMemoryDocumentStore does not support headers.")
|
||||
if batch_size:
|
||||
logger.warning(
|
||||
"InMemoryDocumentStore does not support batching in `get_documents_by_id` method. This parameter is ignored."
|
||||
)
|
||||
index = index or self.index
|
||||
documents = [self.indexes[index][id] for id in ids]
|
||||
return documents
|
||||
|
||||
def get_scores_torch(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
|
||||
def _get_scores_torch(self, query_emb: np.ndarray, documents_to_search: List[Document]) -> List[float]:
|
||||
"""
|
||||
Calculate similarity scores between query embedding and a list of documents using torch.
|
||||
|
||||
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
||||
:param document_to_search: List of documents to compare `query_emb` against.
|
||||
:param documents_to_search: List of documents to compare `query_emb` against.
|
||||
"""
|
||||
query_emb = torch.tensor(query_emb, dtype=torch.float).to(self.main_device) # type: ignore [assignment]
|
||||
if len(query_emb.shape) == 1:
|
||||
query_emb = query_emb.unsqueeze(dim=0) # type: ignore [attr-defined]
|
||||
query_emb_tensor = torch.tensor(query_emb, dtype=torch.float).to(self.main_device)
|
||||
if query_emb_tensor.ndim == 1:
|
||||
query_emb_tensor = query_emb_tensor.unsqueeze(dim=0)
|
||||
|
||||
doc_embeds = np.array([doc.embedding for doc in document_to_search])
|
||||
doc_embeds = torch.as_tensor(doc_embeds, dtype=torch.float) # type: ignore [assignment]
|
||||
if len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 1:
|
||||
doc_embeds = doc_embeds.unsqueeze(dim=0) # type: ignore [attr-defined]
|
||||
elif len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 0:
|
||||
doc_embeds = np.array([doc.embedding for doc in documents_to_search])
|
||||
doc_embeds_tensor = torch.as_tensor(doc_embeds, dtype=torch.float)
|
||||
if doc_embeds_tensor.ndim == 1:
|
||||
# if there are no embeddings, return an empty list
|
||||
if doc_embeds_tensor.shape[0] == 0:
|
||||
return []
|
||||
doc_embeds_tensor = doc_embeds_tensor.unsqueeze(dim=0)
|
||||
|
||||
if self.similarity == "cosine":
|
||||
# cosine similarity is just a normed dot product
|
||||
query_emb_norm = torch.norm(query_emb, dim=1)
|
||||
query_emb = torch.div(query_emb, query_emb_norm) # type: ignore [assignment,arg-type]
|
||||
query_emb_norm = torch.norm(query_emb_tensor, dim=1)
|
||||
query_emb_tensor = torch.div(query_emb_tensor, query_emb_norm)
|
||||
|
||||
doc_embeds_norms = torch.norm(doc_embeds, dim=1)
|
||||
doc_embeds = torch.div(doc_embeds.T, doc_embeds_norms).T # type: ignore [assignment,arg-type]
|
||||
doc_embeds_norms = torch.norm(doc_embeds_tensor, dim=1)
|
||||
doc_embeds_tensor = torch.div(doc_embeds_tensor.T, doc_embeds_norms).T
|
||||
|
||||
curr_pos = 0
|
||||
scores = [] # type: ignore [var-annotated]
|
||||
while curr_pos < len(doc_embeds):
|
||||
doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size]
|
||||
doc_embeds_slice = doc_embeds_slice.to(self.main_device) # type: ignore [attr-defined]
|
||||
scores: List[float] = []
|
||||
while curr_pos < len(doc_embeds_tensor):
|
||||
doc_embeds_slice = doc_embeds_tensor[curr_pos : curr_pos + self.scoring_batch_size]
|
||||
doc_embeds_slice = doc_embeds_slice.to(self.main_device)
|
||||
with torch.inference_mode():
|
||||
slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu() # type: ignore [arg-type,arg-type]
|
||||
slice_scores = torch.matmul(doc_embeds_slice, query_emb_tensor.T).cpu()
|
||||
slice_scores = slice_scores.squeeze(dim=1)
|
||||
slice_scores = slice_scores.numpy().tolist()
|
||||
|
||||
@ -330,21 +343,22 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
|
||||
return scores
|
||||
|
||||
def get_scores_numpy(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
|
||||
def _get_scores_numpy(self, query_emb: np.ndarray, documents_to_search: List[Document]) -> List[float]:
|
||||
"""
|
||||
Calculate similarity scores between query embedding and a list of documents using numpy.
|
||||
|
||||
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
||||
:param document_to_search: List of documents to compare `query_emb` against.
|
||||
:param documents_to_search: List of documents to compare `query_emb` against.
|
||||
"""
|
||||
if len(query_emb.shape) == 1:
|
||||
query_emb = np.expand_dims(query_emb, 0)
|
||||
if query_emb.ndim == 1:
|
||||
query_emb = np.expand_dims(a=query_emb, axis=0)
|
||||
|
||||
doc_embeds = np.array([doc.embedding for doc in document_to_search])
|
||||
if len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 1:
|
||||
doc_embeds = doc_embeds.unsqueeze(dim=0) # type: ignore [attr-defined]
|
||||
elif len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 0:
|
||||
doc_embeds = np.array([doc.embedding for doc in documents_to_search])
|
||||
if doc_embeds.ndim == 1:
|
||||
# if there are no embeddings, return an empty list
|
||||
if doc_embeds.shape[0] == 0:
|
||||
return []
|
||||
doc_embeds = np.expand_dims(a=doc_embeds, axis=0)
|
||||
|
||||
if self.similarity == "cosine":
|
||||
# cosine similarity is just a normed dot product
|
||||
@ -360,11 +374,11 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
|
||||
return scores
|
||||
|
||||
def get_scores(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
|
||||
def _get_scores(self, query_emb: np.ndarray, documents_to_search: List[Document]) -> List[float]:
|
||||
if self.main_device.type == "cuda":
|
||||
scores = self.get_scores_torch(query_emb, document_to_search)
|
||||
scores = self._get_scores_torch(query_emb, documents_to_search)
|
||||
else:
|
||||
scores = self.get_scores_numpy(query_emb, document_to_search)
|
||||
scores = self._get_scores_numpy(query_emb, documents_to_search)
|
||||
|
||||
return scores
|
||||
|
||||
@ -460,11 +474,17 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
if query_emb is None:
|
||||
return []
|
||||
|
||||
document_to_search = self.get_all_documents(index=index, filters=filters, return_embedding=True)
|
||||
scores = self.get_scores(query_emb, document_to_search)
|
||||
documents = self.get_all_documents(index=index, filters=filters, return_embedding=True)
|
||||
documents_with_embeddings = [doc for doc in documents if doc.embedding is not None]
|
||||
if len(documents) != len(documents_with_embeddings):
|
||||
logger.warning(
|
||||
"Skipping some of your documents that don't have embeddings. "
|
||||
"To generate embeddings, run the document store's update_embeddings() method."
|
||||
)
|
||||
scores = self._get_scores(query_emb, documents_with_embeddings)
|
||||
|
||||
candidate_docs = []
|
||||
for doc, score in zip(document_to_search, scores):
|
||||
for doc, score in zip(documents_with_embeddings, scores):
|
||||
curr_meta = deepcopy(doc.meta)
|
||||
new_document = Document(
|
||||
id=doc.id, content=doc.content, content_type=doc.content_type, meta=curr_meta, embedding=doc.embedding
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from rank_bm25 import BM25
|
||||
import numpy as np
|
||||
|
||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||
from haystack.schema import Document
|
||||
@ -112,3 +113,15 @@ class TestInMemoryDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
for docs, query_emb in zip(docs_batch, query_embs):
|
||||
assert len(docs) == 5
|
||||
assert (docs[0].embedding == query_emb).all()
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_memory_query_by_embedding_docs_wo_embeddings(self, ds, caplog):
|
||||
# write document but don't update embeddings
|
||||
ds.write_documents([Document(content="test Document")])
|
||||
|
||||
query_embedding = np.random.rand(768).astype(np.float32)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
docs = ds.query_by_embedding(query_emb=query_embedding, top_k=1)
|
||||
assert "Skipping some of your documents that don't have embeddings" in caplog.text
|
||||
assert len(docs) == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user