Fix document filtering in SQLDocumentStore (#396)

This commit is contained in:
Tanay Soni 2020-09-18 12:22:52 +02:00 committed by GitHub
parent 3399fc784d
commit 0859da8f74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 23 deletions

View File

@ -26,7 +26,7 @@ class DocumentORM(ORMBase):
text = Column(String, nullable=False)
index = Column(String, nullable=False)
meta = relationship("MetaORM", secondary="document_meta", backref="Document")
meta = relationship("MetaORM", backref="Document")
class MetaORM(ORMBase):
@ -34,15 +34,9 @@ class MetaORM(ORMBase):
name = Column(String, index=True)
value = Column(String, index=True)
documents = relationship(DocumentORM, secondary="document_meta", backref="Meta")
class DocumentMetaORM(ORMBase):
__tablename__ = "document_meta"
document_id = Column(String, ForeignKey("document.id"), nullable=False)
meta_id = Column(Integer, ForeignKey("meta.id"), nullable=False)
documents = relationship(DocumentORM, backref="Meta")
class LabelORM(ORMBase):
@ -88,9 +82,9 @@ class SQLDocumentStore(BaseDocumentStore):
query = self.session.query(DocumentORM).filter_by(index=index)
if filters:
query = query.join(MetaORM)
for key, values in filters.items():
query = query.filter(DocumentORM.meta.any(MetaORM.name.in_([key])))\
.filter(DocumentORM.meta.any(MetaORM.value.in_(values)))
query = query.filter(MetaORM.name == key, MetaORM.value.in_(values))
documents = [self._convert_sql_row_to_document(row) for row in query.all()]
return documents
@ -148,12 +142,10 @@ class SQLDocumentStore(BaseDocumentStore):
self.session.commit()
def update_document_meta(self, id: str, meta: Dict[str, str]):
document = self.session.query(DocumentORM).get(id)
meta_orms = [
self._get_or_create(session=self.session, model=MetaORM, name=key, value=value)
for key, value in meta.items()
]
document.meta = meta_orms
self.session.query(MetaORM).filter_by(document_id=id).delete()
meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()]
for m in meta_orms:
self.session.add(m)
self.session.commit()
def add_eval_data(self, filename: str, doc_index: str = "eval_document", label_index: str = "label"):

View File

@ -15,6 +15,28 @@ def test_get_all_documents_without_filters(document_store_with_docs):
assert {d.meta["meta_field"] for d in documents} == {"test1", "test2", "test3"}
def test_get_all_document_filter_duplicate_value(document_store):
documents = [
Document(
text="Doc1",
meta={"f1": "0"}
),
Document(
text="Doc1",
meta={"f1": "1", "vector_id": "0"}
),
Document(
text="Doc2",
meta={"f3": "0"}
)
]
document_store.write_documents(documents)
documents = document_store.get_all_documents(filters={"f1": ["1"]})
assert documents[0].text == "Doc1"
assert len(documents) == 1
assert {d.meta["vector_id"] for d in documents} == {"0"}
def test_get_all_documents_with_correct_filters(document_store_with_docs):
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test2"]})
assert len(documents) == 1
@ -234,12 +256,29 @@ def test_multilabel_no_answer(document_store):
document_store.delete_all_documents(index="haystack_test_multilabel_no_answer")
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_elasticsearch_update_meta(document_store_with_docs):
document = document_store_with_docs.query(query=None, filters={"name": ["filename1"]})[0]
document_store_with_docs.update_document_meta(document.id, meta={"meta_field": "updated_meta"})
updated_document = document_store_with_docs.query(query=None, filters={"name": ["filename1"]})[0]
assert updated_document.meta["meta_field"] == "updated_meta"
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql"], indirect=True)
def test_elasticsearch_update_meta(document_store):
documents = [
Document(
text="Doc1",
meta={"vector_id": "1", "meta_key": "1"}
),
Document(
text="Doc2",
meta={"vector_id": "2", "meta_key": "2"}
),
Document(
text="Doc3",
meta={"vector_id": "3", "meta_key": "3"}
)
]
document_store.write_documents(documents)
document_2 = document_store.get_all_documents(filters={"meta_key": ["2"]})[0]
document_store.update_document_meta(document_2.id, meta={"vector_id": "99", "meta_key": "2"})
updated_document = document_store.get_document_by_id(document_2.id)
assert len(updated_document.meta.keys()) == 2
assert updated_document.meta["vector_id"] == "99"
assert updated_document.meta["meta_key"] == "2"
def test_elasticsearch_custom_fields(elasticsearch_fixture):