mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-27 07:48:43 +00:00
refactor: InMemoryDocumentStore - manage documents without embedding & fix mypy errors (#4113)
* refactoring and test * try to replace error with warning * more expressive and robust get_scores methods * make get_scores methods internal
This commit is contained in:
parent
d86a511cc1
commit
24405f851c
@ -281,47 +281,60 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]: # type: ignore
|
def get_documents_by_id(
|
||||||
|
self,
|
||||||
|
ids: List[str],
|
||||||
|
index: Optional[str] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Fetch documents by specifying a list of text id strings.
|
Fetch documents by specifying a list of text id strings.
|
||||||
"""
|
"""
|
||||||
|
if headers:
|
||||||
|
raise NotImplementedError("InMemoryDocumentStore does not support headers.")
|
||||||
|
if batch_size:
|
||||||
|
logger.warning(
|
||||||
|
"InMemoryDocumentStore does not support batching in `get_documents_by_id` method. This parameter is ignored."
|
||||||
|
)
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
documents = [self.indexes[index][id] for id in ids]
|
documents = [self.indexes[index][id] for id in ids]
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def get_scores_torch(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
|
def _get_scores_torch(self, query_emb: np.ndarray, documents_to_search: List[Document]) -> List[float]:
|
||||||
"""
|
"""
|
||||||
Calculate similarity scores between query embedding and a list of documents using torch.
|
Calculate similarity scores between query embedding and a list of documents using torch.
|
||||||
|
|
||||||
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
||||||
:param document_to_search: List of documents to compare `query_emb` against.
|
:param documents_to_search: List of documents to compare `query_emb` against.
|
||||||
"""
|
"""
|
||||||
query_emb = torch.tensor(query_emb, dtype=torch.float).to(self.main_device) # type: ignore [assignment]
|
query_emb_tensor = torch.tensor(query_emb, dtype=torch.float).to(self.main_device)
|
||||||
if len(query_emb.shape) == 1:
|
if query_emb_tensor.ndim == 1:
|
||||||
query_emb = query_emb.unsqueeze(dim=0) # type: ignore [attr-defined]
|
query_emb_tensor = query_emb_tensor.unsqueeze(dim=0)
|
||||||
|
|
||||||
doc_embeds = np.array([doc.embedding for doc in document_to_search])
|
doc_embeds = np.array([doc.embedding for doc in documents_to_search])
|
||||||
doc_embeds = torch.as_tensor(doc_embeds, dtype=torch.float) # type: ignore [assignment]
|
doc_embeds_tensor = torch.as_tensor(doc_embeds, dtype=torch.float)
|
||||||
if len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 1:
|
if doc_embeds_tensor.ndim == 1:
|
||||||
doc_embeds = doc_embeds.unsqueeze(dim=0) # type: ignore [attr-defined]
|
# if there are no embeddings, return an empty list
|
||||||
elif len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 0:
|
if doc_embeds_tensor.shape[0] == 0:
|
||||||
return []
|
return []
|
||||||
|
doc_embeds_tensor = doc_embeds_tensor.unsqueeze(dim=0)
|
||||||
|
|
||||||
if self.similarity == "cosine":
|
if self.similarity == "cosine":
|
||||||
# cosine similarity is just a normed dot product
|
# cosine similarity is just a normed dot product
|
||||||
query_emb_norm = torch.norm(query_emb, dim=1)
|
query_emb_norm = torch.norm(query_emb_tensor, dim=1)
|
||||||
query_emb = torch.div(query_emb, query_emb_norm) # type: ignore [assignment,arg-type]
|
query_emb_tensor = torch.div(query_emb_tensor, query_emb_norm)
|
||||||
|
|
||||||
doc_embeds_norms = torch.norm(doc_embeds, dim=1)
|
doc_embeds_norms = torch.norm(doc_embeds_tensor, dim=1)
|
||||||
doc_embeds = torch.div(doc_embeds.T, doc_embeds_norms).T # type: ignore [assignment,arg-type]
|
doc_embeds_tensor = torch.div(doc_embeds_tensor.T, doc_embeds_norms).T
|
||||||
|
|
||||||
curr_pos = 0
|
curr_pos = 0
|
||||||
scores = [] # type: ignore [var-annotated]
|
scores: List[float] = []
|
||||||
while curr_pos < len(doc_embeds):
|
while curr_pos < len(doc_embeds_tensor):
|
||||||
doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size]
|
doc_embeds_slice = doc_embeds_tensor[curr_pos : curr_pos + self.scoring_batch_size]
|
||||||
doc_embeds_slice = doc_embeds_slice.to(self.main_device) # type: ignore [attr-defined]
|
doc_embeds_slice = doc_embeds_slice.to(self.main_device)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu() # type: ignore [arg-type,arg-type]
|
slice_scores = torch.matmul(doc_embeds_slice, query_emb_tensor.T).cpu()
|
||||||
slice_scores = slice_scores.squeeze(dim=1)
|
slice_scores = slice_scores.squeeze(dim=1)
|
||||||
slice_scores = slice_scores.numpy().tolist()
|
slice_scores = slice_scores.numpy().tolist()
|
||||||
|
|
||||||
@ -330,21 +343,22 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def get_scores_numpy(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
|
def _get_scores_numpy(self, query_emb: np.ndarray, documents_to_search: List[Document]) -> List[float]:
|
||||||
"""
|
"""
|
||||||
Calculate similarity scores between query embedding and a list of documents using numpy.
|
Calculate similarity scores between query embedding and a list of documents using numpy.
|
||||||
|
|
||||||
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
||||||
:param document_to_search: List of documents to compare `query_emb` against.
|
:param documents_to_search: List of documents to compare `query_emb` against.
|
||||||
"""
|
"""
|
||||||
if len(query_emb.shape) == 1:
|
if query_emb.ndim == 1:
|
||||||
query_emb = np.expand_dims(query_emb, 0)
|
query_emb = np.expand_dims(a=query_emb, axis=0)
|
||||||
|
|
||||||
doc_embeds = np.array([doc.embedding for doc in document_to_search])
|
doc_embeds = np.array([doc.embedding for doc in documents_to_search])
|
||||||
if len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 1:
|
if doc_embeds.ndim == 1:
|
||||||
doc_embeds = doc_embeds.unsqueeze(dim=0) # type: ignore [attr-defined]
|
# if there are no embeddings, return an empty list
|
||||||
elif len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 0:
|
if doc_embeds.shape[0] == 0:
|
||||||
return []
|
return []
|
||||||
|
doc_embeds = np.expand_dims(a=doc_embeds, axis=0)
|
||||||
|
|
||||||
if self.similarity == "cosine":
|
if self.similarity == "cosine":
|
||||||
# cosine similarity is just a normed dot product
|
# cosine similarity is just a normed dot product
|
||||||
@ -360,11 +374,11 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def get_scores(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
|
def _get_scores(self, query_emb: np.ndarray, documents_to_search: List[Document]) -> List[float]:
|
||||||
if self.main_device.type == "cuda":
|
if self.main_device.type == "cuda":
|
||||||
scores = self.get_scores_torch(query_emb, document_to_search)
|
scores = self._get_scores_torch(query_emb, documents_to_search)
|
||||||
else:
|
else:
|
||||||
scores = self.get_scores_numpy(query_emb, document_to_search)
|
scores = self._get_scores_numpy(query_emb, documents_to_search)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@ -460,11 +474,17 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
|||||||
if query_emb is None:
|
if query_emb is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
document_to_search = self.get_all_documents(index=index, filters=filters, return_embedding=True)
|
documents = self.get_all_documents(index=index, filters=filters, return_embedding=True)
|
||||||
scores = self.get_scores(query_emb, document_to_search)
|
documents_with_embeddings = [doc for doc in documents if doc.embedding is not None]
|
||||||
|
if len(documents) != len(documents_with_embeddings):
|
||||||
|
logger.warning(
|
||||||
|
"Skipping some of your documents that don't have embeddings. "
|
||||||
|
"To generate embeddings, run the document store's update_embeddings() method."
|
||||||
|
)
|
||||||
|
scores = self._get_scores(query_emb, documents_with_embeddings)
|
||||||
|
|
||||||
candidate_docs = []
|
candidate_docs = []
|
||||||
for doc, score in zip(document_to_search, scores):
|
for doc, score in zip(documents_with_embeddings, scores):
|
||||||
curr_meta = deepcopy(doc.meta)
|
curr_meta = deepcopy(doc.meta)
|
||||||
new_document = Document(
|
new_document = Document(
|
||||||
id=doc.id, content=doc.content, content_type=doc.content_type, meta=curr_meta, embedding=doc.embedding
|
id=doc.id, content=doc.content, content_type=doc.content_type, meta=curr_meta, embedding=doc.embedding
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import logging
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
from rank_bm25 import BM25
|
from rank_bm25 import BM25
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
@ -112,3 +113,15 @@ class TestInMemoryDocumentStore(DocumentStoreBaseTestAbstract):
|
|||||||
for docs, query_emb in zip(docs_batch, query_embs):
|
for docs, query_emb in zip(docs_batch, query_embs):
|
||||||
assert len(docs) == 5
|
assert len(docs) == 5
|
||||||
assert (docs[0].embedding == query_emb).all()
|
assert (docs[0].embedding == query_emb).all()
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_memory_query_by_embedding_docs_wo_embeddings(self, ds, caplog):
|
||||||
|
# write document but don't update embeddings
|
||||||
|
ds.write_documents([Document(content="test Document")])
|
||||||
|
|
||||||
|
query_embedding = np.random.rand(768).astype(np.float32)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
docs = ds.query_by_embedding(query_emb=query_embedding, top_k=1)
|
||||||
|
assert "Skipping some of your documents that don't have embeddings" in caplog.text
|
||||||
|
assert len(docs) == 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user