2022-04-12 16:41:05 +02:00
|
|
|
from typing import Any, Dict, List, Union, Optional
|
2022-02-09 18:27:12 +01:00
|
|
|
|
2021-04-07 17:53:32 +02:00
|
|
|
import json
|
|
|
|
import logging
|
2020-04-15 14:04:30 +02:00
|
|
|
|
2022-04-12 16:41:05 +02:00
|
|
|
from fastapi import FastAPI, APIRouter
|
2022-02-14 11:43:26 +01:00
|
|
|
from haystack.schema import Label
|
2022-04-12 16:41:05 +02:00
|
|
|
from haystack.document_stores import BaseDocumentStore
|
2022-02-14 11:43:26 +01:00
|
|
|
from rest_api.schema import FilterRequest, LabelSerialized, CreateLabelSerialized
|
2022-04-12 16:41:05 +02:00
|
|
|
from rest_api.utils import get_app, get_pipelines
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
|
2021-04-07 17:53:32 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2022-04-12 16:41:05 +02:00
|
|
|
router = APIRouter()
|
|
|
|
app: FastAPI = get_app()
|
|
|
|
document_store: BaseDocumentStore = get_pipelines().get("document_store", None)
|
|
|
|
|
2020-04-15 14:04:30 +02:00
|
|
|
|
2021-11-29 17:03:54 +01:00
|
|
|
@router.post("/feedback")
|
2022-02-14 11:43:26 +01:00
|
|
|
def post_feedback(feedback: Union[LabelSerialized, CreateLabelSerialized]):
|
2022-01-27 13:06:01 +01:00
|
|
|
"""
|
2022-02-16 16:28:55 +01:00
|
|
|
This endpoint allows the API user to submit feedback on an answer for a particular query.
|
|
|
|
|
|
|
|
For example, the user can send feedback on whether the answer was correct and
|
2022-02-03 13:43:18 +01:00
|
|
|
whether the right snippet was identified as the answer.
|
2022-02-16 16:28:55 +01:00
|
|
|
|
|
|
|
Information submitted through this endpoint is used to train the underlying QA model.
|
2022-01-27 13:06:01 +01:00
|
|
|
"""
|
2022-02-16 16:28:55 +01:00
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
if feedback.origin is None:
|
|
|
|
feedback.origin = "user-feedback"
|
2022-02-14 11:43:26 +01:00
|
|
|
|
|
|
|
label = Label(**feedback.dict())
|
2022-04-12 16:41:05 +02:00
|
|
|
document_store.write_labels([label])
|
2021-10-13 14:23:23 +02:00
|
|
|
|
|
|
|
|
2022-04-12 16:41:05 +02:00
|
|
|
@router.get("/feedback", response_model=List[LabelSerialized])
|
2021-11-29 17:03:54 +01:00
|
|
|
def get_feedback():
|
2022-01-27 13:06:01 +01:00
|
|
|
"""
|
2022-02-16 16:28:55 +01:00
|
|
|
This endpoint allows the API user to retrieve all the feedback that has been submitted
|
|
|
|
through the `POST /feedback` endpoint.
|
2022-01-27 13:06:01 +01:00
|
|
|
"""
|
2022-04-12 16:41:05 +02:00
|
|
|
labels = document_store.get_all_labels()
|
2021-10-13 14:23:23 +02:00
|
|
|
return labels
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
|
2022-02-14 11:43:26 +01:00
|
|
|
@router.delete("/feedback")
|
|
|
|
def delete_feedback():
|
|
|
|
"""
|
|
|
|
This endpoint allows the API user to delete all the
|
|
|
|
feedback that has been sumbitted through the
|
|
|
|
`POST /feedback` endpoint
|
|
|
|
"""
|
2022-04-12 16:41:05 +02:00
|
|
|
all_labels = document_store.get_all_labels()
|
2022-02-14 11:43:26 +01:00
|
|
|
user_label_ids = [label.id for label in all_labels if label.origin == "user-feedback"]
|
2022-04-12 16:41:05 +02:00
|
|
|
document_store.delete_labels(ids=user_label_ids)
|
2022-02-14 11:43:26 +01:00
|
|
|
|
|
|
|
|
2021-11-29 17:03:54 +01:00
|
|
|
@router.post("/eval-feedback")
|
|
|
|
def get_feedback_metrics(filters: FilterRequest = None):
|
2021-02-15 10:48:59 +01:00
|
|
|
"""
|
2022-02-09 18:27:12 +01:00
|
|
|
This endpoint returns basic accuracy metrics based on user feedback,
|
|
|
|
e.g., the ratio of correct answers or correctly identified documents.
|
2022-01-27 13:06:01 +01:00
|
|
|
You can filter the output by document or label.
|
|
|
|
|
|
|
|
Example:
|
2022-02-16 16:28:55 +01:00
|
|
|
|
2022-01-27 13:06:01 +01:00
|
|
|
`curl --location --request POST 'http://127.0.0.1:8000/eval-doc-qa-feedback' \
|
|
|
|
--header 'Content-Type: application/json' \
|
|
|
|
--data-raw '{ "filters": {"document_id": ["XRR3xnEBCYVTkbTystOB"]} }'`
|
2021-02-15 10:48:59 +01:00
|
|
|
"""
|
|
|
|
|
|
|
|
if filters:
|
2022-02-09 18:27:12 +01:00
|
|
|
filters_content = filters.filters or {}
|
|
|
|
filters_content["origin"] = ["user-feedback"]
|
2021-02-15 10:48:59 +01:00
|
|
|
else:
|
2022-02-09 18:27:12 +01:00
|
|
|
filters_content = {"origin": ["user-feedback"]}
|
2021-02-15 10:48:59 +01:00
|
|
|
|
2022-04-12 16:41:05 +02:00
|
|
|
labels = document_store.get_all_labels(filters=filters_content)
|
2021-02-15 10:48:59 +01:00
|
|
|
|
2022-02-09 18:27:12 +01:00
|
|
|
res: Dict[str, Optional[Union[float, int]]]
|
2021-02-15 10:48:59 +01:00
|
|
|
if len(labels) > 0:
|
|
|
|
answer_feedback = [1 if l.is_correct_answer else 0 for l in labels]
|
|
|
|
doc_feedback = [1 if l.is_correct_document else 0 for l in labels]
|
|
|
|
|
2021-04-07 17:53:32 +02:00
|
|
|
answer_accuracy = sum(answer_feedback) / len(answer_feedback)
|
|
|
|
doc_accuracy = sum(doc_feedback) / len(doc_feedback)
|
2021-02-15 10:48:59 +01:00
|
|
|
|
2021-04-07 17:53:32 +02:00
|
|
|
res = {"answer_accuracy": answer_accuracy, "document_accuracy": doc_accuracy, "n_feedback": len(labels)}
|
2021-02-15 10:48:59 +01:00
|
|
|
else:
|
2021-04-07 17:53:32 +02:00
|
|
|
res = {"answer_accuracy": None, "document_accuracy": None, "n_feedback": 0}
|
2021-02-15 10:48:59 +01:00
|
|
|
return res
|
|
|
|
|
2021-04-07 17:53:32 +02:00
|
|
|
|
|
|
|
@router.get("/export-feedback")
|
2021-11-11 09:40:58 +01:00
|
|
|
def export_feedback(
|
2021-04-07 17:53:32 +02:00
|
|
|
context_size: int = 100_000, full_document_context: bool = True, only_positive_labels: bool = False
|
|
|
|
):
|
2020-04-15 14:04:30 +02:00
|
|
|
"""
|
2022-02-03 13:43:18 +01:00
|
|
|
This endpoint returns JSON output in the SQuAD format for question/answer pairs
|
2022-01-27 13:06:01 +01:00
|
|
|
that were marked as "relevant" by user feedback through the `POST /feedback` endpoint.
|
2020-07-31 11:34:06 +02:00
|
|
|
|
|
|
|
The context_size param can be used to limit response size for large documents.
|
2020-04-15 14:04:30 +02:00
|
|
|
"""
|
2021-04-07 17:53:32 +02:00
|
|
|
if only_positive_labels:
|
2022-04-12 16:41:05 +02:00
|
|
|
labels = document_store.get_all_labels(filters={"is_correct_answer": [True], "origin": ["user-feedback"]})
|
2021-04-07 17:53:32 +02:00
|
|
|
else:
|
2022-04-12 16:41:05 +02:00
|
|
|
labels = document_store.get_all_labels(filters={"origin": ["user-feedback"]})
|
2021-04-07 17:53:32 +02:00
|
|
|
# Filter out the labels where the passage is correct but answer is wrong (in SQuAD this matches
|
|
|
|
# neither a "positive example" nor a negative "is_impossible" one)
|
|
|
|
labels = [l for l in labels if not (l.is_correct_document is True and l.is_correct_answer is False)]
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
export_data = []
|
2020-07-31 11:34:06 +02:00
|
|
|
|
|
|
|
for label in labels:
|
2022-04-12 16:41:05 +02:00
|
|
|
answer_text = label.answer.answer if label and label.answer else ""
|
|
|
|
|
|
|
|
offset_start_in_document = 0
|
|
|
|
if label.answer and label.answer.offsets_in_document:
|
|
|
|
offset_start_in_document = label.answer.offsets_in_document[0].start
|
|
|
|
|
2021-04-07 17:53:32 +02:00
|
|
|
if full_document_context:
|
2021-10-13 14:23:23 +02:00
|
|
|
context = label.document.content
|
2022-04-12 16:41:05 +02:00
|
|
|
answer_start = offset_start_in_document
|
2021-04-07 17:53:32 +02:00
|
|
|
else:
|
2021-10-13 14:23:23 +02:00
|
|
|
text = label.document.content
|
2021-04-07 17:53:32 +02:00
|
|
|
# the final length of context(including the answer string) is 'context_size'.
|
|
|
|
# we try to add equal characters for context before and after the answer string.
|
|
|
|
# if either beginning or end of text is reached, we correspondingly
|
|
|
|
# append more context characters at the other end of answer string.
|
2022-04-12 16:41:05 +02:00
|
|
|
context_to_add = int((context_size - len(answer_text)) / 2)
|
|
|
|
start_pos = max(offset_start_in_document - context_to_add, 0)
|
|
|
|
additional_context_at_end = max(context_to_add - offset_start_in_document, 0)
|
|
|
|
end_pos = min(offset_start_in_document + len(answer_text) + context_to_add, len(text) - 1)
|
2021-04-07 17:53:32 +02:00
|
|
|
additional_context_at_start = max(
|
2022-04-12 16:41:05 +02:00
|
|
|
offset_start_in_document + len(answer_text) + context_to_add - len(text), 0
|
2021-04-07 17:53:32 +02:00
|
|
|
)
|
|
|
|
start_pos = max(0, start_pos - additional_context_at_start)
|
|
|
|
end_pos = min(len(text) - 1, end_pos + additional_context_at_end)
|
|
|
|
context = text[start_pos:end_pos]
|
2022-04-12 16:41:05 +02:00
|
|
|
answer_start = offset_start_in_document - start_pos
|
2021-04-07 17:53:32 +02:00
|
|
|
|
2022-04-12 16:41:05 +02:00
|
|
|
squad_label: Dict[str, Any]
|
2021-04-07 17:53:32 +02:00
|
|
|
if label.is_correct_answer is False and label.is_correct_document is False: # No answer
|
|
|
|
squad_label = {
|
|
|
|
"paragraphs": [
|
|
|
|
{
|
|
|
|
"context": context,
|
2021-10-13 14:23:23 +02:00
|
|
|
"id": label.document.id,
|
|
|
|
"qas": [{"question": label.query, "id": label.id, "is_impossible": True, "answers": []}],
|
2021-04-07 17:53:32 +02:00
|
|
|
}
|
|
|
|
]
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
squad_label = {
|
|
|
|
"paragraphs": [
|
|
|
|
{
|
|
|
|
"context": context,
|
2021-10-13 14:23:23 +02:00
|
|
|
"id": label.document.id,
|
2021-04-07 17:53:32 +02:00
|
|
|
"qas": [
|
|
|
|
{
|
2021-10-13 14:23:23 +02:00
|
|
|
"question": label.query,
|
2021-04-07 17:53:32 +02:00
|
|
|
"id": label.id,
|
|
|
|
"is_impossible": False,
|
2022-04-12 16:41:05 +02:00
|
|
|
"answers": [{"text": answer_text, "answer_start": answer_start}],
|
2021-04-07 17:53:32 +02:00
|
|
|
}
|
|
|
|
],
|
|
|
|
}
|
|
|
|
]
|
|
|
|
}
|
|
|
|
|
|
|
|
# quality check
|
|
|
|
start = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
|
|
|
|
answer = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["text"]
|
|
|
|
context = squad_label["paragraphs"][0]["context"]
|
2022-02-03 13:43:18 +01:00
|
|
|
if not context[start : start + len(answer)] == answer:
|
2021-04-07 17:53:32 +02:00
|
|
|
logger.error(
|
|
|
|
f"Skipping invalid squad label as string via offsets "
|
|
|
|
f"('{context[start:start + len(answer)]}') does not match answer string ('{answer}') "
|
|
|
|
)
|
|
|
|
export_data.append(squad_label)
|
2020-04-15 14:04:30 +02:00
|
|
|
|
|
|
|
export = {"data": export_data}
|
|
|
|
|
2021-04-07 17:53:32 +02:00
|
|
|
with open("feedback_squad_direct.json", "w", encoding="utf8") as f:
|
|
|
|
json.dump(export_data, f, ensure_ascii=False, sort_keys=True, indent=4)
|
2020-04-15 14:04:30 +02:00
|
|
|
return export
|