mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-19 11:58:44 +00:00
fix: Flatten DocumentClassifier
output in SQLDocumentStore
; remove _sql_session_rollback
hack in tests (#3273)
* first draft * fix * fix * move test to test_sql
This commit is contained in:
parent
af78f8b431
commit
dc26e6d43e
@ -399,6 +399,8 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
docs_orm = []
|
docs_orm = []
|
||||||
for doc in document_objects[i : i + batch_size]:
|
for doc in document_objects[i : i + batch_size]:
|
||||||
meta_fields = doc.meta or {}
|
meta_fields = doc.meta or {}
|
||||||
|
if "classification" in meta_fields:
|
||||||
|
meta_fields = self._flatten_classification_meta_fields(meta_fields)
|
||||||
vector_id = meta_fields.pop("vector_id", None)
|
vector_id = meta_fields.pop("vector_id", None)
|
||||||
meta_orms = []
|
meta_orms = []
|
||||||
for key, value in meta_fields.items():
|
for key, value in meta_fields.items():
|
||||||
@ -785,3 +787,14 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
for whereclause in self._column_windows(q.session, column, windowsize):
|
for whereclause in self._column_windows(q.session, column, windowsize):
|
||||||
for row in q.filter(whereclause).order_by(column):
|
for row in q.filter(whereclause).order_by(column):
|
||||||
yield row
|
yield row
|
||||||
|
|
||||||
|
def _flatten_classification_meta_fields(self, meta_fields: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Since SQLDocumentStore does not support dictionaries for metadata values,
|
||||||
|
the DocumentClassifier output is flattened
|
||||||
|
"""
|
||||||
|
meta_fields["classification.label"] = meta_fields["classification"]["label"]
|
||||||
|
meta_fields["classification.score"] = meta_fields["classification"]["score"]
|
||||||
|
meta_fields["classification.details"] = str(meta_fields["classification"]["details"])
|
||||||
|
del meta_fields["classification"]
|
||||||
|
return meta_fields
|
||||||
|
@ -106,24 +106,6 @@ posthog.disabled = True
|
|||||||
requests_cache.install_cache(urls_expire_after={"huggingface.co": timedelta(hours=1), "*": requests_cache.DO_NOT_CACHE})
|
requests_cache.install_cache(urls_expire_after={"huggingface.co": timedelta(hours=1), "*": requests_cache.DO_NOT_CACHE})
|
||||||
|
|
||||||
|
|
||||||
def _sql_session_rollback(self, attr):
|
|
||||||
"""
|
|
||||||
Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch
|
|
||||||
errors where an intended operation is still in a transaction, but not committed to the database.
|
|
||||||
"""
|
|
||||||
method = object.__getattribute__(self, attr)
|
|
||||||
if callable(method):
|
|
||||||
try:
|
|
||||||
self.session.rollback()
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return method
|
|
||||||
|
|
||||||
|
|
||||||
SQLDocumentStore.__getattribute__ = _sql_session_rollback
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(config, items):
|
def pytest_collection_modifyitems(config, items):
|
||||||
# add pytest markers for tests that are not explicitly marked but include some keywords
|
# add pytest markers for tests that are not explicitly marked but include some keywords
|
||||||
name_to_markers = {
|
name_to_markers = {
|
||||||
|
@ -99,7 +99,6 @@ def test_write_with_duplicate_doc_ids_custom_index(document_store: BaseDocumentS
|
|||||||
|
|
||||||
|
|
||||||
def test_get_all_documents_without_filters(document_store_with_docs):
|
def test_get_all_documents_without_filters(document_store_with_docs):
|
||||||
print("hey!")
|
|
||||||
documents = document_store_with_docs.get_all_documents()
|
documents = document_store_with_docs.get_all_documents()
|
||||||
assert all(isinstance(d, Document) for d in documents)
|
assert all(isinstance(d, Document) for d in documents)
|
||||||
assert len(documents) == 5
|
assert len(documents) == 5
|
||||||
|
@ -62,6 +62,42 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
|
|||||||
with pytest.raises(Exception, match=r"(?i)unique"):
|
with pytest.raises(Exception, match=r"(?i)unique"):
|
||||||
ds.write_documents([doc2], index="index3")
|
ds.write_documents([doc2], index="index3")
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_sql_get_documents_using_nested_filters_about_classification(self, ds):
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
content="That's good. I like it.",
|
||||||
|
id="1",
|
||||||
|
meta={
|
||||||
|
"classification": {
|
||||||
|
"label": "LABEL_1",
|
||||||
|
"score": 0.694,
|
||||||
|
"details": {"LABEL_1": 0.694, "LABEL_0": 0.306},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
content="That's bad. I don't like it.",
|
||||||
|
id="2",
|
||||||
|
meta={
|
||||||
|
"classification": {
|
||||||
|
"label": "LABEL_0",
|
||||||
|
"score": 0.898,
|
||||||
|
"details": {"LABEL_0": 0.898, "LABEL_1": 0.102},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
ds.write_documents(documents)
|
||||||
|
|
||||||
|
assert ds.get_document_count() == 2
|
||||||
|
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.1}})) == 2
|
||||||
|
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_1", "LABEL_0"]})) == 2
|
||||||
|
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.8}})) == 1
|
||||||
|
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_1"]})) == 1
|
||||||
|
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0
|
||||||
|
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0
|
||||||
|
|
||||||
# NOTE: the SQLDocumentStore behaves differently to the others when filters are applied.
|
# NOTE: the SQLDocumentStore behaves differently to the others when filters are applied.
|
||||||
# While this should be considered a bug, the relative tests are skipped in the meantime
|
# While this should be considered a bug, the relative tests are skipped in the meantime
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user