DPR (Dense Retriever) for InMemoryDocumentStore #316 (#332)

This commit is contained in:
venuraja79 2020-08-24 18:18:36 +05:30 committed by GitHub
parent 3a42eb663e
commit ea334658d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,10 @@ from collections import defaultdict
from haystack.database.base import BaseDocumentStore, Document, Label
from haystack.indexing.utils import eval_data_from_file
from haystack.retriever.base import BaseRetriever
import logging
logger = logging.getLogger(__name__)
class InMemoryDocumentStore(BaseDocumentStore):
"""
@ -15,6 +18,8 @@ class InMemoryDocumentStore(BaseDocumentStore):
self.indexes: Dict[str, Dict] = defaultdict(dict)
self.index: str = "document"
self.label_index: str = "label"
self.embedding_field: str = "embedding"
self.embedding_dim : int = 768
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
"""
@ -86,16 +91,34 @@ class InMemoryDocumentStore(BaseDocumentStore):
return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k]
def update_embeddings(self, retriever):
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever
:param index: Index name to update
:return: None
"""
#TODO
raise NotImplementedError("update_embeddings() is not yet implemented for this DocumentStore")
if index is None:
index = self.index
if not self.embedding_field:
raise RuntimeError("Specify the arg embedding_field when initializing InMemoryDocumentStore()")
# TODO Index embeddings every X batches to avoid OOM for huge document collections
docs = self.get_all_documents(index)
logger.info(f"Updating embeddings for {len(docs)} docs ...")
embeddings = retriever.embed_passages(docs) # type: ignore
assert len(docs) == len(embeddings)
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
f" doesn't match embedding dim. in documentstore ({self.embedding_dim})."
"Specify the arg `embedding_dim` when initializing InMemoryDocumentStore()")
for doc, emb in zip(docs, embeddings):
self.indexes[index][doc.id].embedding = emb
def get_document_count(self, index: Optional[str] = None) -> int:
index = index or self.index