2020-04-15 14:04:30 +02:00
|
|
|
from typing import Optional
|
2021-02-15 10:48:59 +01:00
|
|
|
import time
|
2020-04-15 14:04:30 +02:00
|
|
|
|
2020-07-31 11:34:06 +02:00
|
|
|
from fastapi import APIRouter
|
2020-04-15 14:04:30 +02:00
|
|
|
from pydantic import BaseModel, Field
|
2021-02-15 10:48:59 +01:00
|
|
|
from typing import Dict, Union, List
|
2020-04-15 14:04:30 +02:00
|
|
|
|
2020-09-16 18:33:23 +02:00
|
|
|
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
2020-06-22 12:07:12 +02:00
|
|
|
from rest_api.config import (
|
2020-04-15 14:04:30 +02:00
|
|
|
DB_HOST,
|
2020-06-09 04:56:56 -03:00
|
|
|
DB_PORT,
|
2020-04-15 14:04:30 +02:00
|
|
|
DB_USER,
|
|
|
|
DB_PW,
|
|
|
|
DB_INDEX,
|
2021-01-11 12:24:09 +01:00
|
|
|
DB_INDEX_FEEDBACK,
|
2020-04-15 14:04:30 +02:00
|
|
|
ES_CONN_SCHEME,
|
|
|
|
TEXT_FIELD_NAME,
|
|
|
|
SEARCH_FIELD_NAME,
|
|
|
|
EMBEDDING_DIM,
|
|
|
|
EMBEDDING_FIELD_NAME,
|
|
|
|
EXCLUDE_META_DATA_FIELDS,
|
2020-07-31 11:34:06 +02:00
|
|
|
FAQ_QUESTION_FIELD_NAME,
|
2020-10-15 19:03:58 +02:00
|
|
|
CREATE_INDEX,
|
2021-01-12 13:00:56 +01:00
|
|
|
VECTOR_SIMILARITY_METRIC,
|
|
|
|
UPDATE_EXISTING_DOCUMENTS
|
2020-04-15 14:04:30 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
document_store = ElasticsearchDocumentStore(
|
|
|
|
host=DB_HOST,
|
2020-06-09 04:56:56 -03:00
|
|
|
port=DB_PORT,
|
2020-04-15 14:04:30 +02:00
|
|
|
username=DB_USER,
|
|
|
|
password=DB_PW,
|
|
|
|
index=DB_INDEX,
|
2021-01-11 12:24:09 +01:00
|
|
|
label_index=DB_INDEX_FEEDBACK,
|
2020-04-15 14:04:30 +02:00
|
|
|
scheme=ES_CONN_SCHEME,
|
|
|
|
ca_certs=False,
|
|
|
|
verify_certs=False,
|
|
|
|
text_field=TEXT_FIELD_NAME,
|
|
|
|
search_fields=SEARCH_FIELD_NAME,
|
2020-07-31 11:34:06 +02:00
|
|
|
faq_question_field=FAQ_QUESTION_FIELD_NAME,
|
2020-04-15 14:04:30 +02:00
|
|
|
embedding_dim=EMBEDDING_DIM,
|
|
|
|
embedding_field=EMBEDDING_FIELD_NAME,
|
2020-06-10 17:22:37 +02:00
|
|
|
excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore
|
2020-10-15 19:03:58 +02:00
|
|
|
create_index=CREATE_INDEX,
|
2021-01-12 13:00:56 +01:00
|
|
|
update_existing_documents=UPDATE_EXISTING_DOCUMENTS,
|
2020-10-15 19:03:58 +02:00
|
|
|
similarity=VECTOR_SIMILARITY_METRIC
|
2020-04-15 14:04:30 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-07-31 11:34:06 +02:00
|
|
|
class FAQQAFeedback(BaseModel):
|
2020-04-15 14:04:30 +02:00
|
|
|
question: str = Field(..., description="The question input by the user, i.e., the query.")
|
2020-07-31 11:34:06 +02:00
|
|
|
is_correct_answer: bool = Field(..., description="Whether the answer is correct or not.")
|
2020-04-15 14:04:30 +02:00
|
|
|
document_id: str = Field(..., description="The document in the query result for which feedback is given.")
|
|
|
|
model_id: Optional[int] = Field(None, description="The model used for the query.")
|
|
|
|
|
|
|
|
|
2020-07-31 11:34:06 +02:00
|
|
|
class DocQAFeedback(FAQQAFeedback):
|
|
|
|
is_correct_document: bool = Field(
|
|
|
|
...,
|
|
|
|
description="In case of negative feedback, there could be two cases; incorrect answer but correct "
|
|
|
|
"document & incorrect document. This flag denotes if the returned document was correct.",
|
|
|
|
)
|
|
|
|
answer: str = Field(..., description="The answer string.")
|
|
|
|
offset_start_in_doc: int = Field(
|
|
|
|
..., description="The answer start offset in the original doc. Only required for doc-qa feedback."
|
|
|
|
)
|
|
|
|
|
2021-02-15 10:48:59 +01:00
|
|
|
class FilterRequest(BaseModel):
|
|
|
|
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
2020-07-31 11:34:06 +02:00
|
|
|
|
2020-04-15 14:04:30 +02:00
|
|
|
@router.post("/doc-qa-feedback")
|
2020-07-31 11:34:06 +02:00
|
|
|
def doc_qa_feedback(feedback: DocQAFeedback):
|
|
|
|
document_store.write_labels([{"origin": "user-feedback", **feedback.dict()}])
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
@router.post("/faq-qa-feedback")
|
2020-07-31 11:34:06 +02:00
|
|
|
def faq_qa_feedback(feedback: FAQQAFeedback):
|
|
|
|
feedback_payload = {"is_correct_document": feedback.is_correct_answer, "answer": None, **feedback.dict()}
|
|
|
|
document_store.write_labels([{"origin": "user-feedback-faq", **feedback_payload}])
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
|
2021-02-15 10:48:59 +01:00
|
|
|
@router.post("/eval-doc-qa-feedback")
|
|
|
|
def eval_doc_qa_feedback(filters: FilterRequest = None):
|
|
|
|
"""
|
|
|
|
Return basic accuracy metrics based on the user feedback.
|
|
|
|
Which ratio of answers was correct? Which ratio of documents was correct?
|
|
|
|
You can supply filters in the request to only use a certain subset of labels.
|
|
|
|
|
|
|
|
**Example:**
|
|
|
|
|
|
|
|
```
|
|
|
|
| curl --location --request POST 'http://127.0.0.1:8000/eval-doc-qa-feedback' \
|
|
|
|
| --header 'Content-Type: application/json' \
|
|
|
|
| --data-raw '{ "filters": {"document_id": ["XRR3xnEBCYVTkbTystOB"]} }'
|
|
|
|
"""
|
|
|
|
|
|
|
|
if filters:
|
|
|
|
filters = filters.filters
|
|
|
|
filters["origin"] = ["user-feedback"]
|
|
|
|
else:
|
|
|
|
filters = {"origin": ["user-feedback"]}
|
|
|
|
|
|
|
|
labels = document_store.get_all_labels(
|
|
|
|
index=DB_INDEX_FEEDBACK,
|
|
|
|
filters=filters
|
|
|
|
)
|
|
|
|
|
|
|
|
if len(labels) > 0:
|
|
|
|
answer_feedback = [1 if l.is_correct_answer else 0 for l in labels]
|
|
|
|
doc_feedback = [1 if l.is_correct_document else 0 for l in labels]
|
|
|
|
|
|
|
|
answer_accuracy = sum(answer_feedback)/len(answer_feedback)
|
|
|
|
doc_accuracy = sum(doc_feedback)/len(doc_feedback)
|
|
|
|
|
|
|
|
res = {"answer_accuracy": answer_accuracy,
|
|
|
|
"document_accuracy": doc_accuracy,
|
|
|
|
"n_feedback": len(labels)}
|
|
|
|
else:
|
|
|
|
res = {"answer_accuracy": None,
|
|
|
|
"document_accuracy": None,
|
|
|
|
"n_feedback": 0}
|
|
|
|
return res
|
|
|
|
|
2020-04-15 14:04:30 +02:00
|
|
|
@router.get("/export-doc-qa-feedback")
|
2020-07-31 11:34:06 +02:00
|
|
|
def export_doc_qa_feedback(context_size: int = 2_000):
|
2020-04-15 14:04:30 +02:00
|
|
|
"""
|
|
|
|
SQuAD format JSON export for question/answer pairs that were marked as "relevant".
|
2020-07-31 11:34:06 +02:00
|
|
|
|
|
|
|
The context_size param can be used to limit response size for large documents.
|
2020-04-15 14:04:30 +02:00
|
|
|
"""
|
2020-07-31 11:34:06 +02:00
|
|
|
labels = document_store.get_all_labels(
|
|
|
|
index=DB_INDEX_FEEDBACK, filters={"is_correct_answer": [True], "origin": ["user-feedback"]}
|
|
|
|
)
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
export_data = []
|
2020-07-31 11:34:06 +02:00
|
|
|
for label in labels:
|
|
|
|
document = document_store.get_document_by_id(label.document_id)
|
|
|
|
text = document.text
|
|
|
|
|
|
|
|
# the final length of context(including the answer string) is 'context_size'.
|
|
|
|
# we try to add equal characters for context before and after the answer string.
|
|
|
|
# if either beginning or end of text is reached, we correspondingly
|
|
|
|
# append more context characters at the other end of answer string.
|
|
|
|
context_to_add = int((context_size - len(label.answer)) / 2)
|
|
|
|
|
|
|
|
start_pos = max(label.offset_start_in_doc - context_to_add, 0)
|
|
|
|
additional_context_at_end = max(context_to_add - label.offset_start_in_doc, 0)
|
|
|
|
|
|
|
|
end_pos = min(label.offset_start_in_doc + len(label.answer) + context_to_add, len(text) - 1)
|
|
|
|
additional_context_at_start = max(label.offset_start_in_doc + len(label.answer) + context_to_add - len(text), 0)
|
|
|
|
|
|
|
|
start_pos = max(0, start_pos - additional_context_at_start)
|
|
|
|
end_pos = min(len(text) - 1, end_pos + additional_context_at_end)
|
|
|
|
|
|
|
|
context_to_export = text[start_pos:end_pos]
|
|
|
|
|
|
|
|
export_data.append({"paragraphs": [{"qas": label, "context": context_to_export}]})
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
export = {"data": export_data}
|
|
|
|
|
|
|
|
return export
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/export-faq-qa-feedback")
|
|
|
|
def export_faq_feedback():
|
|
|
|
"""
|
|
|
|
Export feedback for faq-qa in JSON format.
|
|
|
|
"""
|
2020-07-31 11:34:06 +02:00
|
|
|
|
|
|
|
labels = document_store.get_all_labels(index=DB_INDEX_FEEDBACK, filters={"origin": ["user-feedback-faq"]})
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
export_data = []
|
2020-07-31 11:34:06 +02:00
|
|
|
for label in labels:
|
|
|
|
document = document_store.get_document_by_id(label.document_id)
|
|
|
|
feedback = {
|
|
|
|
"question": document.question,
|
|
|
|
"query": label.question,
|
|
|
|
"is_correct_answer": label.is_correct_answer,
|
|
|
|
"is_correct_document": label.is_correct_answer,
|
|
|
|
}
|
|
|
|
export_data.append(feedback)
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
export = {"data": export_data}
|
|
|
|
|
|
|
|
return export
|