fix: Flatten DocumentClassifier output in SQLDocumentStore; remove _sql_session_rollback hack in tests (#3273)

* first draft

* fix

* fix

* move test to test_sql
This commit is contained in:
Stefano Fiorucci 2022-11-16 12:20:57 +01:00 committed by GitHub
parent af78f8b431
commit dc26e6d43e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 19 deletions

View File

@ -399,6 +399,8 @@ class SQLDocumentStore(BaseDocumentStore):
docs_orm = []
for doc in document_objects[i : i + batch_size]:
meta_fields = doc.meta or {}
if "classification" in meta_fields:
meta_fields = self._flatten_classification_meta_fields(meta_fields)
vector_id = meta_fields.pop("vector_id", None)
meta_orms = []
for key, value in meta_fields.items():
@ -785,3 +787,14 @@ class SQLDocumentStore(BaseDocumentStore):
for whereclause in self._column_windows(q.session, column, windowsize):
for row in q.filter(whereclause).order_by(column):
yield row
def _flatten_classification_meta_fields(self, meta_fields: dict) -> dict:
"""
Since SQLDocumentStore does not support dictionaries for metadata values,
the DocumentClassifier output is flattened
"""
meta_fields["classification.label"] = meta_fields["classification"]["label"]
meta_fields["classification.score"] = meta_fields["classification"]["score"]
meta_fields["classification.details"] = str(meta_fields["classification"]["details"])
del meta_fields["classification"]
return meta_fields

View File

@ -106,24 +106,6 @@ posthog.disabled = True
requests_cache.install_cache(urls_expire_after={"huggingface.co": timedelta(hours=1), "*": requests_cache.DO_NOT_CACHE})
def _sql_session_rollback(self, attr):
"""
Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch
errors where an intended operation is still in a transaction, but not committed to the database.
"""
method = object.__getattribute__(self, attr)
if callable(method):
try:
self.session.rollback()
except AttributeError:
pass
return method
SQLDocumentStore.__getattribute__ = _sql_session_rollback
def pytest_collection_modifyitems(config, items):
# add pytest markers for tests that are not explicitly marked but include some keywords
name_to_markers = {

View File

@ -99,7 +99,6 @@ def test_write_with_duplicate_doc_ids_custom_index(document_store: BaseDocumentS
def test_get_all_documents_without_filters(document_store_with_docs):
print("hey!")
documents = document_store_with_docs.get_all_documents()
assert all(isinstance(d, Document) for d in documents)
assert len(documents) == 5

View File

@ -62,6 +62,42 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
with pytest.raises(Exception, match=r"(?i)unique"):
ds.write_documents([doc2], index="index3")
@pytest.mark.integration
def test_sql_get_documents_using_nested_filters_about_classification(self, ds):
documents = [
Document(
content="That's good. I like it.",
id="1",
meta={
"classification": {
"label": "LABEL_1",
"score": 0.694,
"details": {"LABEL_1": 0.694, "LABEL_0": 0.306},
}
},
),
Document(
content="That's bad. I don't like it.",
id="2",
meta={
"classification": {
"label": "LABEL_0",
"score": 0.898,
"details": {"LABEL_0": 0.898, "LABEL_1": 0.102},
}
},
),
]
ds.write_documents(documents)
assert ds.get_document_count() == 2
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.1}})) == 2
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_1", "LABEL_0"]})) == 2
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.8}})) == 1
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_1"]})) == 1
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0
# NOTE: the SQLDocumentStore behaves differently to the others when filters are applied.
# While this should be considered a bug, the relative tests are skipped in the meantime