2020-08-06 04:36:56 -04:00
|
|
|
import json
|
2020-04-15 14:04:30 +02:00
|
|
|
import logging
|
2020-07-07 12:28:41 +02:00
|
|
|
import time
|
2020-04-15 14:04:30 +02:00
|
|
|
from datetime import datetime
|
2020-10-16 13:25:31 +02:00
|
|
|
from typing import Any, Dict, List, Optional
|
2020-04-15 14:04:30 +02:00
|
|
|
|
2020-04-22 11:28:23 +02:00
|
|
|
import elasticapm
|
2020-04-15 14:04:30 +02:00
|
|
|
from fastapi import APIRouter
|
|
|
|
from fastapi import HTTPException
|
|
|
|
|
|
|
|
from haystack import Finder
|
2020-10-16 13:25:31 +02:00
|
|
|
from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, DEFAULT_TOP_K_READER, ES_CONN_SCHEME, \
|
|
|
|
TEXT_FIELD_NAME, SEARCH_FIELD_NAME, EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, \
|
|
|
|
RETRIEVER_TYPE, EMBEDDING_MODEL_PATH, USE_GPU, READER_MODEL_PATH, BATCHSIZE, CONTEXT_WINDOW_SIZE, \
|
|
|
|
TOP_K_PER_CANDIDATE, NO_ANS_BOOST, MAX_PROCESSES, MAX_SEQ_LEN, DOC_STRIDE, CONCURRENT_REQUEST_PER_WORKER, \
|
|
|
|
FAQ_QUESTION_FIELD_NAME, EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER, NAME_FIELD_NAME, \
|
|
|
|
VECTOR_SIMILARITY_METRIC, CREATE_INDEX
|
|
|
|
|
|
|
|
from rest_api.controller.request import Question
|
|
|
|
from rest_api.controller.response import Answers, AnswersToIndividualQuestion
|
2020-10-15 18:41:36 +02:00
|
|
|
|
2020-06-22 12:07:12 +02:00
|
|
|
from rest_api.controller.utils import RequestLimiter
|
2020-09-16 18:33:23 +02:00
|
|
|
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
2020-10-21 17:15:35 +02:00
|
|
|
from haystack.reader.base import BaseReader
|
2020-04-15 14:04:30 +02:00
|
|
|
from haystack.reader.farm import FARMReader
|
2020-07-07 16:25:36 +02:00
|
|
|
from haystack.reader.transformers import TransformersReader
|
2020-06-10 17:22:37 +02:00
|
|
|
from haystack.retriever.base import BaseRetriever
|
2020-07-15 17:22:17 +02:00
|
|
|
from haystack.retriever.sparse import ElasticsearchRetriever, ElasticsearchFilterOnlyRetriever
|
2020-06-30 19:05:45 +02:00
|
|
|
from haystack.retriever.dense import EmbeddingRetriever
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
# Init global components: DocumentStore, Retriever, Reader, Finder
|
|
|
|
document_store = ElasticsearchDocumentStore(
|
|
|
|
host=DB_HOST,
|
2020-06-09 04:56:56 -03:00
|
|
|
port=DB_PORT,
|
2020-04-15 14:04:30 +02:00
|
|
|
username=DB_USER,
|
|
|
|
password=DB_PW,
|
|
|
|
index=DB_INDEX,
|
|
|
|
scheme=ES_CONN_SCHEME,
|
|
|
|
ca_certs=False,
|
|
|
|
verify_certs=False,
|
|
|
|
text_field=TEXT_FIELD_NAME,
|
2020-08-10 05:34:39 -04:00
|
|
|
name_field=NAME_FIELD_NAME,
|
2020-04-15 14:04:30 +02:00
|
|
|
search_fields=SEARCH_FIELD_NAME,
|
|
|
|
embedding_dim=EMBEDDING_DIM,
|
|
|
|
embedding_field=EMBEDDING_FIELD_NAME,
|
2020-06-10 17:22:37 +02:00
|
|
|
excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore
|
2020-06-11 12:36:19 +02:00
|
|
|
faq_question_field=FAQ_QUESTION_FIELD_NAME,
|
2020-10-15 18:41:36 +02:00
|
|
|
create_index=CREATE_INDEX,
|
|
|
|
similarity=VECTOR_SIMILARITY_METRIC
|
2020-04-15 14:04:30 +02:00
|
|
|
)
|
|
|
|
|
2020-07-15 17:22:17 +02:00
|
|
|
if RETRIEVER_TYPE == "EmbeddingRetriever":
|
2020-06-16 13:58:30 +02:00
|
|
|
retriever = EmbeddingRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
embedding_model=EMBEDDING_MODEL_PATH,
|
2020-06-17 16:26:21 +02:00
|
|
|
model_format=EMBEDDING_MODEL_FORMAT,
|
2020-07-16 10:45:01 +02:00
|
|
|
use_gpu=USE_GPU
|
2020-06-16 13:58:30 +02:00
|
|
|
) # type: BaseRetriever
|
2020-07-15 17:22:17 +02:00
|
|
|
elif RETRIEVER_TYPE == "ElasticsearchRetriever":
|
2020-04-23 16:09:53 +02:00
|
|
|
retriever = ElasticsearchRetriever(document_store=document_store)
|
2020-07-15 17:22:17 +02:00
|
|
|
elif RETRIEVER_TYPE is None or RETRIEVER_TYPE == "ElasticsearchFilterOnlyRetriever":
|
|
|
|
retriever = ElasticsearchFilterOnlyRetriever(document_store=document_store)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Could not load Retriever of type '{RETRIEVER_TYPE}'. "
|
|
|
|
f"Please adjust RETRIEVER_TYPE to one of: "
|
|
|
|
f"'EmbeddingRetriever', 'ElasticsearchRetriever', 'ElasticsearchFilterOnlyRetriever', None"
|
|
|
|
f"OR modify rest_api/search.py to support your retriever"
|
|
|
|
)
|
|
|
|
|
2020-04-23 16:09:53 +02:00
|
|
|
if READER_MODEL_PATH: # for extractive doc-qa
|
2020-07-15 17:22:17 +02:00
|
|
|
if READER_TYPE == "TransformersReader":
|
2020-07-07 16:25:36 +02:00
|
|
|
use_gpu = -1 if not USE_GPU else GPU_NUMBER
|
|
|
|
reader = TransformersReader(
|
2020-10-21 17:15:35 +02:00
|
|
|
model_name_or_path=str(READER_MODEL_PATH),
|
2020-07-07 16:25:36 +02:00
|
|
|
use_gpu=use_gpu,
|
|
|
|
context_window_size=CONTEXT_WINDOW_SIZE,
|
|
|
|
tokenizer=str(READER_TOKENIZER)
|
2020-10-21 17:15:35 +02:00
|
|
|
) # type: Optional[BaseReader]
|
2020-07-15 17:22:17 +02:00
|
|
|
elif READER_TYPE == "FARMReader":
|
2020-07-07 16:25:36 +02:00
|
|
|
reader = FARMReader(
|
|
|
|
model_name_or_path=str(READER_MODEL_PATH),
|
|
|
|
batch_size=BATCHSIZE,
|
|
|
|
use_gpu=USE_GPU,
|
|
|
|
context_window_size=CONTEXT_WINDOW_SIZE,
|
|
|
|
top_k_per_candidate=TOP_K_PER_CANDIDATE,
|
|
|
|
no_ans_boost=NO_ANS_BOOST,
|
|
|
|
num_processes=MAX_PROCESSES,
|
|
|
|
max_seq_len=MAX_SEQ_LEN,
|
|
|
|
doc_stride=DOC_STRIDE,
|
2020-10-21 17:15:35 +02:00
|
|
|
) # type: Optional[BaseReader]
|
2020-07-15 17:22:17 +02:00
|
|
|
else:
|
|
|
|
raise ValueError(f"Could not load Reader of type '{READER_TYPE}'. "
|
|
|
|
f"Please adjust READER_TYPE to one of: "
|
|
|
|
f"'FARMReader', 'TransformersReader', None"
|
|
|
|
)
|
2020-04-15 14:04:30 +02:00
|
|
|
else:
|
2020-04-23 16:09:53 +02:00
|
|
|
reader = None # don't need one for pure FAQ matching
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
FINDERS = {1: Finder(reader=reader, retriever=retriever)}
|
|
|
|
|
|
|
|
|
|
|
|
#############################################
|
|
|
|
# Endpoints
|
|
|
|
#############################################
|
2020-04-17 15:15:53 +02:00
|
|
|
doc_qa_limiter = RequestLimiter(CONCURRENT_REQUEST_PER_WORKER)
|
|
|
|
|
2020-10-16 13:25:31 +02:00
|
|
|
|
2020-04-15 14:04:30 +02:00
|
|
|
@router.post("/models/{model_id}/doc-qa", response_model=Answers, response_model_exclude_unset=True)
|
2020-10-16 13:25:31 +02:00
|
|
|
def doc_qa(model_id: int, question_request: Question):
|
2020-04-17 15:15:53 +02:00
|
|
|
with doc_qa_limiter.run():
|
2020-07-07 12:28:41 +02:00
|
|
|
start_time = time.time()
|
2020-04-17 15:15:53 +02:00
|
|
|
finder = FINDERS.get(model_id, None)
|
|
|
|
if not finder:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
|
|
|
)
|
|
|
|
|
2020-10-16 13:25:31 +02:00
|
|
|
results = search_documents(finder, question_request, start_time)
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
return {"results": results}
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/models/{model_id}/faq-qa", response_model=Answers, response_model_exclude_unset=True)
|
|
|
|
def faq_qa(model_id: int, request: Question):
|
|
|
|
finder = FINDERS.get(model_id, None)
|
|
|
|
if not finder:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
|
|
|
)
|
|
|
|
|
|
|
|
results = []
|
|
|
|
for question in request.questions:
|
|
|
|
if request.filters:
|
|
|
|
# put filter values into a list and remove filters with null value
|
2020-06-10 17:22:37 +02:00
|
|
|
filters = {key: [value] for key, value in request.filters.items() if value is not None}
|
2020-04-15 14:04:30 +02:00
|
|
|
logger.info(f" [{datetime.now()}] Request: {request}")
|
2020-06-10 17:22:37 +02:00
|
|
|
else:
|
|
|
|
filters = {}
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
result = finder.get_answers_via_similar_questions(
|
2020-06-10 17:22:37 +02:00
|
|
|
question=question, top_k_retriever=request.top_k_retriever, filters=filters,
|
2020-04-15 14:04:30 +02:00
|
|
|
)
|
|
|
|
results.append(result)
|
|
|
|
|
2020-04-22 11:28:23 +02:00
|
|
|
elasticapm.set_custom_context({"results": results})
|
2020-08-06 04:36:56 -04:00
|
|
|
logger.info(json.dumps({"request": request.dict(), "results": results}))
|
2020-04-15 14:04:30 +02:00
|
|
|
|
2020-04-22 11:28:23 +02:00
|
|
|
return {"results": results}
|
2020-10-16 13:25:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
@router.post("/models/{model_id}/query", response_model=Dict[str, Any], response_model_exclude_unset=True)
|
|
|
|
def query(model_id: int, query_request: Dict[str, Any], top_k_reader: int = DEFAULT_TOP_K_READER):
|
|
|
|
with doc_qa_limiter.run():
|
|
|
|
start_time = time.time()
|
|
|
|
finder = FINDERS.get(model_id, None)
|
|
|
|
if not finder:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=404, detail=f"Couldn't get Finder with ID {model_id}. Available IDs: {list(FINDERS.keys())}"
|
|
|
|
)
|
|
|
|
|
|
|
|
question_request = Question.from_elastic_query_dsl(query_request, top_k_reader)
|
|
|
|
|
|
|
|
answers = search_documents(finder, question_request, start_time)
|
|
|
|
response: Dict[str, Any] = {}
|
|
|
|
if answers and len(answers) > 0:
|
|
|
|
response = AnswersToIndividualQuestion.to_elastic_response_dsl(dict(answers[0]))
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
def search_documents(finder, question_request, start_time) -> List[AnswersToIndividualQuestion]:
|
|
|
|
results = []
|
|
|
|
for question in question_request.questions:
|
|
|
|
if question_request.filters:
|
|
|
|
# put filter values into a list and remove filters with null value
|
|
|
|
filters = {key: [value] for key, value in question_request.filters.items() if value is not None}
|
|
|
|
logger.info(f" [{datetime.now()}] Request: {question_request}")
|
|
|
|
else:
|
|
|
|
filters = {}
|
|
|
|
|
|
|
|
result = finder.get_answers(
|
|
|
|
question=question,
|
|
|
|
top_k_retriever=question_request.top_k_retriever,
|
|
|
|
top_k_reader=question_request.top_k_reader,
|
|
|
|
filters=filters,
|
|
|
|
)
|
|
|
|
results.append(result)
|
|
|
|
elasticapm.set_custom_context({"results": results})
|
|
|
|
end_time = time.time()
|
|
|
|
logger.info(
|
|
|
|
json.dumps({"request": question_request.dict(), "results": results,
|
|
|
|
"time": f"{(end_time - start_time):.2f}"}))
|
|
|
|
return results
|