mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-31 17:59:27 +00:00
Allow list of filter values in REST API (#568)
This commit is contained in:
parent
2b352d6ac4
commit
acd088808b
@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from typing import Any, Collection, Dict, List, Optional
|
||||
from typing import Any, Collection, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -10,7 +10,7 @@ MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
|
||||
|
||||
class Question(BaseModel):
|
||||
questions: List[str]
|
||||
filters: Optional[Dict[str, Optional[str]]] = None
|
||||
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
||||
top_k_reader: int = DEFAULT_TOP_K_READER
|
||||
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER
|
||||
|
||||
|
||||
@ -117,7 +117,7 @@ def doc_qa(model_id: int, question_request: Question):
|
||||
finder = FINDERS.get(model_id, None)
|
||||
if not finder:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||
)
|
||||
|
||||
results = search_documents(finder, question_request, start_time)
|
||||
@ -130,14 +130,20 @@ def faq_qa(model_id: int, request: Question):
|
||||
finder = FINDERS.get(model_id, None)
|
||||
if not finder:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||
)
|
||||
|
||||
results = []
|
||||
for question in request.questions:
|
||||
if request.filters:
|
||||
# put filter values into a list and remove filters with null value
|
||||
filters = {key: [value] for key, value in request.filters.items() if value is not None}
|
||||
filters = {}
|
||||
for key, values in request.filters.items():
|
||||
if values is None:
|
||||
continue
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
filters[key] = values
|
||||
logger.info(f" [{datetime.now()}] Request: {request}")
|
||||
else:
|
||||
filters = {}
|
||||
@ -160,7 +166,7 @@ def query(model_id: int, query_request: Dict[str, Any], top_k_reader: int = DEFA
|
||||
finder = FINDERS.get(model_id, None)
|
||||
if not finder:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||
)
|
||||
|
||||
question_request = Question.from_elastic_query_dsl(query_request, top_k_reader)
|
||||
@ -178,7 +184,13 @@ def search_documents(finder, question_request, start_time) -> List[AnswersToIndi
|
||||
for question in question_request.questions:
|
||||
if question_request.filters:
|
||||
# put filter values into a list and remove filters with null value
|
||||
filters = {key: [value] for key, value in question_request.filters.items() if value is not None}
|
||||
filters = {}
|
||||
for key, values in question_request.filters.items():
|
||||
if values is None:
|
||||
continue
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
filters[key] = values
|
||||
logger.info(f" [{datetime.now()}] Request: {question_request}")
|
||||
else:
|
||||
filters = {}
|
||||
|
||||
@ -22,7 +22,45 @@ def get_test_client_and_override_dependencies(reader, document_store_with_docs):
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_query_api(reader, document_store_with_docs):
|
||||
def test_qa_api_filters(reader, document_store_with_docs):
|
||||
client = get_test_client_and_override_dependencies(reader, document_store_with_docs)
|
||||
|
||||
query_with_no_filter_value = {"questions": ["Where does Carla lives?"]}
|
||||
response = client.post(url="/models/1/doc-qa", json=query_with_no_filter_value)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
|
||||
|
||||
query_with_single_filter_value = {"questions": ["Where does Carla lives?"], "filters": {"name": "filename1"}}
|
||||
response = client.post(url="/models/1/doc-qa", json=query_with_single_filter_value)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
|
||||
|
||||
query_with_a_list_of_filter_values = {
|
||||
"questions": ["Where does Carla lives?"],
|
||||
"filters": {"name": ["filename1", "filename2"]},
|
||||
}
|
||||
response = client.post(url="/models/1/doc-qa", json=query_with_a_list_of_filter_values)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
|
||||
|
||||
query_with_non_existing_filter_value = {
|
||||
"questions": ["Where does Carla lives?"],
|
||||
"filters": {"name": ["invalid-name"]},
|
||||
}
|
||||
response = client.post(url="/models/1/doc-qa", json=query_with_non_existing_filter_value)
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert len(response_json["results"][0]["answers"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_query_api_filters(reader, document_store_with_docs):
|
||||
client = get_test_client_and_override_dependencies(reader, document_store_with_docs)
|
||||
|
||||
query = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user