mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-18 03:18:42 +00:00
Fix usage of filters in /query
endpoint in REST API (#1774)
* WIP filter refactoring * fix filter formatting * remove inplace modification of filters
This commit is contained in:
parent
31e22012da
commit
c0892717a0
@ -737,7 +737,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
filter_clause = []
|
filter_clause = []
|
||||||
for key, values in filters.items():
|
for key, values in filters.items():
|
||||||
if type(values) != list:
|
if type(values) != list:
|
||||||
raise ValueError(f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. '
|
raise ValueError(f'Wrong filter format: "{key}": {values}. Provide a list of values for each key. '
|
||||||
'Example: {"name": ["some", "more"], "category": ["only_one"]} ')
|
'Example: {"name": ["some", "more"], "category": ["only_one"]} ')
|
||||||
filter_clause.append(
|
filter_clause.append(
|
||||||
{
|
{
|
||||||
|
@ -50,7 +50,7 @@ use_route_names_as_operation_ids(app)
|
|||||||
logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
|
logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
|
||||||
logger.info(
|
logger.info(
|
||||||
"""
|
"""
|
||||||
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' -H "Content-Type: application/json" --data '{"query": "Did Albus Dumbledore die?"}'
|
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' -H "Content-Type: application/json" --data '{"query": "Who is the father of Arya Stark?"}'
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,19 +53,38 @@ def _process_request(pipeline, request) -> QueryResponse:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
params = request.params or {}
|
params = request.params or {}
|
||||||
params["Retriever"] = params.get("Retriever", {})
|
|
||||||
filters = {}
|
# format global, top-level filters (e.g. "params": {"filters": {"name": ["some"]}})
|
||||||
if "filters" in params["Retriever"]: # put filter values into a list and remove filters with null value
|
if "filters" in params.keys():
|
||||||
for key, values in params["Retriever"]["filters"].items():
|
params["filters"] = _format_filters(params["filters"])
|
||||||
if values is None:
|
|
||||||
continue
|
# format targeted node filters (e.g. "params": {"Retriever": {"filters": {"value"}}})
|
||||||
if not isinstance(values, list):
|
for key, value in params.items():
|
||||||
values = [values]
|
if "filters" in params[key].keys():
|
||||||
filters[key] = values
|
params[key]["filters"] = _format_filters(params[key]["filters"])
|
||||||
params["Retriever"]["filters"] = filters
|
|
||||||
result = pipeline.run(query=request.query, params=params)
|
result = pipeline.run(query=request.query, params=params,debug=request.debug)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.info({"request": request.dict(), "response": result, "time": f"{(end_time - start_time):.2f}"})
|
logger.info({"request": request.dict(), "response": result, "time": f"{(end_time - start_time):.2f}"})
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _format_filters(filters):
|
||||||
|
"""
|
||||||
|
Adjust filters to compliant format:
|
||||||
|
Put filter values into a list and remove filters with null value.
|
||||||
|
"""
|
||||||
|
new_filters = {}
|
||||||
|
for key, values in filters.items():
|
||||||
|
if values is None:
|
||||||
|
logger.warning(f"Request with deprecated filter format ('{key}: null'). "
|
||||||
|
f"Remove null values from filters to be compliant with future versions")
|
||||||
|
continue
|
||||||
|
elif not isinstance(values, list):
|
||||||
|
logger.warning(f"Request with deprecated filter format ('{key}': {values}). "
|
||||||
|
f"Change to '{key}':[{values}]' to be compliant with future versions")
|
||||||
|
values = [values]
|
||||||
|
|
||||||
|
new_filters[key] = values
|
||||||
|
return new_filters
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from typing import Dict, List, Optional, Union, Any
|
from typing import Dict, List, Optional, Union, Any
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, Extra
|
||||||
from haystack.schema import Answer, Document, Label, Span
|
from haystack.schema import Answer, Document, Label, Span
|
||||||
from pydantic import BaseConfig
|
from pydantic import BaseConfig
|
||||||
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
||||||
@ -11,11 +11,13 @@ except ImportError:
|
|||||||
|
|
||||||
BaseConfig.arbitrary_types_allowed = True
|
BaseConfig.arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
params: Optional[dict] = None
|
params: Optional[dict] = None
|
||||||
|
debug: Optional[bool] = False
|
||||||
|
class Config:
|
||||||
|
# Forbid any extra fields in the request to avoid silent failures
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
class FilterRequest(BaseModel):
|
class FilterRequest(BaseModel):
|
||||||
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
||||||
@ -41,4 +43,5 @@ class QueryResponse(BaseModel):
|
|||||||
query: str
|
query: str
|
||||||
answers: List[AnswerSerialized]
|
answers: List[AnswerSerialized]
|
||||||
documents: Optional[List[DocumentSerialized]]
|
documents: Optional[List[DocumentSerialized]]
|
||||||
|
debug: Optional[Dict] = Field(None, alias="_debug")
|
||||||
|
|
||||||
|
@ -157,6 +157,14 @@ def test_query_with_one_filter(populated_client: TestClient):
|
|||||||
assert response_json["answers"][0]["answer"] == "Adobe Systems"
|
assert response_json["answers"][0]["answer"] == "Adobe Systems"
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_with_one_global_filter(populated_client: TestClient):
|
||||||
|
query_with_filter = {"query": "Who made the PDF specification?", "params": {"filters": {"meta_key": "meta_value"}}}
|
||||||
|
response = populated_client.post(url="/query", json=query_with_filter)
|
||||||
|
assert 200 == response.status_code
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["answers"][0]["answer"] == "Adobe Systems"
|
||||||
|
|
||||||
|
|
||||||
def test_query_with_filter_list(populated_client: TestClient):
|
def test_query_with_filter_list(populated_client: TestClient):
|
||||||
query_with_filter_list = {
|
query_with_filter_list = {
|
||||||
"query": "Who made the PDF specification?",
|
"query": "Who made the PDF specification?",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user