mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-27 18:06:17 +00:00
feat: support dynamic filters in custom_query (#5427)
* support filters in custom_query * better tests * Update docstrings --------- Co-authored-by: agnieszka-m <amarzec13@gmail.com>
This commit is contained in:
parent
3f472995bb
commit
d46c84bb61
@ -706,11 +706,8 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
||||
}
|
||||
```
|
||||
:param top_k: How many documents to return per query.
|
||||
:param custom_query: query string containing a mandatory `${query}` placeholder.
|
||||
:param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder.
|
||||
|
||||
Optionally, ES `filter` clause can be added where the values of `terms` are placeholders
|
||||
that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_name_2}..)
|
||||
names must match with the filters dict supplied in self.retrieve().
|
||||
::
|
||||
|
||||
**An example custom_query:**
|
||||
@ -723,17 +720,13 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
||||
"query": ${query}, // mandatory query placeholder
|
||||
"type": "most_fields",
|
||||
"fields": ["content", "title"]}}],
|
||||
"filter": [ // optional custom filters
|
||||
{"terms": {"year": ${years}}},
|
||||
{"terms": {"quarter": ${quarters}}},
|
||||
{"range": {"date": {"gte": ${date}}}}
|
||||
],
|
||||
"filter": ${filters} // optional filters placeholder
|
||||
}
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
**For this custom_query, a sample retrieve() could be:**
|
||||
**For this custom_query, a sample `retrieve()` could be:**
|
||||
```python
|
||||
self.retrieve(query="Why did the revenue increase?",
|
||||
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
||||
|
@ -824,11 +824,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
}
|
||||
```
|
||||
:param top_k: How many documents to return per query.
|
||||
:param custom_query: query string containing a mandatory `${query}` placeholder.
|
||||
:param custom_query: query string containing a mandatory `${query}` and an optional `${filters}` placeholder.
|
||||
|
||||
Optionally, ES `filter` clause can be added where the values of `terms` are placeholders
|
||||
that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_name_2}..)
|
||||
names must match with the filters dict supplied in self.retrieve().
|
||||
::
|
||||
|
||||
**An example custom_query:**
|
||||
@ -841,11 +838,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
"query": ${query}, // mandatory query placeholder
|
||||
"type": "most_fields",
|
||||
"fields": ["content", "title"]}}],
|
||||
"filter": [ // optional custom filters
|
||||
{"terms": {"year": ${years}}},
|
||||
{"terms": {"quarter": ${quarters}}},
|
||||
{"range": {"date": {"gte": ${date}}}}
|
||||
],
|
||||
"filter": ${filters} // optional filters placeholder
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -1084,16 +1077,13 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
body["query"]["bool"]["filter"] = LogicalFilterClause.parse(filters).convert_to_elasticsearch()
|
||||
|
||||
# Retrieval via custom query
|
||||
elif custom_query: # substitute placeholder for query and filters for the custom_query template string
|
||||
elif custom_query:
|
||||
template = Template(custom_query)
|
||||
# replace all "${query}" placeholder(s) with query
|
||||
substitutions = {"query": json.dumps(query)}
|
||||
# For each filter we got passed, we'll try to find & replace the corresponding placeholder in the template
|
||||
# Example: filters={"years":[2018]} => replaces {$years} in custom_query with '[2018]'
|
||||
if filters:
|
||||
for key, values in filters.items():
|
||||
values_str = json.dumps(values)
|
||||
substitutions[key] = values_str
|
||||
# substitute placeholder for query and filters for the custom_query template string
|
||||
substitutions = {
|
||||
"query": json.dumps(query),
|
||||
"filters": json.dumps(LogicalFilterClause.parse(filters or {}).convert_to_elasticsearch()),
|
||||
}
|
||||
custom_query_json = template.substitute(**substitutions)
|
||||
body = json.loads(custom_query_json)
|
||||
# add top_k
|
||||
|
@ -33,11 +33,7 @@ class BM25Retriever(BaseRetriever):
|
||||
If true all query terms must be present in a document in order to be retrieved (i.e the AND operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy AND fish AND restaurant").
|
||||
Otherwise at least one query term must be present in a document in order to be retrieved (i.e the OR operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy OR fish OR restaurant").
|
||||
Defaults to False.
|
||||
:param custom_query: query string as per Elasticsearch DSL with a mandatory query placeholder(query).
|
||||
|
||||
Optionally, ES `filter` clause can be added where the values of `terms` are placeholders
|
||||
that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_name_2}..)
|
||||
names must match with the filters dict supplied in self.retrieve().
|
||||
:param custom_query: The query string containing a mandatory `${query}` and an optional `${filters}` placeholder.
|
||||
|
||||
**An example custom_query:**
|
||||
|
||||
@ -50,17 +46,13 @@ class BM25Retriever(BaseRetriever):
|
||||
"query": ${query}, // mandatory query placeholder
|
||||
"type": "most_fields",
|
||||
"fields": ["content", "title"]}}],
|
||||
"filter": [ // optional custom filters
|
||||
{"terms": {"year": ${years}}},
|
||||
{"terms": {"quarter": ${quarters}}},
|
||||
{"range": {"date": {"gte": ${date}}}}
|
||||
],
|
||||
"filter": ${filters} // optional filter placeholder
|
||||
}
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
**For this custom_query, a sample retrieve() could be:**
|
||||
**For this custom_query, a sample `retrieve()` could be:**
|
||||
|
||||
```python
|
||||
self.retrieve(query="Why did the revenue increase?",
|
||||
|
@ -0,0 +1,82 @@
|
||||
---
|
||||
upgrade:
|
||||
- |
|
||||
The Opensearch custom query syntax changes: the old filter placeholders for ``custom_query`` are no longer supported.
|
||||
Replace your custom filter expressions with the new ``${filters}`` placeholder:
|
||||
|
||||
**Old:**
|
||||
```python
|
||||
retriever = BM25Retriever(
|
||||
custom_query="""
|
||||
{
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{"multi_match": {
|
||||
"query": ${query},
|
||||
"type": "most_fields",
|
||||
"fields": ["content", "title"]}}
|
||||
],
|
||||
"filter": [
|
||||
{"terms": {"year": ${years}}},
|
||||
{"terms": {"quarter": ${quarters}}},
|
||||
{"range": {"date": {"gte": ${date}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
retriever.retrieve(
|
||||
query="What is the meaning of life?",
|
||||
filters={"years": [2019, 2020], "quarters": [1, 2, 3], "date": "2019-03-01"}
|
||||
)
|
||||
```
|
||||
|
||||
**New:**
|
||||
```python
|
||||
retriever = BM25Retriever(
|
||||
custom_query="""
|
||||
{
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{"multi_match": {
|
||||
"query": ${query},
|
||||
"type": "most_fields",
|
||||
"fields": ["content", "title"]}}
|
||||
],
|
||||
"filter": ${filters}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
retriever.retrieve(
|
||||
query="What is the meaning of life?",
|
||||
filters={"year": [2019, 2020], "quarter": [1, 2, 3], "date": {"$gte": "2019-03-01"}}
|
||||
)
|
||||
```
|
||||
features:
|
||||
- |
|
||||
When using ``custom_query`` in ``BM25Retriever`` along with ``OpenSearch``
|
||||
or ``Elasticsearch``, we added support for dynamic ``filters``, like in regular queries.
|
||||
With this change, you can pass filters at query-time without having to modify the ``custom_query``:
|
||||
Instead of defining filter expressions and field placeholders, all you have to do
|
||||
is setting the ``${filters}`` placeholder analogous to the ``${query}`` placeholder into
|
||||
your ``custom_query``.
|
||||
**For example:**
|
||||
```python
|
||||
{
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{"multi_match": {
|
||||
"query": ${query}, // mandatory query placeholder
|
||||
"type": "most_fields",
|
||||
"fields": ["content", "title"]}}
|
||||
],
|
||||
"filter": ${filters} // optional filters placeholder
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
@ -1,8 +1,10 @@
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from haystack.document_stores.search_engine import SearchEngineDocumentStore
|
||||
from haystack.schema import FilterType
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -205,6 +207,53 @@ class SearchEngineDocumentStoreTestAbstract:
|
||||
labels = ds.get_all_labels()
|
||||
assert labels[0].meta["version"] == "2023.1"
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"query,filters,result_count",
|
||||
[
|
||||
# test happy path
|
||||
("tost", {"year": ["2020", "2021", "1990"]}, 4),
|
||||
# test empty filters
|
||||
("test", None, 5),
|
||||
# test linefeeds in query
|
||||
("test\n", {"year": "2021"}, 3),
|
||||
# test double quote in query
|
||||
('test"', {"year": "2021"}, 3),
|
||||
# test non-matching query
|
||||
("toast", None, 0),
|
||||
],
|
||||
)
|
||||
def test_custom_query(
|
||||
self, query: str, filters: Optional[FilterType], result_count: int, ds: SearchEngineDocumentStore
|
||||
):
|
||||
documents = [
|
||||
{"id": "1", "content": "test", "meta": {"year": "2019"}},
|
||||
{"id": "2", "content": "test", "meta": {"year": "2020"}},
|
||||
{"id": "3", "content": "test", "meta": {"year": "2021"}},
|
||||
{"id": "4", "content": "test", "meta": {"year": "2021"}},
|
||||
{"id": "5", "content": "test", "meta": {"year": "2021"}},
|
||||
]
|
||||
ds.write_documents(documents)
|
||||
custom_query = """
|
||||
{
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [{
|
||||
"multi_match": {
|
||||
"query": ${query},
|
||||
"fields": ["content"],
|
||||
"fuzziness": "AUTO"
|
||||
}
|
||||
}],
|
||||
"filter": ${filters}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
results = ds.query(query=query, filters=filters, custom_query=custom_query)
|
||||
assert len(results) == result_count
|
||||
|
||||
|
||||
@pytest.mark.document_store
|
||||
class TestSearchEngineDocumentStore:
|
||||
|
@ -21,7 +21,7 @@ except (ImportError, ModuleNotFoundError) as ie:
|
||||
_optional_component_not_installed(__name__, "elasticsearch", ie)
|
||||
|
||||
|
||||
from haystack.document_stores.base import BaseDocumentStore, FilterType
|
||||
from haystack.document_stores.base import BaseDocumentStore, FilterType, KeywordDocumentStore
|
||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||
from haystack.document_stores import WeaviateDocumentStore
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
@ -268,59 +268,25 @@ def test_embed_meta_fields_list_with_one_item():
|
||||
assert docs_with_embedded_meta[0].content == "one_item\nMy name is Matteo and I live in Rome"
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_elasticsearch_custom_query():
|
||||
client = Elasticsearch()
|
||||
client.indices.delete(index="haystack_test_custom", ignore=[404])
|
||||
document_store = ElasticsearchDocumentStore(
|
||||
index="haystack_test_custom", content_field="custom_text_field", embedding_field="custom_embedding_field"
|
||||
)
|
||||
documents = [
|
||||
{"content": "test_1", "meta": {"year": "2019"}},
|
||||
{"content": "test_2", "meta": {"year": "2020"}},
|
||||
{"content": "test_3", "meta": {"year": "2021"}},
|
||||
{"content": "test_4", "meta": {"year": "2021"}},
|
||||
{"content": "test_5", "meta": {"year": "2021"}},
|
||||
]
|
||||
document_store.write_documents(documents)
|
||||
@pytest.mark.unit
|
||||
def test_custom_query():
|
||||
mock_document_store = Mock(spec=KeywordDocumentStore)
|
||||
mock_document_store.index = "test"
|
||||
|
||||
# test custom "terms" query
|
||||
retriever = BM25Retriever(
|
||||
document_store=document_store,
|
||||
custom_query="""
|
||||
custom_query = """
|
||||
{
|
||||
"size": 10,
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{
|
||||
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["content"]}}],
|
||||
"filter": [{"terms": {"year": ${years}}}]}}}""",
|
||||
)
|
||||
results = retriever.retrieve(query="test", filters={"years": ["2020", "2021"]})
|
||||
assert len(results) == 4
|
||||
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["custom_text_field"]}}],
|
||||
"filter": ${filters}}}}"""
|
||||
|
||||
# test linefeeds in query
|
||||
results = retriever.retrieve(query="test\n", filters={"years": ["2020", "2021"]})
|
||||
assert len(results) == 3
|
||||
|
||||
# test double quote in query
|
||||
results = retriever.retrieve(query='test"', filters={"years": ["2020", "2021"]})
|
||||
assert len(results) == 3
|
||||
|
||||
# test custom "term" query
|
||||
retriever = BM25Retriever(
|
||||
document_store=document_store,
|
||||
custom_query="""
|
||||
{
|
||||
"size": 10,
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{
|
||||
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["content"]}}],
|
||||
"filter": [{"term": {"year": ${years}}}]}}}""",
|
||||
)
|
||||
results = retriever.retrieve(query="test", filters={"years": "2021"})
|
||||
assert len(results) == 3
|
||||
retriever = BM25Retriever(document_store=mock_document_store, custom_query=custom_query)
|
||||
retriever.retrieve(query="test", filters={"year": ["2020", "2021"]})
|
||||
assert mock_document_store.query.call_args.kwargs["custom_query"] == custom_query
|
||||
assert mock_document_store.query.call_args.kwargs["filters"] == {"year": ["2020", "2021"]}
|
||||
assert mock_document_store.query.call_args.kwargs["query"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
Loading…
x
Reference in New Issue
Block a user