mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-11 07:13:45 +00:00
Extend meta data support for SQLDocumentStore (#2199)
* update remaining occurences of get_connection * first commit to add extended metadata filtering support to sql * fix bugs * adding sql doc store instead of milvus * removing updates to milvus2 from other PR * fixing not operator * delete left over line * remove unnecessary import * Update Documentation & Code Style * fix circular import * fix left over merge conflict * Update Documentation & Code Style * fix abstract class Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
8de1aa3e43
commit
116fe2db26
@ -1,6 +1,10 @@
|
|||||||
from typing import Union, List, Dict, Optional, Tuple
|
from typing import Union, List, Dict, Optional, Tuple
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
from sqlalchemy.sql import select
|
||||||
|
from sqlalchemy import and_, or_
|
||||||
|
|
||||||
from haystack.document_stores.utils import convert_date_to_rfc3339
|
from haystack.document_stores.utils import convert_date_to_rfc3339
|
||||||
|
|
||||||
@ -125,6 +129,12 @@ class LogicalFilterClause(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
"""
|
||||||
|
Converts the LogicalFilterClause instance to an SQL filter.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def convert_to_weaviate(self):
|
def convert_to_weaviate(self):
|
||||||
"""
|
"""
|
||||||
Converts the LogicalFilterClause instance to a Weaviate filter.
|
Converts the LogicalFilterClause instance to a Weaviate filter.
|
||||||
@ -205,6 +215,13 @@ class ComparisonOperation(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
"""
|
||||||
|
Converts the ComparisonOperation instance to an SQL filter.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_to_weaviate(self):
|
def convert_to_weaviate(self):
|
||||||
"""
|
"""
|
||||||
@ -266,6 +283,13 @@ class NotOperation(LogicalFilterClause):
|
|||||||
conditions = self._merge_es_range_queries(conditions)
|
conditions = self._merge_es_range_queries(conditions)
|
||||||
return {"bool": {"must_not": conditions}}
|
return {"bool": {"must_not": conditions}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
conditions = [
|
||||||
|
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
|
||||||
|
for condition in self.conditions
|
||||||
|
]
|
||||||
|
return select(meta_document_orm.document_id).filter(~or_(*conditions))
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[str, int, float, bool, List[Dict]]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[str, int, float, bool, List[Dict]]]:
|
||||||
conditions = [condition.invert().convert_to_weaviate() for condition in self.conditions]
|
conditions = [condition.invert().convert_to_weaviate() for condition in self.conditions]
|
||||||
if len(conditions) > 1:
|
if len(conditions) > 1:
|
||||||
@ -295,6 +319,13 @@ class AndOperation(LogicalFilterClause):
|
|||||||
conditions = self._merge_es_range_queries(conditions)
|
conditions = self._merge_es_range_queries(conditions)
|
||||||
return {"bool": {"must": conditions}}
|
return {"bool": {"must": conditions}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
conditions = [
|
||||||
|
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
|
||||||
|
for condition in self.conditions
|
||||||
|
]
|
||||||
|
return select(meta_document_orm.document_id).filter(and_(*conditions))
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
||||||
conditions = [condition.convert_to_weaviate() for condition in self.conditions]
|
conditions = [condition.convert_to_weaviate() for condition in self.conditions]
|
||||||
return {"operator": "And", "operands": conditions}
|
return {"operator": "And", "operands": conditions}
|
||||||
@ -313,6 +344,13 @@ class OrOperation(LogicalFilterClause):
|
|||||||
conditions = self._merge_es_range_queries(conditions)
|
conditions = self._merge_es_range_queries(conditions)
|
||||||
return {"bool": {"should": conditions}}
|
return {"bool": {"should": conditions}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
conditions = [
|
||||||
|
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
|
||||||
|
for condition in self.conditions
|
||||||
|
]
|
||||||
|
return select(meta_document_orm.document_id).filter(or_(*conditions))
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
||||||
conditions = [condition.convert_to_weaviate() for condition in self.conditions]
|
conditions = [condition.convert_to_weaviate() for condition in self.conditions]
|
||||||
return {"operator": "Or", "operands": conditions}
|
return {"operator": "Or", "operands": conditions}
|
||||||
@ -330,6 +368,11 @@ class EqOperation(ComparisonOperation):
|
|||||||
assert not isinstance(self.comparison_value, list), "Use '$in' operation for lists as comparison values."
|
assert not isinstance(self.comparison_value, list), "Use '$in' operation for lists as comparison values."
|
||||||
return {"term": {self.field_name: self.comparison_value}}
|
return {"term": {self.field_name: self.comparison_value}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
return select([meta_document_orm.document_id]).where(
|
||||||
|
meta_document_orm.name == self.field_name, meta_document_orm.value == self.comparison_value
|
||||||
|
)
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]:
|
||||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||||
return {"path": [self.field_name], "operator": "Equal", comp_value_type: comp_value}
|
return {"path": [self.field_name], "operator": "Equal", comp_value_type: comp_value}
|
||||||
@ -347,6 +390,11 @@ class InOperation(ComparisonOperation):
|
|||||||
assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list."
|
assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list."
|
||||||
return {"terms": {self.field_name: self.comparison_value}}
|
return {"terms": {self.field_name: self.comparison_value}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
return select([meta_document_orm.document_id]).where(
|
||||||
|
meta_document_orm.name == self.field_name, meta_document_orm.value.in_(self.comparison_value)
|
||||||
|
)
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
||||||
filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "Or", "operands": []}
|
filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "Or", "operands": []}
|
||||||
assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list."
|
assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list."
|
||||||
@ -372,6 +420,11 @@ class NeOperation(ComparisonOperation):
|
|||||||
assert not isinstance(self.comparison_value, list), "Use '$nin' operation for lists as comparison values."
|
assert not isinstance(self.comparison_value, list), "Use '$nin' operation for lists as comparison values."
|
||||||
return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}}
|
return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
return select([meta_document_orm.document_id]).where(
|
||||||
|
meta_document_orm.name == self.field_name, meta_document_orm.value != self.comparison_value
|
||||||
|
)
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]:
|
||||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||||
return {"path": [self.field_name], "operator": "NotEqual", comp_value_type: comp_value}
|
return {"path": [self.field_name], "operator": "NotEqual", comp_value_type: comp_value}
|
||||||
@ -389,6 +442,11 @@ class NinOperation(ComparisonOperation):
|
|||||||
assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list."
|
assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list."
|
||||||
return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}}
|
return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
return select([meta_document_orm.document_id]).where(
|
||||||
|
meta_document_orm.name == self.field_name, meta_document_orm.value.notin_(self.comparison_value)
|
||||||
|
)
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
|
||||||
filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "And", "operands": []}
|
filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "And", "operands": []}
|
||||||
assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list."
|
assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list."
|
||||||
@ -414,6 +472,11 @@ class GtOperation(ComparisonOperation):
|
|||||||
assert not isinstance(self.comparison_value, list), "Comparison value for '$gt' operation must not be a list."
|
assert not isinstance(self.comparison_value, list), "Comparison value for '$gt' operation must not be a list."
|
||||||
return {"range": {self.field_name: {"gt": self.comparison_value}}}
|
return {"range": {self.field_name: {"gt": self.comparison_value}}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
return select([meta_document_orm.document_id]).where(
|
||||||
|
meta_document_orm.name == self.field_name, meta_document_orm.value > self.comparison_value
|
||||||
|
)
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
||||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||||
assert not isinstance(comp_value, list), "Comparison value for '$gt' operation must not be a list."
|
assert not isinstance(comp_value, list), "Comparison value for '$gt' operation must not be a list."
|
||||||
@ -432,6 +495,11 @@ class GteOperation(ComparisonOperation):
|
|||||||
assert not isinstance(self.comparison_value, list), "Comparison value for '$gte' operation must not be a list."
|
assert not isinstance(self.comparison_value, list), "Comparison value for '$gte' operation must not be a list."
|
||||||
return {"range": {self.field_name: {"gte": self.comparison_value}}}
|
return {"range": {self.field_name: {"gte": self.comparison_value}}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
return select([meta_document_orm.document_id]).where(
|
||||||
|
meta_document_orm.name == self.field_name, meta_document_orm.value >= self.comparison_value
|
||||||
|
)
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
||||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||||
assert not isinstance(comp_value, list), "Comparison value for '$gte' operation must not be a list."
|
assert not isinstance(comp_value, list), "Comparison value for '$gte' operation must not be a list."
|
||||||
@ -450,6 +518,11 @@ class LtOperation(ComparisonOperation):
|
|||||||
assert not isinstance(self.comparison_value, list), "Comparison value for '$lt' operation must not be a list."
|
assert not isinstance(self.comparison_value, list), "Comparison value for '$lt' operation must not be a list."
|
||||||
return {"range": {self.field_name: {"lt": self.comparison_value}}}
|
return {"range": {self.field_name: {"lt": self.comparison_value}}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
return select([meta_document_orm.document_id]).where(
|
||||||
|
meta_document_orm.name == self.field_name, meta_document_orm.value < self.comparison_value
|
||||||
|
)
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
||||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||||
assert not isinstance(comp_value, list), "Comparison value for '$lt' operation must not be a list."
|
assert not isinstance(comp_value, list), "Comparison value for '$lt' operation must not be a list."
|
||||||
@ -468,6 +541,11 @@ class LteOperation(ComparisonOperation):
|
|||||||
assert not isinstance(self.comparison_value, list), "Comparison value for '$lte' operation must not be a list."
|
assert not isinstance(self.comparison_value, list), "Comparison value for '$lte' operation must not be a list."
|
||||||
return {"range": {self.field_name: {"lte": self.comparison_value}}}
|
return {"range": {self.field_name: {"lte": self.comparison_value}}}
|
||||||
|
|
||||||
|
def convert_to_sql(self, meta_document_orm):
|
||||||
|
return select([meta_document_orm.document_id]).where(
|
||||||
|
meta_document_orm.name == self.field_name, meta_document_orm.value <= self.comparison_value
|
||||||
|
)
|
||||||
|
|
||||||
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
|
||||||
comp_value_type, comp_value = self._get_weaviate_datatype()
|
comp_value_type, comp_value = self._get_weaviate_datatype()
|
||||||
assert not isinstance(comp_value, list), "Comparison value for '$lte' operation must not be a list."
|
assert not isinstance(comp_value, list), "Comparison value for '$lte' operation must not be a list."
|
||||||
|
|||||||
@ -31,6 +31,8 @@ except (ImportError, ModuleNotFoundError) as ie:
|
|||||||
from haystack.schema import Document, Label, Answer
|
from haystack.schema import Document, Label, Answer
|
||||||
from haystack.document_stores.base import BaseDocumentStore
|
from haystack.document_stores.base import BaseDocumentStore
|
||||||
|
|
||||||
|
from haystack.document_stores.filter_utils import LogicalFilterClause
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
Base = declarative_base() # type: Any
|
Base = declarative_base() # type: Any
|
||||||
@ -294,11 +296,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
).filter_by(index=index)
|
).filter_by(index=index)
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
for key, values in filters.items():
|
parsed_filter = LogicalFilterClause.parse(filters)
|
||||||
documents_query = documents_query.join(MetaDocumentORM, aliased=True).filter(
|
select_ids = parsed_filter.convert_to_sql(MetaDocumentORM)
|
||||||
MetaDocumentORM.name == key,
|
documents_query = documents_query.filter(DocumentORM.id.in_(select_ids))
|
||||||
MetaDocumentORM.value.in_(values),
|
|
||||||
)
|
|
||||||
|
|
||||||
if only_documents_without_embedding:
|
if only_documents_without_embedding:
|
||||||
documents_query = documents_query.filter(DocumentORM.vector_id.is_(None))
|
documents_query = documents_query.filter(DocumentORM.vector_id.is_(None))
|
||||||
|
|||||||
@ -216,7 +216,7 @@ def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs)
|
|||||||
assert len(documents) == 0
|
assert len(documents) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "weaviate"], indirect=True)
|
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "sql", "weaviate"], indirect=True)
|
||||||
def test_extended_filter(document_store_with_docs):
|
def test_extended_filter(document_store_with_docs):
|
||||||
# Test comparison operators individually
|
# Test comparison operators individually
|
||||||
documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "test1"}})
|
documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "test1"}})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user