mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-17 02:48:30 +00:00
Fix usage of filters in /query
endpoint in REST API (#1774)
* WIP filter refactoring * fix filter formatting * remove inplace modification of filters
This commit is contained in:
parent
31e22012da
commit
c0892717a0
@ -737,7 +737,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
filter_clause = []
|
||||
for key, values in filters.items():
|
||||
if type(values) != list:
|
||||
raise ValueError(f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. '
|
||||
raise ValueError(f'Wrong filter format: "{key}": {values}. Provide a list of values for each key. '
|
||||
'Example: {"name": ["some", "more"], "category": ["only_one"]} ')
|
||||
filter_clause.append(
|
||||
{
|
||||
|
@ -50,7 +50,7 @@ use_route_names_as_operation_ids(app)
|
||||
logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
|
||||
logger.info(
|
||||
"""
|
||||
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' -H "Content-Type: application/json" --data '{"query": "Did Albus Dumbledore die?"}'
|
||||
Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/query' -H "Content-Type: application/json" --data '{"query": "Who is the father of Arya Stark?"}'
|
||||
"""
|
||||
)
|
||||
|
||||
|
@ -53,19 +53,38 @@ def _process_request(pipeline, request) -> QueryResponse:
|
||||
start_time = time.time()
|
||||
|
||||
params = request.params or {}
|
||||
params["Retriever"] = params.get("Retriever", {})
|
||||
filters = {}
|
||||
if "filters" in params["Retriever"]: # put filter values into a list and remove filters with null value
|
||||
for key, values in params["Retriever"]["filters"].items():
|
||||
if values is None:
|
||||
continue
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
filters[key] = values
|
||||
params["Retriever"]["filters"] = filters
|
||||
result = pipeline.run(query=request.query, params=params)
|
||||
|
||||
|
||||
# format global, top-level filters (e.g. "params": {"filters": {"name": ["some"]}})
|
||||
if "filters" in params.keys():
|
||||
params["filters"] = _format_filters(params["filters"])
|
||||
|
||||
# format targeted node filters (e.g. "params": {"Retriever": {"filters": {"value"}}})
|
||||
for key, value in params.items():
|
||||
if "filters" in params[key].keys():
|
||||
params[key]["filters"] = _format_filters(params[key]["filters"])
|
||||
|
||||
result = pipeline.run(query=request.query, params=params,debug=request.debug)
|
||||
end_time = time.time()
|
||||
logger.info({"request": request.dict(), "response": result, "time": f"{(end_time - start_time):.2f}"})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _format_filters(filters):
|
||||
"""
|
||||
Adjust filters to compliant format:
|
||||
Put filter values into a list and remove filters with null value.
|
||||
"""
|
||||
new_filters = {}
|
||||
for key, values in filters.items():
|
||||
if values is None:
|
||||
logger.warning(f"Request with deprecated filter format ('{key}: null'). "
|
||||
f"Remove null values from filters to be compliant with future versions")
|
||||
continue
|
||||
elif not isinstance(values, list):
|
||||
logger.warning(f"Request with deprecated filter format ('{key}': {values}). "
|
||||
f"Change to '{key}':[{values}]' to be compliant with future versions")
|
||||
values = [values]
|
||||
|
||||
new_filters[key] = values
|
||||
return new_filters
|
||||
|
@ -1,5 +1,5 @@
|
||||
from typing import Dict, List, Optional, Union, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
from haystack.schema import Answer, Document, Label, Span
|
||||
from pydantic import BaseConfig
|
||||
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
||||
@ -11,11 +11,13 @@ except ImportError:
|
||||
|
||||
BaseConfig.arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
params: Optional[dict] = None
|
||||
|
||||
debug: Optional[bool] = False
|
||||
class Config:
|
||||
# Forbid any extra fields in the request to avoid silent failures
|
||||
extra = Extra.forbid
|
||||
|
||||
class FilterRequest(BaseModel):
|
||||
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
||||
@ -41,4 +43,5 @@ class QueryResponse(BaseModel):
|
||||
query: str
|
||||
answers: List[AnswerSerialized]
|
||||
documents: Optional[List[DocumentSerialized]]
|
||||
debug: Optional[Dict] = Field(None, alias="_debug")
|
||||
|
||||
|
@ -157,6 +157,14 @@ def test_query_with_one_filter(populated_client: TestClient):
|
||||
assert response_json["answers"][0]["answer"] == "Adobe Systems"
|
||||
|
||||
|
||||
def test_query_with_one_global_filter(populated_client: TestClient):
|
||||
query_with_filter = {"query": "Who made the PDF specification?", "params": {"filters": {"meta_key": "meta_value"}}}
|
||||
response = populated_client.post(url="/query", json=query_with_filter)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert response_json["answers"][0]["answer"] == "Adobe Systems"
|
||||
|
||||
|
||||
def test_query_with_filter_list(populated_client: TestClient):
|
||||
query_with_filter_list = {
|
||||
"query": "Who made the PDF specification?",
|
||||
|
Loading…
x
Reference in New Issue
Block a user