Fix usage of filters in /query endpoint in REST API (#1774)

* WIP filter refactoring

* fix filter formatting

* remove inplace modification of filters
This commit is contained in:
Malte Pietsch 2021-11-18 18:13:03 +01:00 committed by GitHub
parent 31e22012da
commit c0892717a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 17 deletions

View File

@ -737,7 +737,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
filter_clause = []
for key, values in filters.items():
if type(values) != list:
raise ValueError(f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. '
raise ValueError(f'Wrong filter format: "{key}": {values}. Provide a list of values for each key. '
'Example: {"name": ["some", "more"], "category": ["only_one"]} ')
filter_clause.append(
{

View File

@ -50,7 +50,7 @@ use_route_names_as_operation_ids(app)
logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
logger.info(
"""
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' -H "Content-Type: application/json" --data '{"query": "Did Albus Dumbledore die?"}'
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' -H "Content-Type: application/json" --data '{"query": "Who is the father of Arya Stark?"}'
"""
)

View File

@ -53,19 +53,38 @@ def _process_request(pipeline, request) -> QueryResponse:
start_time = time.time()
params = request.params or {}
params["Retriever"] = params.get("Retriever", {})
filters = {}
if "filters" in params["Retriever"]: # put filter values into a list and remove filters with null value
for key, values in params["Retriever"]["filters"].items():
if values is None:
continue
if not isinstance(values, list):
values = [values]
filters[key] = values
params["Retriever"]["filters"] = filters
result = pipeline.run(query=request.query, params=params)
# format global, top-level filters (e.g. "params": {"filters": {"name": ["some"]}})
if "filters" in params.keys():
params["filters"] = _format_filters(params["filters"])
# format targeted node filters (e.g. "params": {"Retriever": {"filters": {"value"}}})
for key, value in params.items():
if "filters" in params[key].keys():
params[key]["filters"] = _format_filters(params[key]["filters"])
result = pipeline.run(query=request.query, params=params,debug=request.debug)
end_time = time.time()
logger.info({"request": request.dict(), "response": result, "time": f"{(end_time - start_time):.2f}"})
return result
def _format_filters(filters):
"""
Adjust filters to compliant format:
Put filter values into a list and remove filters with null value.
"""
new_filters = {}
for key, values in filters.items():
if values is None:
logger.warning(f"Request with deprecated filter format ('{key}: null'). "
f"Remove null values from filters to be compliant with future versions")
continue
elif not isinstance(values, list):
logger.warning(f"Request with deprecated filter format ('{key}': {values}). "
f"Change to '{key}':[{values}]' to be compliant with future versions")
values = [values]
new_filters[key] = values
return new_filters

View File

@ -1,5 +1,5 @@
from typing import Dict, List, Optional, Union, Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, Extra
from haystack.schema import Answer, Document, Label, Span
from pydantic import BaseConfig
from pydantic.dataclasses import dataclass as pydantic_dataclass
@ -11,11 +11,13 @@ except ImportError:
BaseConfig.arbitrary_types_allowed = True
class QueryRequest(BaseModel):
query: str
params: Optional[dict] = None
debug: Optional[bool] = False
class Config:
# Forbid any extra fields in the request to avoid silent failures
extra = Extra.forbid
class FilterRequest(BaseModel):
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
@ -41,4 +43,5 @@ class QueryResponse(BaseModel):
query: str
answers: List[AnswerSerialized]
documents: Optional[List[DocumentSerialized]]
debug: Optional[Dict] = Field(None, alias="_debug")

View File

@ -157,6 +157,14 @@ def test_query_with_one_filter(populated_client: TestClient):
assert response_json["answers"][0]["answer"] == "Adobe Systems"
def test_query_with_one_global_filter(populated_client: TestClient):
query_with_filter = {"query": "Who made the PDF specification?", "params": {"filters": {"meta_key": "meta_value"}}}
response = populated_client.post(url="/query", json=query_with_filter)
assert 200 == response.status_code
response_json = response.json()
assert response_json["answers"][0]["answer"] == "Adobe Systems"
def test_query_with_filter_list(populated_client: TestClient):
query_with_filter_list = {
"query": "Who made the PDF specification?",