feat: store metadata using JSON in SQLDocumentStore (#3547)

* add warnings

* make the field cachable

* review comment
This commit is contained in:
Massimiliano Pippi 2022-11-18 08:26:19 +01:00 committed by GitHub
parent 1399681c81
commit ea75e2aab5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 46 deletions

View File

@ -2,6 +2,7 @@ from typing import Any, Dict, Union, List, Optional, Generator
import logging import logging
import itertools import itertools
import json
from uuid import uuid4 from uuid import uuid4
import numpy as np import numpy as np
@ -20,9 +21,10 @@ try:
JSON, JSON,
ForeignKeyConstraint, ForeignKeyConstraint,
UniqueConstraint, UniqueConstraint,
TypeDecorator,
) )
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker, validates from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.sql import case, null from sqlalchemy.sql import case, null
except (ImportError, ModuleNotFoundError) as ie: except (ImportError, ModuleNotFoundError) as ie:
from haystack.utils.import_utils import _optional_component_not_installed from haystack.utils.import_utils import _optional_component_not_installed
@ -38,6 +40,20 @@ logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any Base = declarative_base() # type: Any
class ArrayType(TypeDecorator):
impl = String
cache_ok = True
def process_bind_param(self, value, dialect):
return json.dumps(value)
def process_result_value(self, value, dialect):
if value is not None:
return json.loads(value)
return value
class ORMBase(Base): class ORMBase(Base):
__abstract__ = True __abstract__ = True
@ -64,7 +80,7 @@ class MetaDocumentORM(ORMBase):
__tablename__ = "meta_document" __tablename__ = "meta_document"
name = Column(String(100), index=True) name = Column(String(100), index=True)
value = Column(String(1000), index=True) value = Column(ArrayType(100), index=True)
documents = relationship("DocumentORM", back_populates="meta") documents = relationship("DocumentORM", back_populates="meta")
document_id = Column(String(100), nullable=False, index=True) document_id = Column(String(100), nullable=False, index=True)
@ -76,17 +92,6 @@ class MetaDocumentORM(ORMBase):
{}, {},
) # type: ignore ) # type: ignore
valid_metadata_types = (str, int, float, bool, bytes, bytearray, type(None))
@validates("value")
def validate_value(self, key, value):
if not isinstance(value, self.valid_metadata_types):
raise TypeError(
f"Discarded metadata '{self.name}', since it has invalid type: {type(value).__name__}.\n"
f"SQLDocumentStore can accept and cast to string only the following types: {', '.join([el.__name__ for el in self.valid_metadata_types])}"
)
return value
class LabelORM(ORMBase): class LabelORM(ORMBase):
__tablename__ = "label" __tablename__ = "label"
@ -298,6 +303,7 @@ class SQLDocumentStore(BaseDocumentStore):
).filter_by(index=index) ).filter_by(index=index)
if filters: if filters:
logger.warning("filters won't work on metadata fields containing compound data types")
parsed_filter = LogicalFilterClause.parse(filters) parsed_filter = LogicalFilterClause.parse(filters)
select_ids = parsed_filter.convert_to_sql(MetaDocumentORM) select_ids = parsed_filter.convert_to_sql(MetaDocumentORM)
documents_query = documents_query.filter(DocumentORM.id.in_(select_ids)) documents_query = documents_query.filter(DocumentORM.id.in_(select_ids))
@ -402,12 +408,7 @@ class SQLDocumentStore(BaseDocumentStore):
if "classification" in meta_fields: if "classification" in meta_fields:
meta_fields = self._flatten_classification_meta_fields(meta_fields) meta_fields = self._flatten_classification_meta_fields(meta_fields)
vector_id = meta_fields.pop("vector_id", None) vector_id = meta_fields.pop("vector_id", None)
meta_orms = [] meta_orms = [MetaDocumentORM(name=key, value=value) for key, value in meta_fields.items()]
for key, value in meta_fields.items():
try:
meta_orms.append(MetaDocumentORM(name=key, value=value))
except TypeError as ex:
logger.error("Document %s - %s", doc.id, ex)
doc_orm = DocumentORM( doc_orm = DocumentORM(
id=doc.id, id=doc.id,
content=doc.to_dict()["content"], content=doc.to_dict()["content"],

View File

@ -1,3 +1,5 @@
import logging
import pytest import pytest
from haystack.document_stores.sql import SQLDocumentStore from haystack.document_stores.sql import SQLDocumentStore
@ -24,28 +26,6 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
ds.delete_index(index="custom_index") ds.delete_index(index="custom_index")
assert ds.get_document_count(index="custom_index") == 0 assert ds.get_document_count(index="custom_index") == 0
@pytest.mark.integration
def test_sql_write_document_invalid_meta(self, ds):
documents = [
{
"content": "dict_with_invalid_meta",
"valid_meta_field": "test1",
"invalid_meta_field": [1, 2, 3],
"name": "filename1",
"id": "1",
},
Document(
content="document_object_with_invalid_meta",
meta={"valid_meta_field": "test2", "invalid_meta_field": [1, 2, 3], "name": "filename2"},
id="2",
),
]
ds.write_documents(documents)
documents_in_store = ds.get_all_documents()
assert len(documents_in_store) == 2
assert ds.get_document_by_id("1").meta == {"name": "filename1", "valid_meta_field": "test1"}
assert ds.get_document_by_id("2").meta == {"name": "filename2", "valid_meta_field": "test2"}
@pytest.mark.integration @pytest.mark.integration
def test_sql_write_different_documents_same_vector_id(self, ds): def test_sql_write_different_documents_same_vector_id(self, ds):
doc1 = {"content": "content 1", "name": "doc1", "id": "1", "vector_id": "vector_id"} doc1 = {"content": "content 1", "name": "doc1", "id": "1", "vector_id": "vector_id"}
@ -98,13 +78,15 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0 assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0 assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0
# NOTE: the SQLDocumentStore behaves differently to the others when filters are applied. # NOTE: the SQLDocumentStore marshals metadata values with JSON so querying
# While this should be considered a bug, the relative tests are skipped in the meantime # using filters doesn't always work. While this should be considered a bug,
# the relative tests are either customized or skipped while we work on a fix.
@pytest.mark.skip
@pytest.mark.integration @pytest.mark.integration
def test_ne_filters(self, ds, documents): def test_ne_filters(self, ds, caplog):
pass with caplog.at_level(logging.WARNING):
ds.get_all_documents(filters={"year": {"$ne": "2020"}})
assert "filters won't work on metadata fields" in caplog.text
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.integration @pytest.mark.integration