Allow list of filter values in REST API (#568)

This commit is contained in:
Tanay Soni 2020-11-09 20:41:53 +01:00 committed by GitHub
parent 2b352d6ac4
commit acd088808b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 8 deletions

View File

@ -1,5 +1,5 @@
import sys
from typing import Any, Collection, Dict, List, Optional
from typing import Any, Collection, Dict, List, Optional, Union
from pydantic import BaseModel
@ -10,7 +10,7 @@ MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
class Question(BaseModel):
questions: List[str]
filters: Optional[Dict[str, Optional[str]]] = None
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
top_k_reader: int = DEFAULT_TOP_K_READER
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER

View File

@ -117,7 +117,7 @@ def doc_qa(model_id: int, question_request: Question):
finder = FINDERS.get(model_id, None)
if not finder:
raise HTTPException(
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
)
results = search_documents(finder, question_request, start_time)
@ -130,14 +130,20 @@ def faq_qa(model_id: int, request: Question):
finder = FINDERS.get(model_id, None)
if not finder:
raise HTTPException(
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
)
results = []
for question in request.questions:
if request.filters:
# put filter values into a list and remove filters with null value
filters = {key: [value] for key, value in request.filters.items() if value is not None}
filters = {}
for key, values in request.filters.items():
if values is None:
continue
if not isinstance(values, list):
values = [values]
filters[key] = values
logger.info(f" [{datetime.now()}] Request: {request}")
else:
filters = {}
@ -160,7 +166,7 @@ def query(model_id: int, query_request: Dict[str, Any], top_k_reader: int = DEFA
finder = FINDERS.get(model_id, None)
if not finder:
raise HTTPException(
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
)
question_request = Question.from_elastic_query_dsl(query_request, top_k_reader)
@ -178,7 +184,13 @@ def search_documents(finder, question_request, start_time) -> List[AnswersToIndi
for question in question_request.questions:
if question_request.filters:
# put filter values into a list and remove filters with null value
filters = {key: [value] for key, value in question_request.filters.items() if value is not None}
filters = {}
for key, values in question_request.filters.items():
if values is None:
continue
if not isinstance(values, list):
values = [values]
filters[key] = values
logger.info(f" [{datetime.now()}] Request: {question_request}")
else:
filters = {}

View File

@ -22,7 +22,45 @@ def get_test_client_and_override_dependencies(reader, document_store_with_docs):
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_query_api(reader, document_store_with_docs):
def test_qa_api_filters(reader, document_store_with_docs):
client = get_test_client_and_override_dependencies(reader, document_store_with_docs)
query_with_no_filter_value = {"questions": ["Where does Carla lives?"]}
response = client.post(url="/models/1/doc-qa", json=query_with_no_filter_value)
assert 200 == response.status_code
response_json = response.json()
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
query_with_single_filter_value = {"questions": ["Where does Carla lives?"], "filters": {"name": "filename1"}}
response = client.post(url="/models/1/doc-qa", json=query_with_single_filter_value)
assert 200 == response.status_code
response_json = response.json()
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
query_with_a_list_of_filter_values = {
"questions": ["Where does Carla lives?"],
"filters": {"name": ["filename1", "filename2"]},
}
response = client.post(url="/models/1/doc-qa", json=query_with_a_list_of_filter_values)
assert 200 == response.status_code
response_json = response.json()
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
query_with_non_existing_filter_value = {
"questions": ["Where does Carla lives?"],
"filters": {"name": ["invalid-name"]},
}
response = client.post(url="/models/1/doc-qa", json=query_with_non_existing_filter_value)
assert 200 == response.status_code
response_json = response.json()
assert len(response_json["results"][0]["answers"]) == 0
@pytest.mark.slow
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_query_api_filters(reader, document_store_with_docs):
client = get_test_client_and_override_dependencies(reader, document_store_with_docs)
query = {