From ea334658d6d52066553f008b4afa60f27eff3929 Mon Sep 17 00:00:00 2001 From: venuraja79 Date: Mon, 24 Aug 2020 18:18:36 +0530 Subject: [PATCH] DPR (Dense Retriever) for InMemoryDocumentStore #316 (#332) --- haystack/database/memory.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/haystack/database/memory.py b/haystack/database/memory.py index ea548aff6..a658b6674 100644 --- a/haystack/database/memory.py +++ b/haystack/database/memory.py @@ -4,7 +4,10 @@ from collections import defaultdict from haystack.database.base import BaseDocumentStore, Document, Label from haystack.indexing.utils import eval_data_from_file +from haystack.retriever.base import BaseRetriever +import logging +logger = logging.getLogger(__name__) class InMemoryDocumentStore(BaseDocumentStore): """ @@ -15,6 +18,8 @@ class InMemoryDocumentStore(BaseDocumentStore): self.indexes: Dict[str, Dict] = defaultdict(dict) self.index: str = "document" self.label_index: str = "label" + self.embedding_field: str = "embedding" + self.embedding_dim : int = 768 def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None): """ @@ -86,16 +91,34 @@ class InMemoryDocumentStore(BaseDocumentStore): return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k] - def update_embeddings(self, retriever): + def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None): """ Updates the embeddings in the the document store using the encoding model specified in the retriever. This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config). :param retriever: Retriever + :param index: Index name to update :return: None """ - #TODO - raise NotImplementedError("update_embeddings() is not yet implemented for this DocumentStore") + if index is None: + index = self.index + + if not self.embedding_field: + raise RuntimeError("Specify the arg embedding_field when initializing InMemoryDocumentStore()") + + # TODO Index embeddings every X batches to avoid OOM for huge document collections + docs = self.get_all_documents(index) + logger.info(f"Updating embeddings for {len(docs)} docs ...") + embeddings = retriever.embed_passages(docs) # type: ignore + assert len(docs) == len(embeddings) + + if embeddings[0].shape[0] != self.embedding_dim: + raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})" + f" doesn't match embedding dim. in documentstore ({self.embedding_dim})." + "Specify the arg `embedding_dim` when initializing InMemoryDocumentStore()") + + for doc, emb in zip(docs, embeddings): + self.indexes[index][doc.id].embedding = emb def get_document_count(self, index: Optional[str] = None) -> int: index = index or self.index