diff --git a/haystack/document_stores/sql.py b/haystack/document_stores/sql.py index 9a7cd0e28..89a7ab220 100644 --- a/haystack/document_stores/sql.py +++ b/haystack/document_stores/sql.py @@ -399,6 +399,8 @@ class SQLDocumentStore(BaseDocumentStore): docs_orm = [] for doc in document_objects[i : i + batch_size]: meta_fields = doc.meta or {} + if "classification" in meta_fields: + meta_fields = self._flatten_classification_meta_fields(meta_fields) vector_id = meta_fields.pop("vector_id", None) meta_orms = [] for key, value in meta_fields.items(): @@ -785,3 +787,14 @@ class SQLDocumentStore(BaseDocumentStore): for whereclause in self._column_windows(q.session, column, windowsize): for row in q.filter(whereclause).order_by(column): yield row + + def _flatten_classification_meta_fields(self, meta_fields: dict) -> dict: + """ + Since SQLDocumentStore does not support dictionaries for metadata values, + the DocumentClassifier output is flattened + """ + meta_fields["classification.label"] = meta_fields["classification"]["label"] + meta_fields["classification.score"] = meta_fields["classification"]["score"] + meta_fields["classification.details"] = str(meta_fields["classification"]["details"]) + del meta_fields["classification"] + return meta_fields diff --git a/test/conftest.py b/test/conftest.py index b6fcfc3c2..4c8ea3819 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -106,24 +106,6 @@ posthog.disabled = True requests_cache.install_cache(urls_expire_after={"huggingface.co": timedelta(hours=1), "*": requests_cache.DO_NOT_CACHE}) -def _sql_session_rollback(self, attr): - """ - Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch - errors where an intended operation is still in a transaction, but not committed to the database. - """ - method = object.__getattribute__(self, attr) - if callable(method): - try: - self.session.rollback() - except AttributeError: - pass - - return method - - -SQLDocumentStore.__getattribute__ = _sql_session_rollback - - def pytest_collection_modifyitems(config, items): # add pytest markers for tests that are not explicitly marked but include some keywords name_to_markers = { diff --git a/test/document_stores/test_document_store.py b/test/document_stores/test_document_store.py index bd8d1d481..cd8bc4651 100644 --- a/test/document_stores/test_document_store.py +++ b/test/document_stores/test_document_store.py @@ -99,7 +99,6 @@ def test_write_with_duplicate_doc_ids_custom_index(document_store: BaseDocumentS def test_get_all_documents_without_filters(document_store_with_docs): - print("hey!") documents = document_store_with_docs.get_all_documents() assert all(isinstance(d, Document) for d in documents) assert len(documents) == 5 diff --git a/test/document_stores/test_sql.py b/test/document_stores/test_sql.py index 8852f28b6..f26153b20 100644 --- a/test/document_stores/test_sql.py +++ b/test/document_stores/test_sql.py @@ -62,6 +62,42 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract): with pytest.raises(Exception, match=r"(?i)unique"): ds.write_documents([doc2], index="index3") + @pytest.mark.integration + def test_sql_get_documents_using_nested_filters_about_classification(self, ds): + documents = [ + Document( + content="That's good. I like it.", + id="1", + meta={ + "classification": { + "label": "LABEL_1", + "score": 0.694, + "details": {"LABEL_1": 0.694, "LABEL_0": 0.306}, + } + }, + ), + Document( + content="That's bad. I don't like it.", + id="2", + meta={ + "classification": { + "label": "LABEL_0", + "score": 0.898, + "details": {"LABEL_0": 0.898, "LABEL_1": 0.102}, + } + }, + ), + ] + ds.write_documents(documents) + + assert ds.get_document_count() == 2 + assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.1}})) == 2 + assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_1", "LABEL_0"]})) == 2 + assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.8}})) == 1 + assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_1"]})) == 1 + assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0 + assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0 + # NOTE: the SQLDocumentStore behaves differently to the others when filters are applied. # While this should be considered a bug, the relative tests are skipped in the meantime