mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 18:29:32 +00:00
Extend meta data support for SQLDocumentStore (#2199)
* update remaining occurences of get_connection * first commit to add extended metadata filtering support to sql * fix bugs * adding sql doc store instead of milvus * removing updates to milvus2 from other PR * fixing not operator * delete left over line * remove unnecessary import * Update Documentation & Code Style * fix circular import * fix left over merge conflict * Update Documentation & Code Style * fix abstract class Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
8de1aa3e43
commit
116fe2db26
@ -1,6 +1,10 @@
|
||||
from typing import Union, List, Dict, Optional, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from functools import reduce
|
||||
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from haystack.document_stores.utils import convert_date_to_rfc3339
|
||||
|
||||
@ -125,6 +129,12 @@ class LogicalFilterClause(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
"""
|
||||
Converts the LogicalFilterClause instance to an SQL filter.
|
||||
"""
|
||||
pass
|
||||
|
||||
def convert_to_weaviate(self):
|
||||
"""
|
||||
Converts the LogicalFilterClause instance to a Weaviate filter.
|
||||
@ -205,6 +215,13 @@ class ComparisonOperation(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
"""
|
||||
Converts the ComparisonOperation instance to an SQL filter.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_to_weaviate(self):
|
||||
"""
|
||||
@ -266,6 +283,13 @@ class NotOperation(LogicalFilterClause):
|
||||
conditions = self._merge_es_range_queries(conditions)
|
||||
return {"bool": {"must_not": conditions}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
conditions = [
|
||||
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
|
||||
for condition in self.conditions
|
||||
]
|
||||
return select(meta_document_orm.document_id).filter(~or_(*conditions))
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[str, int, float, bool, List[Dict]]]:
|
||||
conditions = [condition.invert().convert_to_weaviate() for condition in self.conditions]
|
||||
if len(conditions) > 1:
|
||||
@ -295,6 +319,13 @@ class AndOperation(LogicalFilterClause):
|
||||
conditions = self._merge_es_range_queries(conditions)
|
||||
return {"bool": {"must": conditions}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
conditions = [
|
||||
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
|
||||
for condition in self.conditions
|
||||
]
|
||||
return select(meta_document_orm.document_id).filter(and_(*conditions))
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
||||
conditions = [condition.convert_to_weaviate() for condition in self.conditions]
|
||||
return {"operator": "And", "operands": conditions}
|
||||
@ -313,6 +344,13 @@ class OrOperation(LogicalFilterClause):
|
||||
conditions = self._merge_es_range_queries(conditions)
|
||||
return {"bool": {"should": conditions}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
conditions = [
|
||||
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
|
||||
for condition in self.conditions
|
||||
]
|
||||
return select(meta_document_orm.document_id).filter(or_(*conditions))
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
||||
conditions = [condition.convert_to_weaviate() for condition in self.conditions]
|
||||
return {"operator": "Or", "operands": conditions}
|
||||
@ -330,6 +368,11 @@ class EqOperation(ComparisonOperation):
|
||||
assert not isinstance(self.comparison_value, list), "Use '$in' operation for lists as comparison values."
|
||||
return {"term": {self.field_name: self.comparison_value}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value == self.comparison_value
|
||||
)
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]:
|
||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||
return {"path": [self.field_name], "operator": "Equal", comp_value_type: comp_value}
|
||||
@ -347,6 +390,11 @@ class InOperation(ComparisonOperation):
|
||||
assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list."
|
||||
return {"terms": {self.field_name: self.comparison_value}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value.in_(self.comparison_value)
|
||||
)
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
||||
filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "Or", "operands": []}
|
||||
assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list."
|
||||
@ -372,6 +420,11 @@ class NeOperation(ComparisonOperation):
|
||||
assert not isinstance(self.comparison_value, list), "Use '$nin' operation for lists as comparison values."
|
||||
return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value != self.comparison_value
|
||||
)
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]:
|
||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||
return {"path": [self.field_name], "operator": "NotEqual", comp_value_type: comp_value}
|
||||
@ -389,6 +442,11 @@ class NinOperation(ComparisonOperation):
|
||||
assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list."
|
||||
return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value.notin_(self.comparison_value)
|
||||
)
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
||||
filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "And", "operands": []}
|
||||
assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list."
|
||||
@ -414,6 +472,11 @@ class GtOperation(ComparisonOperation):
|
||||
assert not isinstance(self.comparison_value, list), "Comparison value for '$gt' operation must not be a list."
|
||||
return {"range": {self.field_name: {"gt": self.comparison_value}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value > self.comparison_value
|
||||
)
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||
assert not isinstance(comp_value, list), "Comparison value for '$gt' operation must not be a list."
|
||||
@ -432,6 +495,11 @@ class GteOperation(ComparisonOperation):
|
||||
assert not isinstance(self.comparison_value, list), "Comparison value for '$gte' operation must not be a list."
|
||||
return {"range": {self.field_name: {"gte": self.comparison_value}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value >= self.comparison_value
|
||||
)
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||
assert not isinstance(comp_value, list), "Comparison value for '$gte' operation must not be a list."
|
||||
@ -450,6 +518,11 @@ class LtOperation(ComparisonOperation):
|
||||
assert not isinstance(self.comparison_value, list), "Comparison value for '$lt' operation must not be a list."
|
||||
return {"range": {self.field_name: {"lt": self.comparison_value}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value < self.comparison_value
|
||||
)
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||
assert not isinstance(comp_value, list), "Comparison value for '$lt' operation must not be a list."
|
||||
@ -468,6 +541,11 @@ class LteOperation(ComparisonOperation):
|
||||
assert not isinstance(self.comparison_value, list), "Comparison value for '$lte' operation must not be a list."
|
||||
return {"range": {self.field_name: {"lte": self.comparison_value}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value <= self.comparison_value
|
||||
)
|
||||
|
||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||
assert not isinstance(comp_value, list), "Comparison value for '$lte' operation must not be a list."
|
||||
|
||||
@ -31,6 +31,8 @@ except (ImportError, ModuleNotFoundError) as ie:
|
||||
from haystack.schema import Document, Label, Answer
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
|
||||
from haystack.document_stores.filter_utils import LogicalFilterClause
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Base = declarative_base() # type: Any
|
||||
@ -294,11 +296,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
).filter_by(index=index)
|
||||
|
||||
if filters:
|
||||
for key, values in filters.items():
|
||||
documents_query = documents_query.join(MetaDocumentORM, aliased=True).filter(
|
||||
MetaDocumentORM.name == key,
|
||||
MetaDocumentORM.value.in_(values),
|
||||
)
|
||||
parsed_filter = LogicalFilterClause.parse(filters)
|
||||
select_ids = parsed_filter.convert_to_sql(MetaDocumentORM)
|
||||
documents_query = documents_query.filter(DocumentORM.id.in_(select_ids))
|
||||
|
||||
if only_documents_without_embedding:
|
||||
documents_query = documents_query.filter(DocumentORM.vector_id.is_(None))
|
||||
|
||||
@ -216,7 +216,7 @@ def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs)
|
||||
assert len(documents) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "weaviate"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "sql", "weaviate"], indirect=True)
|
||||
def test_extended_filter(document_store_with_docs):
|
||||
# Test comparison operators individually
|
||||
documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "test1"}})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user