mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 07:29:06 +00:00
parent
3a42eb663e
commit
ea334658d6
@ -4,7 +4,10 @@ from collections import defaultdict
|
||||
|
||||
from haystack.database.base import BaseDocumentStore, Document, Label
|
||||
from haystack.indexing.utils import eval_data_from_file
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class InMemoryDocumentStore(BaseDocumentStore):
|
||||
"""
|
||||
@ -15,6 +18,8 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
self.indexes: Dict[str, Dict] = defaultdict(dict)
|
||||
self.index: str = "document"
|
||||
self.label_index: str = "label"
|
||||
self.embedding_field: str = "embedding"
|
||||
self.embedding_dim : int = 768
|
||||
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
||||
"""
|
||||
@ -86,16 +91,34 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
|
||||
return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k]
|
||||
|
||||
def update_embeddings(self, retriever):
|
||||
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
|
||||
"""
|
||||
Updates the embeddings in the the document store using the encoding model specified in the retriever.
|
||||
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
|
||||
|
||||
:param retriever: Retriever
|
||||
:param index: Index name to update
|
||||
:return: None
|
||||
"""
|
||||
#TODO
|
||||
raise NotImplementedError("update_embeddings() is not yet implemented for this DocumentStore")
|
||||
if index is None:
|
||||
index = self.index
|
||||
|
||||
if not self.embedding_field:
|
||||
raise RuntimeError("Specify the arg embedding_field when initializing InMemoryDocumentStore()")
|
||||
|
||||
# TODO Index embeddings every X batches to avoid OOM for huge document collections
|
||||
docs = self.get_all_documents(index)
|
||||
logger.info(f"Updating embeddings for {len(docs)} docs ...")
|
||||
embeddings = retriever.embed_passages(docs) # type: ignore
|
||||
assert len(docs) == len(embeddings)
|
||||
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
|
||||
f" doesn't match embedding dim. in documentstore ({self.embedding_dim})."
|
||||
"Specify the arg `embedding_dim` when initializing InMemoryDocumentStore()")
|
||||
|
||||
for doc, emb in zip(docs, embeddings):
|
||||
self.indexes[index][doc.id].embedding = emb
|
||||
|
||||
def get_document_count(self, index: Optional[str] = None) -> int:
|
||||
index = index or self.index
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user