Fix indexing of metadata for FAISS/SQL Document Store (#310)

This commit is contained in:
Tanay Soni 2020-08-13 12:25:32 +02:00 committed by GitHub
parent 397dcf9d92
commit 089fecf99e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 8 deletions

View File

@ -412,7 +412,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
def _convert_es_hit_to_document(self, hit: dict, score_adjustment: int = 0) -> Document:
# We put all additional data of the doc into meta_data and return it in the API
meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.faq_question_field, self.embedding_field)}
meta_data["name"] = meta_data.pop(self.name_field, None)
name = meta_data.pop(self.name_field, None)
if name:
meta_data["name"] = name
document = Document(
id=hit["_id"],

View File

@ -68,14 +68,14 @@ class SQLDocumentStore(BaseDocumentStore):
self.index = index
self.label_index = "label"
def get_document_by_id(self, id: str, index=None) -> Optional[Document]:
index = index or self.index
document_row = self.session.query(DocumentORM).filter_by(index=index, id=id).first()
document = document_row or self._convert_sql_row_to_document(document_row)
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
documents = self.get_documents_by_id([id], index)
document = documents[0] if documents else None
return document
def get_documents_by_id(self, ids: List[str]) -> List[Document]:
results = self.session.query(DocumentORM).filter(DocumentORM.id.in_(ids)).all()
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
index = index or self.index
results = self.session.query(DocumentORM).filter(DocumentORM.id.in_(ids), DocumentORM.index == index).all()
documents = [self._convert_sql_row_to_document(row) for row in results]
return documents
@ -138,7 +138,8 @@ class SQLDocumentStore(BaseDocumentStore):
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
index = index or self.index
for doc in document_objects:
meta_orms = [MetaORM(name=key, value=value) for key, value in doc.meta.items()]
meta_fields = doc.meta or {}
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
doc_orm = DocumentORM(id=doc.id, text=doc.text, meta=meta_orms, index=index)
self.session.add(doc_orm)
self.session.commit()

View File

@ -4,6 +4,7 @@ from elasticsearch import Elasticsearch
from haystack.database.base import Document, Label
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.database.faiss import FAISSDocumentStore
def test_get_all_documents_without_filters(document_store_with_docs):
@ -42,6 +43,39 @@ def test_get_documents_by_id(document_store_with_docs):
assert doc.text == documents[0].text
def test_write_document_meta(document_store):
documents = [
{"text": "dict_without_meta", "id": "1"},
{"text": "dict_with_meta", "meta_field": "test2", "name": "filename2", "id": "2"},
Document(text="document_object_without_meta", id="3"),
Document(text="document_object_with_meta", meta={"meta_field": "test4", "name": "filename3"}, id="4"),
]
document_store.write_documents(documents)
documents_in_store = document_store.get_all_documents()
assert len(documents_in_store) == 4
assert not document_store.get_document_by_id("1").meta
assert document_store.get_document_by_id("2").meta["meta_field"] == "test2"
assert not document_store.get_document_by_id("3").meta
assert document_store.get_document_by_id("4").meta["meta_field"] == "test4"
def test_write_document_index(document_store):
documents = [
{"text": "text1", "id": "1"},
{"text": "text2", "id": "2"},
]
document_store.write_documents([documents[0]], index="haystack_test_1")
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
if not isinstance(document_store, FAISSDocumentStore): # addition of more documents is not supported in FAISS
document_store.write_documents([documents[1]], index="haystack_test_2")
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
assert len(document_store.get_all_documents()) == 0
def test_labels(document_store):
label = Label(
question="question",