mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-17 19:09:09 +00:00
fix: Flatten DocumentClassifier
output in SQLDocumentStore
; remove _sql_session_rollback
hack in tests (#3273)
* first draft * fix * fix * move test to test_sql
This commit is contained in:
parent
af78f8b431
commit
dc26e6d43e
@ -399,6 +399,8 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
docs_orm = []
|
||||
for doc in document_objects[i : i + batch_size]:
|
||||
meta_fields = doc.meta or {}
|
||||
if "classification" in meta_fields:
|
||||
meta_fields = self._flatten_classification_meta_fields(meta_fields)
|
||||
vector_id = meta_fields.pop("vector_id", None)
|
||||
meta_orms = []
|
||||
for key, value in meta_fields.items():
|
||||
@ -785,3 +787,14 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
for whereclause in self._column_windows(q.session, column, windowsize):
|
||||
for row in q.filter(whereclause).order_by(column):
|
||||
yield row
|
||||
|
||||
def _flatten_classification_meta_fields(self, meta_fields: dict) -> dict:
|
||||
"""
|
||||
Since SQLDocumentStore does not support dictionaries for metadata values,
|
||||
the DocumentClassifier output is flattened
|
||||
"""
|
||||
meta_fields["classification.label"] = meta_fields["classification"]["label"]
|
||||
meta_fields["classification.score"] = meta_fields["classification"]["score"]
|
||||
meta_fields["classification.details"] = str(meta_fields["classification"]["details"])
|
||||
del meta_fields["classification"]
|
||||
return meta_fields
|
||||
|
@ -106,24 +106,6 @@ posthog.disabled = True
|
||||
requests_cache.install_cache(urls_expire_after={"huggingface.co": timedelta(hours=1), "*": requests_cache.DO_NOT_CACHE})
|
||||
|
||||
|
||||
def _sql_session_rollback(self, attr):
|
||||
"""
|
||||
Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch
|
||||
errors where an intended operation is still in a transaction, but not committed to the database.
|
||||
"""
|
||||
method = object.__getattribute__(self, attr)
|
||||
if callable(method):
|
||||
try:
|
||||
self.session.rollback()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return method
|
||||
|
||||
|
||||
SQLDocumentStore.__getattribute__ = _sql_session_rollback
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# add pytest markers for tests that are not explicitly marked but include some keywords
|
||||
name_to_markers = {
|
||||
|
@ -99,7 +99,6 @@ def test_write_with_duplicate_doc_ids_custom_index(document_store: BaseDocumentS
|
||||
|
||||
|
||||
def test_get_all_documents_without_filters(document_store_with_docs):
|
||||
print("hey!")
|
||||
documents = document_store_with_docs.get_all_documents()
|
||||
assert all(isinstance(d, Document) for d in documents)
|
||||
assert len(documents) == 5
|
||||
|
@ -62,6 +62,42 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
with pytest.raises(Exception, match=r"(?i)unique"):
|
||||
ds.write_documents([doc2], index="index3")
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sql_get_documents_using_nested_filters_about_classification(self, ds):
|
||||
documents = [
|
||||
Document(
|
||||
content="That's good. I like it.",
|
||||
id="1",
|
||||
meta={
|
||||
"classification": {
|
||||
"label": "LABEL_1",
|
||||
"score": 0.694,
|
||||
"details": {"LABEL_1": 0.694, "LABEL_0": 0.306},
|
||||
}
|
||||
},
|
||||
),
|
||||
Document(
|
||||
content="That's bad. I don't like it.",
|
||||
id="2",
|
||||
meta={
|
||||
"classification": {
|
||||
"label": "LABEL_0",
|
||||
"score": 0.898,
|
||||
"details": {"LABEL_0": 0.898, "LABEL_1": 0.102},
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
ds.write_documents(documents)
|
||||
|
||||
assert ds.get_document_count() == 2
|
||||
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.1}})) == 2
|
||||
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_1", "LABEL_0"]})) == 2
|
||||
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.8}})) == 1
|
||||
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_1"]})) == 1
|
||||
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0
|
||||
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0
|
||||
|
||||
# NOTE: the SQLDocumentStore behaves differently to the others when filters are applied.
|
||||
# While this should be considered a bug, the relative tests are skipped in the meantime
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user