diff --git a/rest_api/controller/request.py b/rest_api/controller/request.py new file mode 100644 index 000000000..9e32d5ffb --- /dev/null +++ b/rest_api/controller/request.py @@ -0,0 +1,73 @@ +import sys +from typing import Any, Collection, Dict, List, Optional + +from pydantic import BaseModel + +from rest_api.config import DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER + +MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1 + + +class Question(BaseModel): + questions: List[str] + filters: Optional[Dict[str, str]] = None + top_k_reader: int = DEFAULT_TOP_K_READER + top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER + + @classmethod + def from_elastic_query_dsl(cls, query_request: Dict[str, Any], top_k_reader: int = DEFAULT_TOP_K_READER): + + # Refer Query DSL + # https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-bool-query.html + # Currently do not support query matching with field parameter + query_strings: List[str] = [] + filters: Dict[str, str] = {} + top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER if "size" not in query_request else query_request["size"] + + cls._iterate_dsl_request(query_request, query_strings, filters) + + if len(query_strings) != 1: + raise SyntaxError('Only one valid `query` field required expected, ' + 'refer https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html') + + return cls(questions=query_strings, filters=filters if len(filters) else None, top_k_retriever=top_k_retriever, + top_k_reader=top_k_reader) + + @classmethod + def _iterate_dsl_request(cls, query_dsl: Any, query_strings: List[str], filters: Dict[str, str], depth: int = 0): + if depth == MAX_RECURSION_DEPTH: + raise RecursionError('Parsing incoming DSL reaching current value of the recursion limit') + + # For question: Only consider values of "query" key for "match" and "multi_match" request. + # For filter: Only consider Dict[str, str] value of "term" or "terms" key + if isinstance(query_dsl, List): + for item in query_dsl: + cls._iterate_dsl_request(item, query_strings, filters, depth + 1) + elif isinstance(query_dsl, Dict): + for key, value in query_dsl.items(): + # "query" value should be "str" type + if key == 'query' and isinstance(value, str): + query_strings.append(value) + elif key in ["filter", "filters"]: + cls._iterate_filters(value, filters, depth + 1) + elif isinstance(value, Collection): + cls._iterate_dsl_request(value, query_strings, filters, depth + 1) + + @classmethod + def _iterate_filters(cls, filter_dsl: Any, filters: Dict[str, str], depth: int = 0): + if depth == MAX_RECURSION_DEPTH: + raise RecursionError('Parsing incoming DSL reaching current value of the recursion limit') + + if isinstance(filter_dsl, List): + for item in filter_dsl: + cls._iterate_filters(item, filters, depth + 1) + elif isinstance(filter_dsl, Dict): + for key, value in filter_dsl.items(): + if key in ["term", "terms"]: + if isinstance(value, Dict): + for filter_key, filter_value in value.items(): + # Currently only accepting Dict[str, str] + if isinstance(filter_value, str): + filters[filter_key] = filter_value + elif isinstance(value, Collection): + cls._iterate_filters(value, filters, depth + 1) diff --git a/rest_api/controller/response.py b/rest_api/controller/response.py new file mode 100644 index 000000000..e05ed9efd --- /dev/null +++ b/rest_api/controller/response.py @@ -0,0 +1,40 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + + +class Answer(BaseModel): + answer: Optional[str] + question: Optional[str] + score: Optional[float] = None + probability: Optional[float] = None + context: Optional[str] + offset_start: int + offset_end: int + offset_start_in_doc: Optional[int] + offset_end_in_doc: Optional[int] + document_id: Optional[str] = None + meta: Optional[Dict[str, str]] + + +class AnswersToIndividualQuestion(BaseModel): + question: str + answers: List[Optional[Answer]] + + @staticmethod + def to_elastic_response_dsl(data: Dict[str, Any]): + result_dsl = {'hits': {'hits': [], 'total': {'value': len(data["answers"])}}} + for answer in data["answers"]: + + record = {"_source": {k: v for k, v in dict(answer).items()}} + record["_id"] = record["_source"].pop("document_id", None) + record["_score"] = record["_source"].pop("score", None) + + result_dsl['hits']['hits'].append(record) + + return result_dsl + + +class Answers(BaseModel): + results: List[AnswersToIndividualQuestion] + diff --git a/rest_api/controller/search.py b/rest_api/controller/search.py index 9d6f2ae8a..ae9bd7189 100644 --- a/rest_api/controller/search.py +++ b/rest_api/controller/search.py @@ -2,20 +2,22 @@ import json import logging import time from datetime import datetime -from typing import List, Dict, Optional +from typing import Any, Dict, List, Optional import elasticapm from fastapi import APIRouter from fastapi import HTTPException -from pydantic import BaseModel from haystack import Finder -from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_CONN_SCHEME, TEXT_FIELD_NAME, SEARCH_FIELD_NAME, \ - EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, RETRIEVER_TYPE, EMBEDDING_MODEL_PATH, USE_GPU, READER_MODEL_PATH, \ - BATCHSIZE, CONTEXT_WINDOW_SIZE, TOP_K_PER_CANDIDATE, NO_ANS_BOOST, MAX_PROCESSES, MAX_SEQ_LEN, DOC_STRIDE, \ - DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER, CONCURRENT_REQUEST_PER_WORKER, FAQ_QUESTION_FIELD_NAME, \ - EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER, NAME_FIELD_NAME, VECTOR_SIMILARITY_METRIC, \ - CREATE_INDEX +from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, DEFAULT_TOP_K_READER, ES_CONN_SCHEME, \ + TEXT_FIELD_NAME, SEARCH_FIELD_NAME, EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, \ + RETRIEVER_TYPE, EMBEDDING_MODEL_PATH, USE_GPU, READER_MODEL_PATH, BATCHSIZE, CONTEXT_WINDOW_SIZE, \ + TOP_K_PER_CANDIDATE, NO_ANS_BOOST, MAX_PROCESSES, MAX_SEQ_LEN, DOC_STRIDE, CONCURRENT_REQUEST_PER_WORKER, \ + FAQ_QUESTION_FIELD_NAME, EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER, NAME_FIELD_NAME, \ + VECTOR_SIMILARITY_METRIC, CREATE_INDEX + +from rest_api.controller.request import Question +from rest_api.controller.response import Answers, AnswersToIndividualQuestion from rest_api.controller.utils import RequestLimiter from haystack.document_store.elasticsearch import ElasticsearchDocumentStore @@ -49,7 +51,6 @@ document_store = ElasticsearchDocumentStore( similarity=VECTOR_SIMILARITY_METRIC ) - if RETRIEVER_TYPE == "EmbeddingRetriever": retriever = EmbeddingRetriever( document_store=document_store, @@ -68,7 +69,6 @@ else: f"OR modify rest_api/search.py to support your retriever" ) - if READER_MODEL_PATH: # for extractive doc-qa if READER_TYPE == "TransformersReader": use_gpu = -1 if not USE_GPU else GPU_NUMBER @@ -101,46 +101,14 @@ else: FINDERS = {1: Finder(reader=reader, retriever=retriever)} -############################################# -# Data schema for request & response -############################################# -class Question(BaseModel): - questions: List[str] - filters: Optional[Dict[str, str]] = None - top_k_reader: int = DEFAULT_TOP_K_READER - top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER - - -class Answer(BaseModel): - answer: Optional[str] - question: Optional[str] - score: Optional[float] = None - probability: Optional[float] = None - context: Optional[str] - offset_start: int - offset_end: int - offset_start_in_doc: Optional[int] - offset_end_in_doc: Optional[int] - document_id: Optional[str] = None - meta: Optional[Dict[str, str]] - - -class AnswersToIndividualQuestion(BaseModel): - question: str - answers: List[Optional[Answer]] - - -class Answers(BaseModel): - results: List[AnswersToIndividualQuestion] - - ############################################# # Endpoints ############################################# doc_qa_limiter = RequestLimiter(CONCURRENT_REQUEST_PER_WORKER) + @router.post("/models/{model_id}/doc-qa", response_model=Answers, response_model_exclude_unset=True) -def doc_qa(model_id: int, request: Question): +def doc_qa(model_id: int, question_request: Question): with doc_qa_limiter.run(): start_time = time.time() finder = FINDERS.get(model_id, None) @@ -149,26 +117,7 @@ def doc_qa(model_id: int, request: Question): status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}" ) - results = [] - for question in request.questions: - if request.filters: - # put filter values into a list and remove filters with null value - filters = {key: [value] for key, value in request.filters.items() if value is not None} - logger.info(f" [{datetime.now()}] Request: {request}") - else: - filters = {} - - result = finder.get_answers( - question=question, - top_k_retriever=request.top_k_retriever, - top_k_reader=request.top_k_reader, - filters=filters, - ) - results.append(result) - - elasticapm.set_custom_context({"results": results}) - end_time = time.time() - logger.info(json.dumps({"request": request.dict(), "results": results, "time": f"{(end_time - start_time):.2f}"})) + results = search_documents(finder, question_request, start_time) return {"results": results} @@ -199,3 +148,48 @@ def faq_qa(model_id: int, request: Question): logger.info(json.dumps({"request": request.dict(), "results": results})) return {"results": results} + + +@router.post("/models/{model_id}/query", response_model=Dict[str, Any], response_model_exclude_unset=True) +def query(model_id: int, query_request: Dict[str, Any], top_k_reader: int = DEFAULT_TOP_K_READER): + with doc_qa_limiter.run(): + start_time = time.time() + finder = FINDERS.get(model_id, None) + if not finder: + raise HTTPException( + status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}" + ) + + question_request = Question.from_elastic_query_dsl(query_request, top_k_reader) + + answers = search_documents(finder, question_request, start_time) + response: Dict[str, Any] = {} + if answers and len(answers) > 0: + response = AnswersToIndividualQuestion.to_elastic_response_dsl(dict(answers[0])) + + return response + + +def search_documents(finder, question_request, start_time) -> List[AnswersToIndividualQuestion]: + results = [] + for question in question_request.questions: + if question_request.filters: + # put filter values into a list and remove filters with null value + filters = {key: [value] for key, value in question_request.filters.items() if value is not None} + logger.info(f" [{datetime.now()}] Request: {question_request}") + else: + filters = {} + + result = finder.get_answers( + question=question, + top_k_retriever=question_request.top_k_retriever, + top_k_reader=question_request.top_k_reader, + filters=filters, + ) + results.append(result) + elasticapm.set_custom_context({"results": results}) + end_time = time.time() + logger.info( + json.dumps({"request": question_request.dict(), "results": results, + "time": f"{(end_time - start_time):.2f}"})) + return results diff --git a/test/test_elastic_dsl_convertor.py b/test/test_elastic_dsl_convertor.py new file mode 100644 index 000000000..828fdabf0 --- /dev/null +++ b/test/test_elastic_dsl_convertor.py @@ -0,0 +1,201 @@ +import pytest + +from rest_api.controller.request import Question +from rest_api.controller.response import Answer, AnswersToIndividualQuestion + + +def test_query_dsl_with_without_valid_query_field(): + query = { + "query": { + "bool": { + "must": [ + {"match": {"title": "Search"}}, + {"match": {"content": "Elasticsearch"}} + ], + "filter": [ + {"term": {"status": "published"}}, + {"range": {"publish_date": {"gte": "2015-01-01"}}} + ] + } + } + } + with pytest.raises(Exception): + Question.from_elastic_query_dsl(query) + + +def test_query_dsl_with_without_multiple_query_field(): + query = { + "query": { + "bool": { + "should": [ + {"match": {"name.first": {"query": "shay", "_name": "first"}}}, + {"match": {"name.last": {"query": "banon", "_name": "last"}}} + ], + "filter": { + "terms": { + "name.last": ["banon", "kimchy"], + "_name": "test" + } + } + } + } + } + with pytest.raises(Exception): + Question.from_elastic_query_dsl(query) + + +def test_query_dsl_with_single_query(): + query = { + "query": { + "match": { + "message": { + "query": "this is a test" + } + } + } + } + question = Question.from_elastic_query_dsl(query) + assert 1 == len(question.questions) + assert question.questions.__contains__("this is a test") + assert question.filters is None + + +def test_query_dsl_with_filter(): + query = { + "query": { + "bool": { + "should": [ + {"match": {"name.first": {"query": "shay", "_name": "first"}}} + ], + "filter": { + "terms": { + "name.last": ["banon", "kimchy"], + "_name": "test" + } + } + } + } + } + question = Question.from_elastic_query_dsl(query) + assert 1 == len(question.questions) + assert question.questions.__contains__("shay") + assert len(question.filters) == 1 + assert question.filters["_name"] == "test" + + +def test_query_dsl_with_complex_query(): + query = { + "size": 17, + "query": { + "bool": { + "should": [ + { + "multi_match": { + "query": "I am test1", + "type": "most_fields", + "fields": ["text", "title"] + } + } + ], + "filter": [ + { + "terms": { + "year": "2020" + } + }, + { + "terms": { + "quarter": "1" + } + }, + { + "range": { + "date": { + "gte": "12-12-12" + } + } + } + ] + } + } + } + top_k_reader = 7 + question = Question.from_elastic_query_dsl(query, top_k_reader) + assert 1 == len(question.questions) + assert question.questions.__contains__("I am test1") + assert 2 == len(question.filters) + assert question.filters["year"] == "2020" + assert question.filters["quarter"] == "1" + assert 17 == question.top_k_retriever + assert 7 == question.top_k_reader + + +def test_response_dsl_with_empty_answers(): + sample_answer = AnswersToIndividualQuestion(question="test question", answers=[]) + response = AnswersToIndividualQuestion.to_elastic_response_dsl(sample_answer.__dict__) + assert 0 == response['hits']['total']['value'] + assert 0 == len(response['hits']['hits']) + + +def test_response_dsl_with_answers(): + full_answer = Answer( + answer="answer", + question="question", + score=0.1234, + probability=0.5678, + context="context", + offset_start=200, + offset_end=300, + offset_start_in_doc=2000, + offset_end_in_doc=2100, + document_id="id_1", + meta={ + "meta1": "meta_value" + } + ) + empty_answer = Answer( + answer=None, + question=None, + score=None, + probability=None, + context=None, + offset_start=250, + offset_end=350, + offset_start_in_doc=None, + offset_end_in_doc=None, + document_id=None, + meta=None + ) + sample_answer = AnswersToIndividualQuestion(question="test question", answers=[full_answer, empty_answer]) + response = AnswersToIndividualQuestion.to_elastic_response_dsl(sample_answer.__dict__) + + # Test number of returned answers + assert response['hits']['total']['value'] == 2 + + # Test converted answers + hits = response['hits']['hits'] + assert len(hits) == 2 + # Test full answer record + assert hits[0]["_score"] == 0.1234 + assert hits[0]["_id"] == "id_1" + assert hits[0]["_source"]["answer"] == "answer" + assert hits[0]["_source"]["question"] == "question" + assert hits[0]["_source"]["context"] == "context" + assert hits[0]["_source"]["probability"] == 0.5678 + assert hits[0]["_source"]["offset_start"] == 200 + assert hits[0]["_source"]["offset_end"] == 300 + assert hits[0]["_source"]["offset_start_in_doc"] == 2000 + assert hits[0]["_source"]["offset_end_in_doc"] == 2100 + assert hits[0]["_source"]["meta"] == {"meta1": "meta_value"} + # Test empty answer record + assert hits[1]["_score"] is None + assert hits[1]["_id"] is None + assert hits[1]["_source"]["answer"] is None + assert hits[1]["_source"]["question"] is None + assert hits[1]["_source"]["context"] is None + assert hits[1]["_source"]["probability"] is None + assert hits[1]["_source"]["offset_start"] == 250 + assert hits[1]["_source"]["offset_end"] == 350 + assert hits[1]["_source"]["offset_start_in_doc"] is None + assert hits[1]["_source"]["offset_end_in_doc"] is None + assert hits[1]["_source"]["meta"] is None diff --git a/test/test_rest_api.py b/test/test_rest_api.py new file mode 100644 index 000000000..19927acea --- /dev/null +++ b/test/test_rest_api.py @@ -0,0 +1,58 @@ +import pytest +from fastapi.testclient import TestClient + +from haystack import Finder +from haystack.retriever.sparse import ElasticsearchRetriever + +# TODO: Add integration tests for other APIs + + +def get_test_client_and_override_dependencies(reader, document_store_with_docs): + from rest_api.application import app + from rest_api.controller import search + + search.document_store = document_store_with_docs + search.retriever = ElasticsearchRetriever(document_store=document_store_with_docs) + search.FINDERS = {1: Finder(reader=reader, retriever=search.retriever)} + + return TestClient(app) + + +@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) +@pytest.mark.parametrize("reader", ["farm"], indirect=True) +def test_query_api(reader, document_store_with_docs): + client = get_test_client_and_override_dependencies(reader, document_store_with_docs) + + query = { + "size": 1, + "query": { + "bool": { + "should": [ + { + "multi_match": { + "query": "Where Paul lives?" + } + } + ], + "filter": [ + { + "terms": { + "name": "filename2" + } + } + ] + } + } + } + + response = client.post(url="/models/1/query?top_k_reader=1", json=query) + assert 200 == response.status_code + response_json = response.json() + assert 1 == response_json['hits']['total']['value'] + assert 1 == len(response_json['hits']['hits']) + assert response_json['hits']['hits'][0]["_score"] is not None + assert response_json['hits']['hits'][0]["_source"]["meta"] is not None + assert response_json['hits']['hits'][0]["_id"] is not None + assert "New York" == response_json['hits']['hits'][0]["_source"]["answer"] + assert "My name is Paul and I live in New York" == response_json['hits']['hits'][0]["_source"]["context"] +