mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 10:26:27 +00:00
feat: support dynamic filters in custom_query (#5427)
* support filters in custom_query * better tests * Update docstrings --------- Co-authored-by: agnieszka-m <amarzec13@gmail.com>
This commit is contained in:
parent
3f472995bb
commit
d46c84bb61
@ -706,11 +706,8 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
:param top_k: How many documents to return per query.
|
:param top_k: How many documents to return per query.
|
||||||
:param custom_query: query string containing a mandatory `${query}` placeholder.
|
:param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder.
|
||||||
|
|
||||||
Optionally, ES `filter` clause can be added where the values of `terms` are placeholders
|
|
||||||
that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_name_2}..)
|
|
||||||
names must match with the filters dict supplied in self.retrieve().
|
|
||||||
::
|
::
|
||||||
|
|
||||||
**An example custom_query:**
|
**An example custom_query:**
|
||||||
@ -723,17 +720,13 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
|||||||
"query": ${query}, // mandatory query placeholder
|
"query": ${query}, // mandatory query placeholder
|
||||||
"type": "most_fields",
|
"type": "most_fields",
|
||||||
"fields": ["content", "title"]}}],
|
"fields": ["content", "title"]}}],
|
||||||
"filter": [ // optional custom filters
|
"filter": ${filters} // optional filters placeholder
|
||||||
{"terms": {"year": ${years}}},
|
|
||||||
{"terms": {"quarter": ${quarters}}},
|
|
||||||
{"range": {"date": {"gte": ${date}}}}
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
**For this custom_query, a sample retrieve() could be:**
|
**For this custom_query, a sample `retrieve()` could be:**
|
||||||
```python
|
```python
|
||||||
self.retrieve(query="Why did the revenue increase?",
|
self.retrieve(query="Why did the revenue increase?",
|
||||||
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
||||||
|
@ -824,11 +824,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
:param top_k: How many documents to return per query.
|
:param top_k: How many documents to return per query.
|
||||||
:param custom_query: query string containing a mandatory `${query}` placeholder.
|
:param custom_query: query string containing a mandatory `${query}` and an optional `${filters}` placeholder.
|
||||||
|
|
||||||
Optionally, ES `filter` clause can be added where the values of `terms` are placeholders
|
|
||||||
that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_name_2}..)
|
|
||||||
names must match with the filters dict supplied in self.retrieve().
|
|
||||||
::
|
::
|
||||||
|
|
||||||
**An example custom_query:**
|
**An example custom_query:**
|
||||||
@ -841,11 +838,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
"query": ${query}, // mandatory query placeholder
|
"query": ${query}, // mandatory query placeholder
|
||||||
"type": "most_fields",
|
"type": "most_fields",
|
||||||
"fields": ["content", "title"]}}],
|
"fields": ["content", "title"]}}],
|
||||||
"filter": [ // optional custom filters
|
"filter": ${filters} // optional filters placeholder
|
||||||
{"terms": {"year": ${years}}},
|
|
||||||
{"terms": {"quarter": ${quarters}}},
|
|
||||||
{"range": {"date": {"gte": ${date}}}}
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -1084,16 +1077,13 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
|||||||
body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch()
|
body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch()
|
||||||
|
|
||||||
# Retrieval via custom query
|
# Retrieval via custom query
|
||||||
elif custom_query: # substitute placeholder for query and filters for the custom_query template string
|
elif custom_query:
|
||||||
template = Template(custom_query)
|
template = Template(custom_query)
|
||||||
# replace all "${query}" placeholder(s) with query
|
# substitute placeholder for query and filters for the custom_query template string
|
||||||
substitutions = {"query": json.dumps(query)}
|
substitutions = {
|
||||||
# For each filter we got passed, we'll try to find & replace the corresponding placeholder in the template
|
"query": json.dumps(query),
|
||||||
# Example: filters={"years":[2018]} => replaces {$years} in custom_query with '[2018]'
|
"filters": json.dumps(LogicalFilterClause.parse(filters or {}).convert_to_elasticsearch()),
|
||||||
if filters:
|
}
|
||||||
for key, values in filters.items():
|
|
||||||
values_str = json.dumps(values)
|
|
||||||
substitutions[key] = values_str
|
|
||||||
custom_query_json = template.substitute(**substitutions)
|
custom_query_json = template.substitute(**substitutions)
|
||||||
body = json.loads(custom_query_json)
|
body = json.loads(custom_query_json)
|
||||||
# add top_k
|
# add top_k
|
||||||
|
@ -33,11 +33,7 @@ class BM25Retriever(BaseRetriever):
|
|||||||
If true all query terms must be present in a document in order to be retrieved (i.e the AND operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy AND fish AND restaurant").
|
If true all query terms must be present in a document in order to be retrieved (i.e the AND operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy AND fish AND restaurant").
|
||||||
Otherwise at least one query term must be present in a document in order to be retrieved (i.e the OR operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy OR fish OR restaurant").
|
Otherwise at least one query term must be present in a document in order to be retrieved (i.e the OR operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy OR fish OR restaurant").
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
:param custom_query: query string as per Elasticsearch DSL with a mandatory query placeholder(query).
|
:param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder.
|
||||||
|
|
||||||
Optionally, ES `filter` clause can be added where the values of `terms` are placeholders
|
|
||||||
that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_name_2}..)
|
|
||||||
names must match with the filters dict supplied in self.retrieve().
|
|
||||||
|
|
||||||
**An example custom_query:**
|
**An example custom_query:**
|
||||||
|
|
||||||
@ -50,17 +46,13 @@ class BM25Retriever(BaseRetriever):
|
|||||||
"query": ${query}, // mandatory query placeholder
|
"query": ${query}, // mandatory query placeholder
|
||||||
"type": "most_fields",
|
"type": "most_fields",
|
||||||
"fields": ["content", "title"]}}],
|
"fields": ["content", "title"]}}],
|
||||||
"filter": [ // optional custom filters
|
"filter": ${filters} // optional filter placeholder
|
||||||
{"terms": {"year": ${years}}},
|
|
||||||
{"terms": {"quarter": ${quarters}}},
|
|
||||||
{"range": {"date": {"gte": ${date}}}}
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
**For this custom_query, a sample retrieve() could be:**
|
**For this custom_query, a sample `retrieve()` could be:**
|
||||||
|
|
||||||
```python
|
```python
|
||||||
self.retrieve(query="Why did the revenue increase?",
|
self.retrieve(query="Why did the revenue increase?",
|
||||||
|
@ -0,0 +1,82 @@
|
|||||||
|
---
|
||||||
|
upgrade:
|
||||||
|
- |
|
||||||
|
The Opensearch custom query syntax changes: the old filter placeholders for ``custom_query`` are no longer supported.
|
||||||
|
Replace your custom filter expressions with the new ``${filters}`` placeholder:
|
||||||
|
|
||||||
|
**Old:**
|
||||||
|
```python
|
||||||
|
retriever = BM25Retriever(
|
||||||
|
custom_query="""
|
||||||
|
{
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"should": [{"multi_match": {
|
||||||
|
"query": ${query},
|
||||||
|
"type": "most_fields",
|
||||||
|
"fields": ["content", "title"]}}
|
||||||
|
],
|
||||||
|
"filter": [
|
||||||
|
{"terms": {"year": ${years}}},
|
||||||
|
{"terms": {"quarter": ${quarters}}},
|
||||||
|
{"range": {"date": {"gte": ${date}}}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever.retrieve(
|
||||||
|
query="What is the meaning of life?",
|
||||||
|
filters={"years": [2019, 2020], "quarters": [1, 2, 3], "date": "2019-03-01"}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**New:**
|
||||||
|
```python
|
||||||
|
retriever = BM25Retriever(
|
||||||
|
custom_query="""
|
||||||
|
{
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"should": [{"multi_match": {
|
||||||
|
"query": ${query},
|
||||||
|
"type": "most_fields",
|
||||||
|
"fields": ["content", "title"]}}
|
||||||
|
],
|
||||||
|
"filter": ${filters}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever.retrieve(
|
||||||
|
query="What is the meaning of life?",
|
||||||
|
filters={"year": [2019, 2020], "quarter": [1, 2, 3], "date": {"$gte": "2019-03-01"}}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
When using ``custom_query`` in ``BM25Retriever`` along with ``OpenSearch``
|
||||||
|
or ``Elasticsearch``, we added support for dynamic ``filters``, like in regular queries.
|
||||||
|
With this change, you can pass filters at query-time without having to modify the ``custom_query``:
|
||||||
|
Instead of defining filter expressions and field placeholders, all you have to do
|
||||||
|
is setting the ``${filters}`` placeholder analogous to the ``${query}`` placeholder into
|
||||||
|
your ``custom_query``.
|
||||||
|
**For example:**
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"should": [{"multi_match": {
|
||||||
|
"query": ${query}, // mandatory query placeholder
|
||||||
|
"type": "most_fields",
|
||||||
|
"fields": ["content", "title"]}}
|
||||||
|
],
|
||||||
|
"filter": ${filters} // optional filters placeholder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
@ -1,8 +1,10 @@
|
|||||||
|
from typing import Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from haystack.document_stores.search_engine import SearchEngineDocumentStore
|
from haystack.document_stores.search_engine import SearchEngineDocumentStore
|
||||||
|
from haystack.schema import FilterType
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@ -205,6 +207,53 @@ class SearchEngineDocumentStoreTestAbstract:
|
|||||||
labels = ds.get_all_labels()
|
labels = ds.get_all_labels()
|
||||||
assert labels[0].meta["version"] == "2023.1"
|
assert labels[0].meta["version"] == "2023.1"
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,filters,result_count",
|
||||||
|
[
|
||||||
|
# test happy path
|
||||||
|
("tost", {"year": ["2020", "2021", "1990"]}, 4),
|
||||||
|
# test empty filters
|
||||||
|
("test", None, 5),
|
||||||
|
# test linefeeds in query
|
||||||
|
("test\n", {"year": "2021"}, 3),
|
||||||
|
# test double quote in query
|
||||||
|
('test"', {"year": "2021"}, 3),
|
||||||
|
# test non-matching query
|
||||||
|
("toast", None, 0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_custom_query(
|
||||||
|
self, query: str, filters: Optional[FilterType], result_count: int, ds: SearchEngineDocumentStore
|
||||||
|
):
|
||||||
|
documents = [
|
||||||
|
{"id": "1", "content": "test", "meta": {"year": "2019"}},
|
||||||
|
{"id": "2", "content": "test", "meta": {"year": "2020"}},
|
||||||
|
{"id": "3", "content": "test", "meta": {"year": "2021"}},
|
||||||
|
{"id": "4", "content": "test", "meta": {"year": "2021"}},
|
||||||
|
{"id": "5", "content": "test", "meta": {"year": "2021"}},
|
||||||
|
]
|
||||||
|
ds.write_documents(documents)
|
||||||
|
custom_query = """
|
||||||
|
{
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"must": [{
|
||||||
|
"multi_match": {
|
||||||
|
"query": ${query},
|
||||||
|
"fields": ["content"],
|
||||||
|
"fuzziness": "AUTO"
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"filter": ${filters}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
results = ds.query(query=query, filters=filters, custom_query=custom_query)
|
||||||
|
assert len(results) == result_count
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.document_store
|
@pytest.mark.document_store
|
||||||
class TestSearchEngineDocumentStore:
|
class TestSearchEngineDocumentStore:
|
||||||
|
@ -21,7 +21,7 @@ except (ImportError, ModuleNotFoundError) as ie:
|
|||||||
_optional_component_not_installed(__name__, "elasticsearch", ie)
|
_optional_component_not_installed(__name__, "elasticsearch", ie)
|
||||||
|
|
||||||
|
|
||||||
from haystack.document_stores.base import BaseDocumentStore, FilterType
|
from haystack.document_stores.base import BaseDocumentStore, FilterType, KeywordDocumentStore
|
||||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||||
from haystack.document_stores import WeaviateDocumentStore
|
from haystack.document_stores import WeaviateDocumentStore
|
||||||
from haystack.nodes.retriever.base import BaseRetriever
|
from haystack.nodes.retriever.base import BaseRetriever
|
||||||
@ -268,59 +268,25 @@ def test_embed_meta_fields_list_with_one_item():
|
|||||||
assert docs_with_embedded_meta[0].content == "one_item\nMy name is Matteo and I live in Rome"
|
assert docs_with_embedded_meta[0].content == "one_item\nMy name is Matteo and I live in Rome"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.unit
|
||||||
def test_elasticsearch_custom_query():
|
def test_custom_query():
|
||||||
client = Elasticsearch()
|
mock_document_store = Mock(spec=KeywordDocumentStore)
|
||||||
client.indices.delete(index="haystack_test_custom", ignore=[404])
|
mock_document_store.index = "test"
|
||||||
document_store = ElasticsearchDocumentStore(
|
|
||||||
index="haystack_test_custom", content_field="custom_text_field", embedding_field="custom_embedding_field"
|
|
||||||
)
|
|
||||||
documents = [
|
|
||||||
{"content": "test_1", "meta": {"year": "2019"}},
|
|
||||||
{"content": "test_2", "meta": {"year": "2020"}},
|
|
||||||
{"content": "test_3", "meta": {"year": "2021"}},
|
|
||||||
{"content": "test_4", "meta": {"year": "2021"}},
|
|
||||||
{"content": "test_5", "meta": {"year": "2021"}},
|
|
||||||
]
|
|
||||||
document_store.write_documents(documents)
|
|
||||||
|
|
||||||
# test custom "terms" query
|
custom_query = """
|
||||||
retriever = BM25Retriever(
|
|
||||||
document_store=document_store,
|
|
||||||
custom_query="""
|
|
||||||
{
|
{
|
||||||
"size": 10,
|
"size": 10,
|
||||||
"query": {
|
"query": {
|
||||||
"bool": {
|
"bool": {
|
||||||
"should": [{
|
"should": [{
|
||||||
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["content"]}}],
|
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["custom_text_field"]}}],
|
||||||
"filter": [{"terms": {"year": ${years}}}]}}}""",
|
"filter": ${filters}}}}"""
|
||||||
)
|
|
||||||
results = retriever.retrieve(query="test", filters={"years": ["2020", "2021"]})
|
|
||||||
assert len(results) == 4
|
|
||||||
|
|
||||||
# test linefeeds in query
|
retriever = BM25Retriever(document_store=mock_document_store, custom_query=custom_query)
|
||||||
results = retriever.retrieve(query="test\n", filters={"years": ["2020", "2021"]})
|
retriever.retrieve(query="test", filters={"year": ["2020", "2021"]})
|
||||||
assert len(results) == 3
|
assert mock_document_store.query.call_args.kwargs["custom_query"] == custom_query
|
||||||
|
assert mock_document_store.query.call_args.kwargs["filters"] == {"year": ["2020", "2021"]}
|
||||||
# test double quote in query
|
assert mock_document_store.query.call_args.kwargs["query"] == "test"
|
||||||
results = retriever.retrieve(query='test"', filters={"years": ["2020", "2021"]})
|
|
||||||
assert len(results) == 3
|
|
||||||
|
|
||||||
# test custom "term" query
|
|
||||||
retriever = BM25Retriever(
|
|
||||||
document_store=document_store,
|
|
||||||
custom_query="""
|
|
||||||
{
|
|
||||||
"size": 10,
|
|
||||||
"query": {
|
|
||||||
"bool": {
|
|
||||||
"should": [{
|
|
||||||
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["content"]}}],
|
|
||||||
"filter": [{"term": {"year": ${years}}}]}}}""",
|
|
||||||
)
|
|
||||||
results = retriever.retrieve(query="test", filters={"years": "2021"})
|
|
||||||
assert len(results) == 3
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
|
Loading…
x
Reference in New Issue
Block a user