mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-03 19:29:32 +00:00 
			
		
		
		
	feat: store metadata using JSON in SQLDocumentStore (#3547)
* add warnings * make the field cachable * review comment
This commit is contained in:
		
							parent
							
								
									1399681c81
								
							
						
					
					
						commit
						ea75e2aab5
					
				@ -2,6 +2,7 @@ from typing import Any, Dict, Union, List, Optional, Generator
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import itertools
 | 
			
		||||
import json
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
@ -20,9 +21,10 @@ try:
 | 
			
		||||
        JSON,
 | 
			
		||||
        ForeignKeyConstraint,
 | 
			
		||||
        UniqueConstraint,
 | 
			
		||||
        TypeDecorator,
 | 
			
		||||
    )
 | 
			
		||||
    from sqlalchemy.ext.declarative import declarative_base
 | 
			
		||||
    from sqlalchemy.orm import relationship, sessionmaker, validates
 | 
			
		||||
    from sqlalchemy.orm import relationship, sessionmaker
 | 
			
		||||
    from sqlalchemy.sql import case, null
 | 
			
		||||
except (ImportError, ModuleNotFoundError) as ie:
 | 
			
		||||
    from haystack.utils.import_utils import _optional_component_not_installed
 | 
			
		||||
@ -38,6 +40,20 @@ logger = logging.getLogger(__name__)
 | 
			
		||||
Base = declarative_base()  # type: Any
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ArrayType(TypeDecorator):
 | 
			
		||||
 | 
			
		||||
    impl = String
 | 
			
		||||
    cache_ok = True
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        return json.dumps(value)
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            return json.loads(value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ORMBase(Base):
 | 
			
		||||
    __abstract__ = True
 | 
			
		||||
 | 
			
		||||
@ -64,7 +80,7 @@ class MetaDocumentORM(ORMBase):
 | 
			
		||||
    __tablename__ = "meta_document"
 | 
			
		||||
 | 
			
		||||
    name = Column(String(100), index=True)
 | 
			
		||||
    value = Column(String(1000), index=True)
 | 
			
		||||
    value = Column(ArrayType(100), index=True)
 | 
			
		||||
    documents = relationship("DocumentORM", back_populates="meta")
 | 
			
		||||
 | 
			
		||||
    document_id = Column(String(100), nullable=False, index=True)
 | 
			
		||||
@ -76,17 +92,6 @@ class MetaDocumentORM(ORMBase):
 | 
			
		||||
        {},
 | 
			
		||||
    )  # type: ignore
 | 
			
		||||
 | 
			
		||||
    valid_metadata_types = (str, int, float, bool, bytes, bytearray, type(None))
 | 
			
		||||
 | 
			
		||||
    @validates("value")
 | 
			
		||||
    def validate_value(self, key, value):
 | 
			
		||||
        if not isinstance(value, self.valid_metadata_types):
 | 
			
		||||
            raise TypeError(
 | 
			
		||||
                f"Discarded metadata '{self.name}', since it has invalid type: {type(value).__name__}.\n"
 | 
			
		||||
                f"SQLDocumentStore can accept and cast to string only the following types: {', '.join([el.__name__ for el in self.valid_metadata_types])}"
 | 
			
		||||
            )
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LabelORM(ORMBase):
 | 
			
		||||
    __tablename__ = "label"
 | 
			
		||||
@ -298,6 +303,7 @@ class SQLDocumentStore(BaseDocumentStore):
 | 
			
		||||
        ).filter_by(index=index)
 | 
			
		||||
 | 
			
		||||
        if filters:
 | 
			
		||||
            logger.warning("filters won't work on metadata fields containing compound data types")
 | 
			
		||||
            parsed_filter = LogicalFilterClause.parse(filters)
 | 
			
		||||
            select_ids = parsed_filter.convert_to_sql(MetaDocumentORM)
 | 
			
		||||
            documents_query = documents_query.filter(DocumentORM.id.in_(select_ids))
 | 
			
		||||
@ -402,12 +408,7 @@ class SQLDocumentStore(BaseDocumentStore):
 | 
			
		||||
                if "classification" in meta_fields:
 | 
			
		||||
                    meta_fields = self._flatten_classification_meta_fields(meta_fields)
 | 
			
		||||
                vector_id = meta_fields.pop("vector_id", None)
 | 
			
		||||
                meta_orms = []
 | 
			
		||||
                for key, value in meta_fields.items():
 | 
			
		||||
                    try:
 | 
			
		||||
                        meta_orms.append(MetaDocumentORM(name=key, value=value))
 | 
			
		||||
                    except TypeError as ex:
 | 
			
		||||
                        logger.error("Document %s - %s", doc.id, ex)
 | 
			
		||||
                meta_orms = [MetaDocumentORM(name=key, value=value) for key, value in meta_fields.items()]
 | 
			
		||||
                doc_orm = DocumentORM(
 | 
			
		||||
                    id=doc.id,
 | 
			
		||||
                    content=doc.to_dict()["content"],
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,5 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from haystack.document_stores.sql import SQLDocumentStore
 | 
			
		||||
@ -24,28 +26,6 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
 | 
			
		||||
        ds.delete_index(index="custom_index")
 | 
			
		||||
        assert ds.get_document_count(index="custom_index") == 0
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.integration
 | 
			
		||||
    def test_sql_write_document_invalid_meta(self, ds):
 | 
			
		||||
        documents = [
 | 
			
		||||
            {
 | 
			
		||||
                "content": "dict_with_invalid_meta",
 | 
			
		||||
                "valid_meta_field": "test1",
 | 
			
		||||
                "invalid_meta_field": [1, 2, 3],
 | 
			
		||||
                "name": "filename1",
 | 
			
		||||
                "id": "1",
 | 
			
		||||
            },
 | 
			
		||||
            Document(
 | 
			
		||||
                content="document_object_with_invalid_meta",
 | 
			
		||||
                meta={"valid_meta_field": "test2", "invalid_meta_field": [1, 2, 3], "name": "filename2"},
 | 
			
		||||
                id="2",
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
        ds.write_documents(documents)
 | 
			
		||||
        documents_in_store = ds.get_all_documents()
 | 
			
		||||
        assert len(documents_in_store) == 2
 | 
			
		||||
        assert ds.get_document_by_id("1").meta == {"name": "filename1", "valid_meta_field": "test1"}
 | 
			
		||||
        assert ds.get_document_by_id("2").meta == {"name": "filename2", "valid_meta_field": "test2"}
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.integration
 | 
			
		||||
    def test_sql_write_different_documents_same_vector_id(self, ds):
 | 
			
		||||
        doc1 = {"content": "content 1", "name": "doc1", "id": "1", "vector_id": "vector_id"}
 | 
			
		||||
@ -98,13 +78,15 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
 | 
			
		||||
        assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0
 | 
			
		||||
        assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0
 | 
			
		||||
 | 
			
		||||
    # NOTE: the SQLDocumentStore behaves differently to the others when filters are applied.
 | 
			
		||||
    # While this should be considered a bug, the relative tests are skipped in the meantime
 | 
			
		||||
    # NOTE: the SQLDocumentStore marshals metadata values with JSON so querying
 | 
			
		||||
    # using filters doesn't always work. While this should be considered a bug,
 | 
			
		||||
    # the relative tests are either customized or skipped while we work on a fix.
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.skip
 | 
			
		||||
    @pytest.mark.integration
 | 
			
		||||
    def test_ne_filters(self, ds, documents):
 | 
			
		||||
        pass
 | 
			
		||||
    def test_ne_filters(self, ds, caplog):
 | 
			
		||||
        with caplog.at_level(logging.WARNING):
 | 
			
		||||
            ds.get_all_documents(filters={"year": {"$ne": "2020"}})
 | 
			
		||||
            assert "filters won't work on metadata fields" in caplog.text
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.skip
 | 
			
		||||
    @pytest.mark.integration
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user