mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-07 13:24:16 +00:00
feat: store metadata using JSON in SQLDocumentStore (#3547)
* add warnings * make the field cachable * review comment
This commit is contained in:
parent
1399681c81
commit
ea75e2aab5
@ -2,6 +2,7 @@ from typing import Any, Dict, Union, List, Optional, Generator
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import itertools
|
import itertools
|
||||||
|
import json
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -20,9 +21,10 @@ try:
|
|||||||
JSON,
|
JSON,
|
||||||
ForeignKeyConstraint,
|
ForeignKeyConstraint,
|
||||||
UniqueConstraint,
|
UniqueConstraint,
|
||||||
|
TypeDecorator,
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import relationship, sessionmaker, validates
|
from sqlalchemy.orm import relationship, sessionmaker
|
||||||
from sqlalchemy.sql import case, null
|
from sqlalchemy.sql import case, null
|
||||||
except (ImportError, ModuleNotFoundError) as ie:
|
except (ImportError, ModuleNotFoundError) as ie:
|
||||||
from haystack.utils.import_utils import _optional_component_not_installed
|
from haystack.utils.import_utils import _optional_component_not_installed
|
||||||
@ -38,6 +40,20 @@ logger = logging.getLogger(__name__)
|
|||||||
Base = declarative_base() # type: Any
|
Base = declarative_base() # type: Any
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayType(TypeDecorator):
|
||||||
|
|
||||||
|
impl = String
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
return json.dumps(value)
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if value is not None:
|
||||||
|
return json.loads(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class ORMBase(Base):
|
class ORMBase(Base):
|
||||||
__abstract__ = True
|
__abstract__ = True
|
||||||
|
|
||||||
@ -64,7 +80,7 @@ class MetaDocumentORM(ORMBase):
|
|||||||
__tablename__ = "meta_document"
|
__tablename__ = "meta_document"
|
||||||
|
|
||||||
name = Column(String(100), index=True)
|
name = Column(String(100), index=True)
|
||||||
value = Column(String(1000), index=True)
|
value = Column(ArrayType(100), index=True)
|
||||||
documents = relationship("DocumentORM", back_populates="meta")
|
documents = relationship("DocumentORM", back_populates="meta")
|
||||||
|
|
||||||
document_id = Column(String(100), nullable=False, index=True)
|
document_id = Column(String(100), nullable=False, index=True)
|
||||||
@ -76,17 +92,6 @@ class MetaDocumentORM(ORMBase):
|
|||||||
{},
|
{},
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
valid_metadata_types = (str, int, float, bool, bytes, bytearray, type(None))
|
|
||||||
|
|
||||||
@validates("value")
|
|
||||||
def validate_value(self, key, value):
|
|
||||||
if not isinstance(value, self.valid_metadata_types):
|
|
||||||
raise TypeError(
|
|
||||||
f"Discarded metadata '{self.name}', since it has invalid type: {type(value).__name__}.\n"
|
|
||||||
f"SQLDocumentStore can accept and cast to string only the following types: {', '.join([el.__name__ for el in self.valid_metadata_types])}"
|
|
||||||
)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class LabelORM(ORMBase):
|
class LabelORM(ORMBase):
|
||||||
__tablename__ = "label"
|
__tablename__ = "label"
|
||||||
@ -298,6 +303,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
).filter_by(index=index)
|
).filter_by(index=index)
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
|
logger.warning("filters won't work on metadata fields containing compound data types")
|
||||||
parsed_filter = LogicalFilterClause.parse(filters)
|
parsed_filter = LogicalFilterClause.parse(filters)
|
||||||
select_ids = parsed_filter.convert_to_sql(MetaDocumentORM)
|
select_ids = parsed_filter.convert_to_sql(MetaDocumentORM)
|
||||||
documents_query = documents_query.filter(DocumentORM.id.in_(select_ids))
|
documents_query = documents_query.filter(DocumentORM.id.in_(select_ids))
|
||||||
@ -402,12 +408,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
if "classification" in meta_fields:
|
if "classification" in meta_fields:
|
||||||
meta_fields = self._flatten_classification_meta_fields(meta_fields)
|
meta_fields = self._flatten_classification_meta_fields(meta_fields)
|
||||||
vector_id = meta_fields.pop("vector_id", None)
|
vector_id = meta_fields.pop("vector_id", None)
|
||||||
meta_orms = []
|
meta_orms = [MetaDocumentORM(name=key, value=value) for key, value in meta_fields.items()]
|
||||||
for key, value in meta_fields.items():
|
|
||||||
try:
|
|
||||||
meta_orms.append(MetaDocumentORM(name=key, value=value))
|
|
||||||
except TypeError as ex:
|
|
||||||
logger.error("Document %s - %s", doc.id, ex)
|
|
||||||
doc_orm = DocumentORM(
|
doc_orm = DocumentORM(
|
||||||
id=doc.id,
|
id=doc.id,
|
||||||
content=doc.to_dict()["content"],
|
content=doc.to_dict()["content"],
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.document_stores.sql import SQLDocumentStore
|
from haystack.document_stores.sql import SQLDocumentStore
|
||||||
@ -24,28 +26,6 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
|
|||||||
ds.delete_index(index="custom_index")
|
ds.delete_index(index="custom_index")
|
||||||
assert ds.get_document_count(index="custom_index") == 0
|
assert ds.get_document_count(index="custom_index") == 0
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
def test_sql_write_document_invalid_meta(self, ds):
|
|
||||||
documents = [
|
|
||||||
{
|
|
||||||
"content": "dict_with_invalid_meta",
|
|
||||||
"valid_meta_field": "test1",
|
|
||||||
"invalid_meta_field": [1, 2, 3],
|
|
||||||
"name": "filename1",
|
|
||||||
"id": "1",
|
|
||||||
},
|
|
||||||
Document(
|
|
||||||
content="document_object_with_invalid_meta",
|
|
||||||
meta={"valid_meta_field": "test2", "invalid_meta_field": [1, 2, 3], "name": "filename2"},
|
|
||||||
id="2",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
ds.write_documents(documents)
|
|
||||||
documents_in_store = ds.get_all_documents()
|
|
||||||
assert len(documents_in_store) == 2
|
|
||||||
assert ds.get_document_by_id("1").meta == {"name": "filename1", "valid_meta_field": "test1"}
|
|
||||||
assert ds.get_document_by_id("2").meta == {"name": "filename2", "valid_meta_field": "test2"}
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_sql_write_different_documents_same_vector_id(self, ds):
|
def test_sql_write_different_documents_same_vector_id(self, ds):
|
||||||
doc1 = {"content": "content 1", "name": "doc1", "id": "1", "vector_id": "vector_id"}
|
doc1 = {"content": "content 1", "name": "doc1", "id": "1", "vector_id": "vector_id"}
|
||||||
@ -98,13 +78,15 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
|
|||||||
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0
|
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0
|
||||||
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0
|
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0
|
||||||
|
|
||||||
# NOTE: the SQLDocumentStore behaves differently to the others when filters are applied.
|
# NOTE: the SQLDocumentStore marshals metadata values with JSON so querying
|
||||||
# While this should be considered a bug, the relative tests are skipped in the meantime
|
# using filters doesn't always work. While this should be considered a bug,
|
||||||
|
# the relative tests are either customized or skipped while we work on a fix.
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_ne_filters(self, ds, documents):
|
def test_ne_filters(self, ds, caplog):
|
||||||
pass
|
with caplog.at_level(logging.WARNING):
|
||||||
|
ds.get_all_documents(filters={"year": {"$ne": "2020"}})
|
||||||
|
assert "filters won't work on metadata fields" in caplog.text
|
||||||
|
|
||||||
@pytest.mark.skip
|
@pytest.mark.skip
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user