Fix usage of filters in /query endpoint in REST API (#1774)

* WIP filter refactoring

* fix filter formatting

* remove inplace modification of filters
This commit is contained in:
Malte Pietsch 2021-11-18 18:13:03 +01:00 committed by GitHub
parent 31e22012da
commit c0892717a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 17 deletions

View File

@ -737,7 +737,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
filter_clause = [] filter_clause = []
for key, values in filters.items(): for key, values in filters.items():
if type(values) != list: if type(values) != list:
raise ValueError(f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. ' raise ValueError(f'Wrong filter format: "{key}": {values}. Provide a list of values for each key. '
'Example: {"name": ["some", "more"], "category": ["only_one"]} ') 'Example: {"name": ["some", "more"], "category": ["only_one"]} ')
filter_clause.append( filter_clause.append(
{ {

View File

@ -50,7 +50,7 @@ use_route_names_as_operation_ids(app)
logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.") logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
logger.info( logger.info(
""" """
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' -H "Content-Type: application/json" --data '{"query": "Did Albus Dumbledore die?"}' Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' -H "Content-Type: application/json" --data '{"query": "Who is the father of Arya Stark?"}'
""" """
) )

View File

@ -53,19 +53,38 @@ def _process_request(pipeline, request) -> QueryResponse:
start_time = time.time() start_time = time.time()
params = request.params or {} params = request.params or {}
params["Retriever"] = params.get("Retriever", {})
filters = {} # format global, top-level filters (e.g. "params": {"filters": {"name": ["some"]}})
if "filters" in params["Retriever"]: # put filter values into a list and remove filters with null value if "filters" in params.keys():
for key, values in params["Retriever"]["filters"].items(): params["filters"] = _format_filters(params["filters"])
if values is None:
continue # format targeted node filters (e.g. "params": {"Retriever": {"filters": {"value"}}})
if not isinstance(values, list): for key, value in params.items():
values = [values] if "filters" in params[key].keys():
filters[key] = values params[key]["filters"] = _format_filters(params[key]["filters"])
params["Retriever"]["filters"] = filters
result = pipeline.run(query=request.query, params=params) result = pipeline.run(query=request.query, params=params,debug=request.debug)
end_time = time.time() end_time = time.time()
logger.info({"request": request.dict(), "response": result, "time": f"{(end_time - start_time):.2f}"}) logger.info({"request": request.dict(), "response": result, "time": f"{(end_time - start_time):.2f}"})
return result return result
def _format_filters(filters):
"""
Adjust filters to compliant format:
Put filter values into a list and remove filters with null value.
"""
new_filters = {}
for key, values in filters.items():
if values is None:
logger.warning(f"Request with deprecated filter format ('{key}: null'). "
f"Remove null values from filters to be compliant with future versions")
continue
elif not isinstance(values, list):
logger.warning(f"Request with deprecated filter format ('{key}': {values}). "
f"Change to '{key}':[{values}]' to be compliant with future versions")
values = [values]
new_filters[key] = values
return new_filters

View File

@ -1,5 +1,5 @@
from typing import Dict, List, Optional, Union, Any from typing import Dict, List, Optional, Union, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, Extra
from haystack.schema import Answer, Document, Label, Span from haystack.schema import Answer, Document, Label, Span
from pydantic import BaseConfig from pydantic import BaseConfig
from pydantic.dataclasses import dataclass as pydantic_dataclass from pydantic.dataclasses import dataclass as pydantic_dataclass
@ -11,11 +11,13 @@ except ImportError:
BaseConfig.arbitrary_types_allowed = True BaseConfig.arbitrary_types_allowed = True
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
params: Optional[dict] = None params: Optional[dict] = None
debug: Optional[bool] = False
class Config:
# Forbid any extra fields in the request to avoid silent failures
extra = Extra.forbid
class FilterRequest(BaseModel): class FilterRequest(BaseModel):
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
@ -41,4 +43,5 @@ class QueryResponse(BaseModel):
query: str query: str
answers: List[AnswerSerialized] answers: List[AnswerSerialized]
documents: Optional[List[DocumentSerialized]] documents: Optional[List[DocumentSerialized]]
debug: Optional[Dict] = Field(None, alias="_debug")

View File

@ -157,6 +157,14 @@ def test_query_with_one_filter(populated_client: TestClient):
assert response_json["answers"][0]["answer"] == "Adobe Systems" assert response_json["answers"][0]["answer"] == "Adobe Systems"
def test_query_with_one_global_filter(populated_client: TestClient):
query_with_filter = {"query": "Who made the PDF specification?", "params": {"filters": {"meta_key": "meta_value"}}}
response = populated_client.post(url="/query", json=query_with_filter)
assert 200 == response.status_code
response_json = response.json()
assert response_json["answers"][0]["answer"] == "Adobe Systems"
def test_query_with_filter_list(populated_client: TestClient): def test_query_with_filter_list(populated_client: TestClient):
query_with_filter_list = { query_with_filter_list = {
"query": "Who made the PDF specification?", "query": "Who made the PDF specification?",