Extend meta data support for SQLDocumentStore (#2199)

* update remaining occurences of get_connection

* first commit to add extended metadata filtering support to sql

* fix bugs

* adding sql doc store instead of milvus

* removing updates to milvus2 from other PR

* fixing not operator

* delete left over line

* remove unnecessary import

* Update Documentation & Code Style

* fix circular import

* fix left over merge conflict

* Update Documentation & Code Style

* fix abstract class

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
MichelBartels 2022-02-21 20:40:32 +01:00 committed by GitHub
parent 8de1aa3e43
commit 116fe2db26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 6 deletions

View File

@ -1,6 +1,10 @@
from typing import Union, List, Dict, Optional, Tuple
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import reduce
from sqlalchemy.sql import select
from sqlalchemy import and_, or_
from haystack.document_stores.utils import convert_date_to_rfc3339
@ -125,6 +129,12 @@ class LogicalFilterClause(ABC):
pass
@abstractmethod
def convert_to_sql(self, meta_document_orm):
"""
Converts the LogicalFilterClause instance to an SQL filter.
"""
pass
def convert_to_weaviate(self):
"""
Converts the LogicalFilterClause instance to a Weaviate filter.
@ -205,6 +215,13 @@ class ComparisonOperation(ABC):
"""
pass
@abstractmethod
def convert_to_sql(self, meta_document_orm):
"""
Converts the ComparisonOperation instance to an SQL filter.
"""
pass
@abstractmethod
def convert_to_weaviate(self):
"""
@ -266,6 +283,13 @@ class NotOperation(LogicalFilterClause):
conditions = self._merge_es_range_queries(conditions)
return {"bool": {"must_not": conditions}}
def convert_to_sql(self, meta_document_orm):
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
]
return select(meta_document_orm.document_id).filter(~or_(*conditions))
def convert_to_weaviate(self) -> Dict[str, Union[str, int, float, bool, List[Dict]]]:
conditions = [condition.invert().convert_to_weaviate() for condition in self.conditions]
if len(conditions) > 1:
@ -295,6 +319,13 @@ class AndOperation(LogicalFilterClause):
conditions = self._merge_es_range_queries(conditions)
return {"bool": {"must": conditions}}
def convert_to_sql(self, meta_document_orm):
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
]
return select(meta_document_orm.document_id).filter(and_(*conditions))
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
conditions = [condition.convert_to_weaviate() for condition in self.conditions]
return {"operator": "And", "operands": conditions}
@ -313,6 +344,13 @@ class OrOperation(LogicalFilterClause):
conditions = self._merge_es_range_queries(conditions)
return {"bool": {"should": conditions}}
def convert_to_sql(self, meta_document_orm):
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
]
return select(meta_document_orm.document_id).filter(or_(*conditions))
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
conditions = [condition.convert_to_weaviate() for condition in self.conditions]
return {"operator": "Or", "operands": conditions}
@ -330,6 +368,11 @@ class EqOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Use '$in' operation for lists as comparison values."
return {"term": {self.field_name: self.comparison_value}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value == self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]:
comp_value_type, comp_value = self._get_weaviate_datatype()
return {"path": [self.field_name], "operator": "Equal", comp_value_type: comp_value}
@ -347,6 +390,11 @@ class InOperation(ComparisonOperation):
assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list."
return {"terms": {self.field_name: self.comparison_value}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value.in_(self.comparison_value)
)
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "Or", "operands": []}
assert isinstance(self.comparison_value, list), "'$in' operation requires comparison value to be a list."
@ -372,6 +420,11 @@ class NeOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Use '$nin' operation for lists as comparison values."
return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value != self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, int, float, bool]]:
comp_value_type, comp_value = self._get_weaviate_datatype()
return {"path": [self.field_name], "operator": "NotEqual", comp_value_type: comp_value}
@ -389,6 +442,11 @@ class NinOperation(ComparisonOperation):
assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list."
return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value.notin_(self.comparison_value)
)
def convert_to_weaviate(self) -> Dict[str, Union[str, List[Dict]]]:
filter_dict: Dict[str, Union[str, List[Dict]]] = {"operator": "And", "operands": []}
assert isinstance(self.comparison_value, list), "'$nin' operation requires comparison value to be a list."
@ -414,6 +472,11 @@ class GtOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Comparison value for '$gt' operation must not be a list."
return {"range": {self.field_name: {"gt": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value > self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
comp_value_type, comp_value = self._get_weaviate_datatype()
assert not isinstance(comp_value, list), "Comparison value for '$gt' operation must not be a list."
@ -432,6 +495,11 @@ class GteOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Comparison value for '$gte' operation must not be a list."
return {"range": {self.field_name: {"gte": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value >= self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
comp_value_type, comp_value = self._get_weaviate_datatype()
assert not isinstance(comp_value, list), "Comparison value for '$gte' operation must not be a list."
@ -450,6 +518,11 @@ class LtOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Comparison value for '$lt' operation must not be a list."
return {"range": {self.field_name: {"lt": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value < self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
comp_value_type, comp_value = self._get_weaviate_datatype()
assert not isinstance(comp_value, list), "Comparison value for '$lt' operation must not be a list."
@ -468,6 +541,11 @@ class LteOperation(ComparisonOperation):
assert not isinstance(self.comparison_value, list), "Comparison value for '$lte' operation must not be a list."
return {"range": {self.field_name: {"lte": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value <= self.comparison_value
)
def convert_to_weaviate(self) -> Dict[str, Union[List[str], str, float, int]]:
comp_value_type, comp_value = self._get_weaviate_datatype()
assert not isinstance(comp_value, list), "Comparison value for '$lte' operation must not be a list."

View File

@ -31,6 +31,8 @@ except (ImportError, ModuleNotFoundError) as ie:
from haystack.schema import Document, Label, Answer
from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.filter_utils import LogicalFilterClause
logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any
@ -294,11 +296,9 @@ class SQLDocumentStore(BaseDocumentStore):
).filter_by(index=index)
if filters:
for key, values in filters.items():
documents_query = documents_query.join(MetaDocumentORM, aliased=True).filter(
MetaDocumentORM.name == key,
MetaDocumentORM.value.in_(values),
)
parsed_filter = LogicalFilterClause.parse(filters)
select_ids = parsed_filter.convert_to_sql(MetaDocumentORM)
documents_query = documents_query.filter(DocumentORM.id.in_(select_ids))
if only_documents_without_embedding:
documents_query = documents_query.filter(DocumentORM.vector_id.is_(None))

View File

@ -216,7 +216,7 @@ def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs)
assert len(documents) == 0
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "weaviate"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "sql", "weaviate"], indirect=True)
def test_extended_filter(document_store_with_docs):
# Test comparison operators individually
documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "test1"}})