mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-07 13:24:16 +00:00
Allow list of filter values in REST API (#568)
This commit is contained in:
parent
2b352d6ac4
commit
acd088808b
@ -1,5 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
from typing import Any, Collection, Dict, List, Optional
|
from typing import Any, Collection, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
|
|||||||
|
|
||||||
class Question(BaseModel):
|
class Question(BaseModel):
|
||||||
questions: List[str]
|
questions: List[str]
|
||||||
filters: Optional[Dict[str, Optional[str]]] = None
|
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
||||||
top_k_reader: int = DEFAULT_TOP_K_READER
|
top_k_reader: int = DEFAULT_TOP_K_READER
|
||||||
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER
|
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER
|
||||||
|
|
||||||
|
|||||||
@ -117,7 +117,7 @@ def doc_qa(model_id: int, question_request: Question):
|
|||||||
finder = FINDERS.get(model_id, None)
|
finder = FINDERS.get(model_id, None)
|
||||||
if not finder:
|
if not finder:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
results = search_documents(finder, question_request, start_time)
|
results = search_documents(finder, question_request, start_time)
|
||||||
@ -130,14 +130,20 @@ def faq_qa(model_id: int, request: Question):
|
|||||||
finder = FINDERS.get(model_id, None)
|
finder = FINDERS.get(model_id, None)
|
||||||
if not finder:
|
if not finder:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for question in request.questions:
|
for question in request.questions:
|
||||||
if request.filters:
|
if request.filters:
|
||||||
# put filter values into a list and remove filters with null value
|
# put filter values into a list and remove filters with null value
|
||||||
filters = {key: [value] for key, value in request.filters.items() if value is not None}
|
filters = {}
|
||||||
|
for key, values in request.filters.items():
|
||||||
|
if values is None:
|
||||||
|
continue
|
||||||
|
if not isinstance(values, list):
|
||||||
|
values = [values]
|
||||||
|
filters[key] = values
|
||||||
logger.info(f" [{datetime.now()}] Request: {request}")
|
logger.info(f" [{datetime.now()}] Request: {request}")
|
||||||
else:
|
else:
|
||||||
filters = {}
|
filters = {}
|
||||||
@ -160,7 +166,7 @@ def query(model_id: int, query_request: Dict[str, Any], top_k_reader: int = DEFA
|
|||||||
finder = FINDERS.get(model_id, None)
|
finder = FINDERS.get(model_id, None)
|
||||||
if not finder:
|
if not finder:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
status_code=404, detail=f"Could not get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
question_request = Question.from_elastic_query_dsl(query_request, top_k_reader)
|
question_request = Question.from_elastic_query_dsl(query_request, top_k_reader)
|
||||||
@ -178,7 +184,13 @@ def search_documents(finder, question_request, start_time) -> List[AnswersToIndi
|
|||||||
for question in question_request.questions:
|
for question in question_request.questions:
|
||||||
if question_request.filters:
|
if question_request.filters:
|
||||||
# put filter values into a list and remove filters with null value
|
# put filter values into a list and remove filters with null value
|
||||||
filters = {key: [value] for key, value in question_request.filters.items() if value is not None}
|
filters = {}
|
||||||
|
for key, values in question_request.filters.items():
|
||||||
|
if values is None:
|
||||||
|
continue
|
||||||
|
if not isinstance(values, list):
|
||||||
|
values = [values]
|
||||||
|
filters[key] = values
|
||||||
logger.info(f" [{datetime.now()}] Request: {question_request}")
|
logger.info(f" [{datetime.now()}] Request: {question_request}")
|
||||||
else:
|
else:
|
||||||
filters = {}
|
filters = {}
|
||||||
|
|||||||
@ -22,7 +22,45 @@ def get_test_client_and_override_dependencies(reader, document_store_with_docs):
|
|||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||||
def test_query_api(reader, document_store_with_docs):
|
def test_qa_api_filters(reader, document_store_with_docs):
|
||||||
|
client = get_test_client_and_override_dependencies(reader, document_store_with_docs)
|
||||||
|
|
||||||
|
query_with_no_filter_value = {"questions": ["Where does Carla lives?"]}
|
||||||
|
response = client.post(url="/models/1/doc-qa", json=query_with_no_filter_value)
|
||||||
|
assert 200 == response.status_code
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
|
||||||
|
|
||||||
|
query_with_single_filter_value = {"questions": ["Where does Carla lives?"], "filters": {"name": "filename1"}}
|
||||||
|
response = client.post(url="/models/1/doc-qa", json=query_with_single_filter_value)
|
||||||
|
assert 200 == response.status_code
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
|
||||||
|
|
||||||
|
query_with_a_list_of_filter_values = {
|
||||||
|
"questions": ["Where does Carla lives?"],
|
||||||
|
"filters": {"name": ["filename1", "filename2"]},
|
||||||
|
}
|
||||||
|
response = client.post(url="/models/1/doc-qa", json=query_with_a_list_of_filter_values)
|
||||||
|
assert 200 == response.status_code
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["results"][0]["answers"][0]["answer"] == "Berlin"
|
||||||
|
|
||||||
|
query_with_non_existing_filter_value = {
|
||||||
|
"questions": ["Where does Carla lives?"],
|
||||||
|
"filters": {"name": ["invalid-name"]},
|
||||||
|
}
|
||||||
|
response = client.post(url="/models/1/doc-qa", json=query_with_non_existing_filter_value)
|
||||||
|
assert 200 == response.status_code
|
||||||
|
response_json = response.json()
|
||||||
|
assert len(response_json["results"][0]["answers"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.elasticsearch
|
||||||
|
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||||
|
def test_query_api_filters(reader, document_store_with_docs):
|
||||||
client = get_test_client_and_override_dependencies(reader, document_store_with_docs)
|
client = get_test_client_and_override_dependencies(reader, document_store_with_docs)
|
||||||
|
|
||||||
query = {
|
query = {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user