Extend meta data support for SQLDocumentStore (#2199)

* update remaining occurences of get_connection

* first commit to add extended metadata filtering support to sql

* fix bugs

* adding sql doc store instead of milvus

* removing updates to milvus2 from other PR

* fixing not operator

* delete left over line

* remove unnecessary import

* Update Documentation & Code Style

* fix circular import

* fix left over merge conflict

* Update Documentation & Code Style

* fix abstract class

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
MichelBartels 2022-02-21 20:40:32 +01:00 committed by GitHub
parent 8de1aa3e43
commit 116fe2db26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 6 deletions

View File

@ -1,6 +1,10 @@
from typing import Union, List, Dict, Optional, Tuple from typing import Union, List, Dict, Optional, Tuple
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from functools import reduce
from sqlalchemy.sql import select
from sqlalchemy import and_, or_
from haystack.document_stores.utils import convert_date_to_rfc3339 from haystack.document_stores.utils import convert_date_to_rfc3339
@ -125,6 +129,12 @@ class LogicalFilterClause(ABC):
pass pass
@abstractmethod @abstractmethod
def convert_to_sql(self, meta_document_orm):
"""
Converts the LogicalFilterClause instance to an SQL filter.
"""
pass
def convert_to_weaviate(self): def convert_to_weaviate(self):
""" """
Converts the LogicalFilterClause instance to a Weaviate filter. Converts the LogicalFilterClause instance to a Weaviate filter.
@ -205,6 +215,13 @@ class ComparisonOperation(ABC):
""" """
pass pass
@abstractmethod
def convert_to_sql(self, meta_document_orm):
"""
Converts the ComparisonOperation instance to an SQL filter.
"""
pass
@abstractmethod @abstractmethod
def convert_to_weaviate(self): def convert_to_weaviate(self):
""" """
@ -266,6 +283,13 @@ class NotOperation(LogicalFilterClause):
conditions = self._merge_es_range_queries(conditions) conditions = self._merge_es_range_queries(conditions)
return {"bool": {"must_not": conditions}} return {"bool": {"must_not": conditions}}
def convert_to_sql(self, meta_document_orm):
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
]
return select(meta_document_orm.document_id).filter(~or_(*conditions))
def convert_to_weaviate(self) -> Dict[str, Union[str, int, float, bool, List[Dict]]]: def convert_to_weaviate(self) -> Dict[str, Union[str, int, float, bool, List[Dict]]]:
conditions = [condition.invert().convert_to_weaviate() for condition in self.conditions] conditions = [condition.invert().convert_to_weaviate() for condition in self.conditions]
if len(conditions) > 1: if len(conditions) > 1:
@ -295,6 +319,13 @@ class AndOperation(LogicalFilterClause):
conditions = self._merge_es_range_queries(conditions) conditions = self._merge_es_range_queries(conditions)
return {"bool": {"must": conditions}} return {"bool": {"must": conditions}}
def convert_to_sql(self, meta_document_orm):
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
]
return select(meta_document_orm.document_id).filter(and_(*conditions))
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]: def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
conditions = [condition.convert_to_weaviate() for condition in self.conditions] conditions = [condition.convert_to_weaviate() for condition in self.conditions]
return {"operator": "And", "operands": conditions} return {"operator": "And", "operands": conditions}
@ -313,6 +344,13 @@ class OrOperation(LogicalFilterClause):
conditions = self._merge_es_range_queries(conditions) conditions = self._merge_es_range_queries(conditions)
return {"bool": {"should": conditions}} return {"bool": {"should": conditions}}
def convert_to_sql(self, meta_document_orm):
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
]
return select(meta_document_orm.document_id).filter(or_(*conditions))
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]: def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
conditions = [condition.convert_to_weaviate() for condition in self.conditions] conditions = [condition.convert_to_weaviate() for condition in self.conditions]
return {"operator": "Or", "operands": conditions} return {"operator": "Or", "operands": conditions}
@ -330,6 +368,11 @@ class EqOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Use '$in' operation for lists as comparison values." assert not isinstance(self.comparison_value, list), "Use '$in' operation for lists as comparison values."
return {"term": {self.field_name: self.comparison_value}} return {"term": {self.field_name: self.comparison_value}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value == self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]: def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]:
comp_value_type, comp_value = self._get_weaviate_datatype() comp_value_type, comp_value = self._get_weaviate_datatype()
return {"path": [self.field_name], "operator": "Equal", comp_value_type: comp_value} return {"path": [self.field_name], "operator": "Equal", comp_value_type: comp_value}
@ -347,6 +390,11 @@ class InOperation(ComparisonOperation):
assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list." assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list."
return {"terms": {self.field_name: self.comparison_value}} return {"terms": {self.field_name: self.comparison_value}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value.in_(self.comparison_value)
)
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]: def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "Or", "operands": []} filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "Or", "operands": []}
assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list." assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list."
@ -372,6 +420,11 @@ class NeOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Use '$nin' operation for lists as comparison values." assert not isinstance(self.comparison_value, list), "Use '$nin' operation for lists as comparison values."
return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}} return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value != self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]: def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]:
comp_value_type, comp_value = self._get_weaviate_datatype() comp_value_type, comp_value = self._get_weaviate_datatype()
return {"path": [self.field_name], "operator": "NotEqual", comp_value_type: comp_value} return {"path": [self.field_name], "operator": "NotEqual", comp_value_type: comp_value}
@ -389,6 +442,11 @@ class NinOperation(ComparisonOperation):
assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list." assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list."
return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}} return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value.notin_(self.comparison_value)
)
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]: def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "And", "operands": []} filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "And", "operands": []}
assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list." assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list."
@ -414,6 +472,11 @@ class GtOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Comparison value for '$gt' operation must not be a list." assert not isinstance(self.comparison_value, list), "Comparison value for '$gt' operation must not be a list."
return {"range": {self.field_name: {"gt": self.comparison_value}}} return {"range": {self.field_name: {"gt": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value > self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]: def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
comp_value_type, comp_value = self._get_weaviate_datatype() comp_value_type, comp_value = self._get_weaviate_datatype()
assert not isinstance(comp_value, list), "Comparison value for '$gt' operation must not be a list." assert not isinstance(comp_value, list), "Comparison value for '$gt' operation must not be a list."
@ -432,6 +495,11 @@ class GteOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Comparison value for '$gte' operation must not be a list." assert not isinstance(self.comparison_value, list), "Comparison value for '$gte' operation must not be a list."
return {"range": {self.field_name: {"gte": self.comparison_value}}} return {"range": {self.field_name: {"gte": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value >= self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]: def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
comp_value_type, comp_value = self._get_weaviate_datatype() comp_value_type, comp_value = self._get_weaviate_datatype()
assert not isinstance(comp_value, list), "Comparison value for '$gte' operation must not be a list." assert not isinstance(comp_value, list), "Comparison value for '$gte' operation must not be a list."
@ -450,6 +518,11 @@ class LtOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Comparison value for '$lt' operation must not be a list." assert not isinstance(self.comparison_value, list), "Comparison value for '$lt' operation must not be a list."
return {"range": {self.field_name: {"lt": self.comparison_value}}} return {"range": {self.field_name: {"lt": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value < self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]: def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
comp_value_type, comp_value = self._get_weaviate_datatype() comp_value_type, comp_value = self._get_weaviate_datatype()
assert not isinstance(comp_value, list), "Comparison value for '$lt' operation must not be a list." assert not isinstance(comp_value, list), "Comparison value for '$lt' operation must not be a list."
@ -468,6 +541,11 @@ class LteOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Comparison value for '$lte' operation must not be a list." assert not isinstance(self.comparison_value, list), "Comparison value for '$lte' operation must not be a list."
return {"range": {self.field_name: {"lte": self.comparison_value}}} return {"range": {self.field_name: {"lte": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value <= self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]: def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
comp_value_type, comp_value = self._get_weaviate_datatype() comp_value_type, comp_value = self._get_weaviate_datatype()
assert not isinstance(comp_value, list), "Comparison value for '$lte' operation must not be a list." assert not isinstance(comp_value, list), "Comparison value for '$lte' operation must not be a list."

View File

@ -31,6 +31,8 @@ except (ImportError, ModuleNotFoundError) as ie:
from haystack.schema import Document, Label, Answer from haystack.schema import Document, Label, Answer
from haystack.document_stores.base import BaseDocumentStore from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.filter_utils import LogicalFilterClause
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any Base = declarative_base() # type: Any
@ -294,11 +296,9 @@ class SQLDocumentStore(BaseDocumentStore):
).filter_by(index=index) ).filter_by(index=index)
if filters: if filters:
for key, values in filters.items(): parsed_filter = LogicalFilterClause.parse(filters)
documents_query = documents_query.join(MetaDocumentORM, aliased=True).filter( select_ids = parsed_filter.convert_to_sql(MetaDocumentORM)
MetaDocumentORM.name == key, documents_query = documents_query.filter(DocumentORM.id.in_(select_ids))
MetaDocumentORM.value.in_(values),
)
if only_documents_without_embedding: if only_documents_without_embedding:
documents_query = documents_query.filter(DocumentORM.vector_id.is_(None)) documents_query = documents_query.filter(DocumentORM.vector_id.is_(None))

View File

@ -216,7 +216,7 @@ def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs)
assert len(documents) == 0 assert len(documents) == 0
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "weaviate"], indirect=True) @pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "sql", "weaviate"], indirect=True)
def test_extended_filter(document_store_with_docs): def test_extended_filter(document_store_with_docs):
# Test comparison operators individually # Test comparison operators individually
documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "test1"}}) documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "test1"}})