diff --git a/rest_api/controller/request.py b/rest_api/controller/request.py index d4804842a..94aa5f87f 100644 --- a/rest_api/controller/request.py +++ b/rest_api/controller/request.py @@ -1,5 +1,5 @@ import sys -from typing import Any, Collection, Dict, List, Optional +from typing import Any, Collection, Dict, List, Optional, Union from pydantic import BaseModel @@ -10,7 +10,7 @@ MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1 class Question(BaseModel): questions: List[str] - filters: Optional[Dict[str, Optional[str]]] = None + filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None top_k_reader: int = DEFAULT_TOP_K_READER top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER diff --git a/rest_api/controller/search.py b/rest_api/controller/search.py index d264620fd..d060dd68b 100644 --- a/rest_api/controller/search.py +++ b/rest_api/controller/search.py @@ -117,7 +117,7 @@ def doc_qa(model_id: int, question_request: Question): finder = FINDERS.get(model_id, None) if not finder: raise HTTPException( - status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}" + status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}" ) results = search_documents(finder, question_request, start_time) @@ -130,14 +130,20 @@ def faq_qa(model_id: int, request: Question): finder = FINDERS.get(model_id, None) if not finder: raise HTTPException( - status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}" + status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}" ) results = [] for question in request.questions: if request.filters: # put filter values into a list and remove filters with null value - filters = {key: [value] for key, value in request.filters.items() if value is not None} + filters = {} + for key, values in request.filters.items(): + if values is None: + continue + if not isinstance(values, list): + values = [values] + filters[key] = values logger.info(f" [{datetime.now()}] Request: {request}") else: filters = {} @@ -160,7 +166,7 @@ def query(model_id: int, query_request: Dict[str, Any], top_k_reader: int = DEFA finder = FINDERS.get(model_id, None) if not finder: raise HTTPException( - status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}" + status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}" ) question_request = Question.from_elastic_query_dsl(query_request, top_k_reader) @@ -178,7 +184,13 @@ def search_documents(finder, question_request, start_time) -> List[AnswersToIndi for question in question_request.questions: if question_request.filters: # put filter values into a list and remove filters with null value - filters = {key: [value] for key, value in question_request.filters.items() if value is not None} + filters = {} + for key, values in question_request.filters.items(): + if values is None: + continue + if not isinstance(values, list): + values = [values] + filters[key] = values logger.info(f" [{datetime.now()}] Request: {question_request}") else: filters = {} diff --git a/test/test_rest_api.py b/test/test_rest_api.py index f4700bfbd..35f4433b0 100644 --- a/test/test_rest_api.py +++ b/test/test_rest_api.py @@ -22,7 +22,45 @@ def get_test_client_and_override_dependencies(reader, document_store_with_docs): @pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) @pytest.mark.parametrize("reader", ["farm"], indirect=True) -def test_query_api(reader, document_store_with_docs): +def test_qa_api_filters(reader, document_store_with_docs): + client = get_test_client_and_override_dependencies(reader, document_store_with_docs) + + query_with_no_filter_value = {"questions": ["Where does Carla lives?"]} + response = client.post(url="/models/1/doc-qa", json=query_with_no_filter_value) + assert 200 == response.status_code + response_json = response.json() + assert response_json["results"][0]["answers"][0]["answer"] == "Berlin" + + query_with_single_filter_value = {"questions": ["Where does Carla lives?"], "filters": {"name": "filename1"}} + response = client.post(url="/models/1/doc-qa", json=query_with_single_filter_value) + assert 200 == response.status_code + response_json = response.json() + assert response_json["results"][0]["answers"][0]["answer"] == "Berlin" + + query_with_a_list_of_filter_values = { + "questions": ["Where does Carla lives?"], + "filters": {"name": ["filename1", "filename2"]}, + } + response = client.post(url="/models/1/doc-qa", json=query_with_a_list_of_filter_values) + assert 200 == response.status_code + response_json = response.json() + assert response_json["results"][0]["answers"][0]["answer"] == "Berlin" + + query_with_non_existing_filter_value = { + "questions": ["Where does Carla lives?"], + "filters": {"name": ["invalid-name"]}, + } + response = client.post(url="/models/1/doc-qa", json=query_with_non_existing_filter_value) + assert 200 == response.status_code + response_json = response.json() + assert len(response_json["results"][0]["answers"]) == 0 + + +@pytest.mark.slow +@pytest.mark.elasticsearch +@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) +@pytest.mark.parametrize("reader", ["farm"], indirect=True) +def test_query_api_filters(reader, document_store_with_docs): client = get_test_client_and_override_dependencies(reader, document_store_with_docs) query = {