From c0892717a061eef0d9a0a185b0ec1b53f0b467d3 Mon Sep 17 00:00:00 2001 From: Malte Pietsch Date: Thu, 18 Nov 2021 18:13:03 +0100 Subject: [PATCH] Fix usage of filters in `/query` endpoint in REST API (#1774) * WIP filter refactoring * fix filter formatting * remove inplace modification of filters --- haystack/document_stores/elasticsearch.py | 2 +- rest_api/application.py | 2 +- rest_api/controller/search.py | 43 ++++++++++++++++------- rest_api/schema.py | 9 +++-- test/test_rest_api.py | 8 +++++ 5 files changed, 47 insertions(+), 17 deletions(-) diff --git a/haystack/document_stores/elasticsearch.py b/haystack/document_stores/elasticsearch.py index f86892bc8..40c30c3cf 100644 --- a/haystack/document_stores/elasticsearch.py +++ b/haystack/document_stores/elasticsearch.py @@ -737,7 +737,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): filter_clause = [] for key, values in filters.items(): if type(values) != list: - raise ValueError(f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. ' + raise ValueError(f'Wrong filter format: "{key}": {values}. Provide a list of values for each key. ' 'Example: {"name": ["some", "more"], "category": ["only_one"]} ') filter_clause.append( { diff --git a/rest_api/application.py b/rest_api/application.py index f3ac38e19..37656c257 100644 --- a/rest_api/application.py +++ b/rest_api/application.py @@ -50,7 +50,7 @@ use_route_names_as_operation_ids(app) logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.") logger.info( """ - Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' -H "Content-Type: application/json" --data '{"query": "Did Albus Dumbledore die?"}' + Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' -H "Content-Type: application/json" --data '{"query": "Who is the father of Arya Stark?"}' """ ) diff --git a/rest_api/controller/search.py b/rest_api/controller/search.py index 082479526..b9010bed3 100644 --- a/rest_api/controller/search.py +++ b/rest_api/controller/search.py @@ -53,19 +53,38 @@ def _process_request(pipeline, request) -> QueryResponse: start_time = time.time() params = request.params or {} - params["Retriever"] = params.get("Retriever", {}) - filters = {} - if "filters" in params["Retriever"]: # put filter values into a list and remove filters with null value - for key, values in params["Retriever"]["filters"].items(): - if values is None: - continue - if not isinstance(values, list): - values = [values] - filters[key] = values - params["Retriever"]["filters"] = filters - result = pipeline.run(query=request.query, params=params) - + + # format global, top-level filters (e.g. "params": {"filters": {"name": ["some"]}}) + if "filters" in params.keys(): + params["filters"] = _format_filters(params["filters"]) + + # format targeted node filters (e.g. "params": {"Retriever": {"filters": {"value"}}}) + for key, value in params.items(): + if "filters" in params[key].keys(): + params[key]["filters"] = _format_filters(params[key]["filters"]) + + result = pipeline.run(query=request.query, params=params,debug=request.debug) end_time = time.time() logger.info({"request": request.dict(), "response": result, "time": f"{(end_time - start_time):.2f}"}) return result + + +def _format_filters(filters): + """ + Adjust filters to compliant format: + Put filter values into a list and remove filters with null value. + """ + new_filters = {} + for key, values in filters.items(): + if values is None: + logger.warning(f"Request with deprecated filter format ('{key}: null'). " + f"Remove null values from filters to be compliant with future versions") + continue + elif not isinstance(values, list): + logger.warning(f"Request with deprecated filter format ('{key}': {values}). " + f"Change to '{key}':[{values}]' to be compliant with future versions") + values = [values] + + new_filters[key] = values + return new_filters diff --git a/rest_api/schema.py b/rest_api/schema.py index a2b0c3505..a77ec15ab 100644 --- a/rest_api/schema.py +++ b/rest_api/schema.py @@ -1,5 +1,5 @@ from typing import Dict, List, Optional, Union, Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, Extra from haystack.schema import Answer, Document, Label, Span from pydantic import BaseConfig from pydantic.dataclasses import dataclass as pydantic_dataclass @@ -11,11 +11,13 @@ except ImportError: BaseConfig.arbitrary_types_allowed = True - class QueryRequest(BaseModel): query: str params: Optional[dict] = None - + debug: Optional[bool] = False + class Config: + # Forbid any extra fields in the request to avoid silent failures + extra = Extra.forbid class FilterRequest(BaseModel): filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None @@ -41,4 +43,5 @@ class QueryResponse(BaseModel): query: str answers: List[AnswerSerialized] documents: Optional[List[DocumentSerialized]] + debug: Optional[Dict] = Field(None, alias="_debug") diff --git a/test/test_rest_api.py b/test/test_rest_api.py index 8e956bfff..2efb25cd9 100644 --- a/test/test_rest_api.py +++ b/test/test_rest_api.py @@ -157,6 +157,14 @@ def test_query_with_one_filter(populated_client: TestClient): assert response_json["answers"][0]["answer"] == "Adobe Systems" +def test_query_with_one_global_filter(populated_client: TestClient): + query_with_filter = {"query": "Who made the PDF specification?", "params": {"filters": {"meta_key": "meta_value"}}} + response = populated_client.post(url="/query", json=query_with_filter) + assert 200 == response.status_code + response_json = response.json() + assert response_json["answers"][0]["answer"] == "Adobe Systems" + + def test_query_with_filter_list(populated_client: TestClient): query_with_filter_list = { "query": "Who made the PDF specification?",