diff --git a/haystack/database/memory.py b/haystack/database/memory.py index ea548aff6..a658b6674 100644 --- a/haystack/database/memory.py +++ b/haystack/database/memory.py @@ -4,7 +4,10 @@ from collections import defaultdict from haystack.database.base import BaseDocumentStore, Document, Label from haystack.indexing.utils import eval_data_from_file +from haystack.retriever.base import BaseRetriever +import logging +logger = logging.getLogger(__name__) class InMemoryDocumentStore(BaseDocumentStore): """ @@ -15,6 +18,8 @@ class InMemoryDocumentStore(BaseDocumentStore): self.indexes: Dict[str, Dict] = defaultdict(dict) self.index: str = "document" self.label_index: str = "label" + self.embedding_field: str = "embedding" + self.embedding_dim : int = 768 def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None): """ @@ -86,16 +91,34 @@ class InMemoryDocumentStore(BaseDocumentStore): return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k] - def update_embeddings(self, retriever): + def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None): """ Updates the embeddings in the the document store using the encoding model specified in the retriever. This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config). :param retriever: Retriever + :param index: Index name to update :return: None """ - #TODO - raise NotImplementedError("update_embeddings() is not yet implemented for this DocumentStore") + if index is None: + index = self.index + + if not self.embedding_field: + raise RuntimeError("Specify the arg embedding_field when initializing InMemoryDocumentStore()") + + # TODO Index embeddings every X batches to avoid OOM for huge document collections + docs = self.get_all_documents(index) + logger.info(f"Updating embeddings for {len(docs)} docs ...") + embeddings = retriever.embed_passages(docs) # type: ignore + assert len(docs) == len(embeddings) + + if embeddings[0].shape[0] != self.embedding_dim: + raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})" + f" doesn't match embedding dim. in documentstore ({self.embedding_dim})." + "Specify the arg `embedding_dim` when initializing InMemoryDocumentStore()") + + for doc, emb in zip(docs, embeddings): + self.indexes[index][doc.id].embedding = emb def get_document_count(self, index: Optional[str] = None) -> int: index = index or self.index