diff --git a/haystack/document_stores/filter_utils.py b/haystack/document_stores/filter_utils.py index 6b97a0f6b..8267c2940 100644 --- a/haystack/document_stores/filter_utils.py +++ b/haystack/document_stores/filter_utils.py @@ -1,6 +1,10 @@ from typing import Union, List, Dict, Optional, Tuple from abc import ABC, abstractmethod from collections import defaultdict +from functools import reduce + +from sqlalchemy.sql import select +from sqlalchemy import and_, or_ from haystack.document_stores.utils import convert_date_to_rfc3339 @@ -125,6 +129,12 @@ class LogicalFilterClause(ABC): pass @abstractmethod + def convert_to_sql(self, meta_document_orm): + """ + Converts the LogicalFilterClause instance to an SQL filter. + """ + pass + def convert_to_weaviate(self): """ Converts the LogicalFilterClause instance to a Weaviate filter. @@ -205,6 +215,13 @@ class ComparisonOperation(ABC): """ pass + @abstractmethod + def convert_to_sql(self, meta_document_orm): + """ + Converts the ComparisonOperation instance to an SQL filter. + """ + pass + @abstractmethod def convert_to_weaviate(self): """ @@ -266,6 +283,13 @@ class NotOperation(LogicalFilterClause): conditions = self._merge_es_range_queries(conditions) return {"bool": {"must_not": conditions}} + def convert_to_sql(self, meta_document_orm): + conditions = [ + meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm)) + for condition in self.conditions + ] + return select(meta_document_orm.document_id).filter(~or_(*conditions)) + def convert_to_weaviate(self) -> Dict[str, Union[str, int, float, bool, List[Dict]]]: conditions = [condition.invert().convert_to_weaviate() for condition in self.conditions] if len(conditions) > 1: @@ -295,6 +319,13 @@ class AndOperation(LogicalFilterClause): conditions = self._merge_es_range_queries(conditions) return {"bool": {"must": conditions}} + def convert_to_sql(self, meta_document_orm): + conditions = [ + meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm)) + for condition in self.conditions + ] + return select(meta_document_orm.document_id).filter(and_(*conditions)) + def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]: conditions = [condition.convert_to_weaviate() for condition in self.conditions] return {"operator": "And", "operands": conditions} @@ -313,6 +344,13 @@ class OrOperation(LogicalFilterClause): conditions = self._merge_es_range_queries(conditions) return {"bool": {"should": conditions}} + def convert_to_sql(self, meta_document_orm): + conditions = [ + meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm)) + for condition in self.conditions + ] + return select(meta_document_orm.document_id).filter(or_(*conditions)) + def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]: conditions = [condition.convert_to_weaviate() for condition in self.conditions] return {"operator": "Or", "operands": conditions} @@ -330,6 +368,11 @@ class EqOperation(ComparisonOperation): assert not isinstance(self.comparison_value, list), "Use '$in' operation for lists as comparison values." return {"term": {self.field_name: self.comparison_value}} + def convert_to_sql(self, meta_document_orm): + return select([meta_document_orm.document_id]).where( + meta_document_orm.name == self.field_name, meta_document_orm.value == self.comparison_value + ) + def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]: comp_value_type, comp_value = self._get_weaviate_datatype() return {"path": [self.field_name], "operator": "Equal", comp_value_type: comp_value} @@ -347,6 +390,11 @@ class InOperation(ComparisonOperation): assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list." return {"terms": {self.field_name: self.comparison_value}} + def convert_to_sql(self, meta_document_orm): + return select([meta_document_orm.document_id]).where( + meta_document_orm.name == self.field_name, meta_document_orm.value.in_(self.comparison_value) + ) + def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]: filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "Or", "operands": []} assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list." @@ -372,6 +420,11 @@ class NeOperation(ComparisonOperation): assert not isinstance(self.comparison_value, list), "Use '$nin' operation for lists as comparison values." return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}} + def convert_to_sql(self, meta_document_orm): + return select([meta_document_orm.document_id]).where( + meta_document_orm.name == self.field_name, meta_document_orm.value != self.comparison_value + ) + def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]: comp_value_type, comp_value = self._get_weaviate_datatype() return {"path": [self.field_name], "operator": "NotEqual", comp_value_type: comp_value} @@ -389,6 +442,11 @@ class NinOperation(ComparisonOperation): assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list." return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}} + def convert_to_sql(self, meta_document_orm): + return select([meta_document_orm.document_id]).where( + meta_document_orm.name == self.field_name, meta_document_orm.value.notin_(self.comparison_value) + ) + def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]: filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "And", "operands": []} assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list." @@ -414,6 +472,11 @@ class GtOperation(ComparisonOperation): assert not isinstance(self.comparison_value, list), "Comparison value for '$gt' operation must not be a list." return {"range": {self.field_name: {"gt": self.comparison_value}}} + def convert_to_sql(self, meta_document_orm): + return select([meta_document_orm.document_id]).where( + meta_document_orm.name == self.field_name, meta_document_orm.value > self.comparison_value + ) + def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]: comp_value_type, comp_value = self._get_weaviate_datatype() assert not isinstance(comp_value, list), "Comparison value for '$gt' operation must not be a list." @@ -432,6 +495,11 @@ class GteOperation(ComparisonOperation): assert not isinstance(self.comparison_value, list), "Comparison value for '$gte' operation must not be a list." return {"range": {self.field_name: {"gte": self.comparison_value}}} + def convert_to_sql(self, meta_document_orm): + return select([meta_document_orm.document_id]).where( + meta_document_orm.name == self.field_name, meta_document_orm.value >= self.comparison_value + ) + def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]: comp_value_type, comp_value = self._get_weaviate_datatype() assert not isinstance(comp_value, list), "Comparison value for '$gte' operation must not be a list." @@ -450,6 +518,11 @@ class LtOperation(ComparisonOperation): assert not isinstance(self.comparison_value, list), "Comparison value for '$lt' operation must not be a list." return {"range": {self.field_name: {"lt": self.comparison_value}}} + def convert_to_sql(self, meta_document_orm): + return select([meta_document_orm.document_id]).where( + meta_document_orm.name == self.field_name, meta_document_orm.value < self.comparison_value + ) + def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]: comp_value_type, comp_value = self._get_weaviate_datatype() assert not isinstance(comp_value, list), "Comparison value for '$lt' operation must not be a list." @@ -468,6 +541,11 @@ class LteOperation(ComparisonOperation): assert not isinstance(self.comparison_value, list), "Comparison value for '$lte' operation must not be a list." return {"range": {self.field_name: {"lte": self.comparison_value}}} + def convert_to_sql(self, meta_document_orm): + return select([meta_document_orm.document_id]).where( + meta_document_orm.name == self.field_name, meta_document_orm.value <= self.comparison_value + ) + def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]: comp_value_type, comp_value = self._get_weaviate_datatype() assert not isinstance(comp_value, list), "Comparison value for '$lte' operation must not be a list." diff --git a/haystack/document_stores/sql.py b/haystack/document_stores/sql.py index 39fa07a85..41177f592 100644 --- a/haystack/document_stores/sql.py +++ b/haystack/document_stores/sql.py @@ -31,6 +31,8 @@ except (ImportError, ModuleNotFoundError) as ie: from haystack.schema import Document, Label, Answer from haystack.document_stores.base import BaseDocumentStore +from haystack.document_stores.filter_utils import LogicalFilterClause + logger = logging.getLogger(__name__) Base = declarative_base() # type: Any @@ -294,11 +296,9 @@ class SQLDocumentStore(BaseDocumentStore): ).filter_by(index=index) if filters: - for key, values in filters.items(): - documents_query = documents_query.join(MetaDocumentORM, aliased=True).filter( - MetaDocumentORM.name == key, - MetaDocumentORM.value.in_(values), - ) + parsed_filter = LogicalFilterClause.parse(filters) + select_ids = parsed_filter.convert_to_sql(MetaDocumentORM) + documents_query = documents_query.filter(DocumentORM.id.in_(select_ids)) if only_documents_without_embedding: documents_query = documents_query.filter(DocumentORM.vector_id.is_(None)) diff --git a/test/test_document_store.py b/test/test_document_store.py index 1d817d745..b1163b539 100644 --- a/test/test_document_store.py +++ b/test/test_document_store.py @@ -216,7 +216,7 @@ def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs) assert len(documents) == 0 -@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "weaviate"], indirect=True) +@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "sql", "weaviate"], indirect=True) def test_extended_filter(document_store_with_docs): # Test comparison operators individually documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "test1"}})