diff --git a/haystack/document_store/sql.py b/haystack/document_store/sql.py index d3763e686..a6d0c2afc 100644 --- a/haystack/document_store/sql.py +++ b/haystack/document_store/sql.py @@ -26,7 +26,7 @@ class DocumentORM(ORMBase): text = Column(String, nullable=False) index = Column(String, nullable=False) - meta = relationship("MetaORM", secondary="document_meta", backref="Document") + meta = relationship("MetaORM", backref="Document") class MetaORM(ORMBase): @@ -34,15 +34,9 @@ class MetaORM(ORMBase): name = Column(String, index=True) value = Column(String, index=True) - - documents = relationship(DocumentORM, secondary="document_meta", backref="Meta") - - -class DocumentMetaORM(ORMBase): - __tablename__ = "document_meta" - document_id = Column(String, ForeignKey("document.id"), nullable=False) - meta_id = Column(Integer, ForeignKey("meta.id"), nullable=False) + + documents = relationship(DocumentORM, backref="Meta") class LabelORM(ORMBase): @@ -88,9 +82,9 @@ class SQLDocumentStore(BaseDocumentStore): query = self.session.query(DocumentORM).filter_by(index=index) if filters: + query = query.join(MetaORM) for key, values in filters.items(): - query = query.filter(DocumentORM.meta.any(MetaORM.name.in_([key])))\ - .filter(DocumentORM.meta.any(MetaORM.value.in_(values))) + query = query.filter(MetaORM.name == key, MetaORM.value.in_(values)) documents = [self._convert_sql_row_to_document(row) for row in query.all()] return documents @@ -148,12 +142,10 @@ class SQLDocumentStore(BaseDocumentStore): self.session.commit() def update_document_meta(self, id: str, meta: Dict[str, str]): - document = self.session.query(DocumentORM).get(id) - meta_orms = [ - self._get_or_create(session=self.session, model=MetaORM, name=key, value=value) - for key, value in meta.items() - ] - document.meta = meta_orms + self.session.query(MetaORM).filter_by(document_id=id).delete() + meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()] + for m in meta_orms: + self.session.add(m) self.session.commit() def add_eval_data(self, filename: str, doc_index: str = "eval_document", label_index: str = "label"): diff --git a/test/test_db.py b/test/test_db.py index affe38071..502294fb9 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -15,6 +15,28 @@ def test_get_all_documents_without_filters(document_store_with_docs): assert {d.meta["meta_field"] for d in documents} == {"test1", "test2", "test3"} +def test_get_all_document_filter_duplicate_value(document_store): + documents = [ + Document( + text="Doc1", + meta={"f1": "0"} + ), + Document( + text="Doc1", + meta={"f1": "1", "vector_id": "0"} + ), + Document( + text="Doc2", + meta={"f3": "0"} + ) + ] + document_store.write_documents(documents) + documents = document_store.get_all_documents(filters={"f1": ["1"]}) + assert documents[0].text == "Doc1" + assert len(documents) == 1 + assert {d.meta["vector_id"] for d in documents} == {"0"} + + def test_get_all_documents_with_correct_filters(document_store_with_docs): documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test2"]}) assert len(documents) == 1 @@ -234,12 +256,29 @@ def test_multilabel_no_answer(document_store): document_store.delete_all_documents(index="haystack_test_multilabel_no_answer") -@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) -def test_elasticsearch_update_meta(document_store_with_docs): - document = document_store_with_docs.query(query=None, filters={"name": ["filename1"]})[0] - document_store_with_docs.update_document_meta(document.id, meta={"meta_field": "updated_meta"}) - updated_document = document_store_with_docs.query(query=None, filters={"name": ["filename1"]})[0] - assert updated_document.meta["meta_field"] == "updated_meta" +@pytest.mark.parametrize("document_store", ["elasticsearch", "sql"], indirect=True) +def test_elasticsearch_update_meta(document_store): + documents = [ + Document( + text="Doc1", + meta={"vector_id": "1", "meta_key": "1"} + ), + Document( + text="Doc2", + meta={"vector_id": "2", "meta_key": "2"} + ), + Document( + text="Doc3", + meta={"vector_id": "3", "meta_key": "3"} + ) + ] + document_store.write_documents(documents) + document_2 = document_store.get_all_documents(filters={"meta_key": ["2"]})[0] + document_store.update_document_meta(document_2.id, meta={"vector_id": "99", "meta_key": "2"}) + updated_document = document_store.get_document_by_id(document_2.id) + assert len(updated_document.meta.keys()) == 2 + assert updated_document.meta["vector_id"] == "99" + assert updated_document.meta["meta_key"] == "2" def test_elasticsearch_custom_fields(elasticsearch_fixture):