import json import logging from typing import Dict, Union, List, Optional from fastapi import APIRouter, HTTPException from rest_api.schema import ExtractiveQAFeedback, FilterRequest from rest_api.controller.search import DOCUMENT_STORE router = APIRouter() logger = logging.getLogger(__name__) @router.post("/feedback") def user_feedback(feedback: ExtractiveQAFeedback): DOCUMENT_STORE.write_labels([{"origin": "user-feedback", **feedback.dict()}]) @router.post("/eval-feedback") def eval_extractive_qa_feedback(filters: FilterRequest = None): """ Return basic accuracy metrics based on the user feedback. Which ratio of answers was correct? Which ratio of documents was correct? You can supply filters in the request to only use a certain subset of labels. **Example:** ``` | curl --location --request POST 'http://127.0.0.1:8000/eval-doc-qa-feedback' \ | --header 'Content-Type: application/json' \ | --data-raw '{ "filters": {"document_id": ["XRR3xnEBCYVTkbTystOB"]} }' """ if filters: filters = filters.filters filters["origin"] = ["user-feedback"] else: filters = {"origin": ["user-feedback"]} labels = DOCUMENT_STORE.get_all_labels(filters=filters) if len(labels) > 0: answer_feedback = [1 if l.is_correct_answer else 0 for l in labels] doc_feedback = [1 if l.is_correct_document else 0 for l in labels] answer_accuracy = sum(answer_feedback) / len(answer_feedback) doc_accuracy = sum(doc_feedback) / len(doc_feedback) res = {"answer_accuracy": answer_accuracy, "document_accuracy": doc_accuracy, "n_feedback": len(labels)} else: res = {"answer_accuracy": None, "document_accuracy": None, "n_feedback": 0} return res @router.get("/export-feedback") def export_extractive_qa_feedback( context_size: int = 100_000, full_document_context: bool = True, only_positive_labels: bool = False ): """ SQuAD format JSON export for question/answer pairs that were marked as "relevant". The context_size param can be used to limit response size for large documents. """ if only_positive_labels: labels = DOCUMENT_STORE.get_all_labels(filters={"is_correct_answer": [True], "origin": ["user-feedback"]}) else: labels = DOCUMENT_STORE.get_all_labels(filters={"origin": ["user-feedback"]}) # Filter out the labels where the passage is correct but answer is wrong (in SQuAD this matches # neither a "positive example" nor a negative "is_impossible" one) labels = [l for l in labels if not (l.is_correct_document is True and l.is_correct_answer is False)] export_data = [] for label in labels: document = DOCUMENT_STORE.get_document_by_id(label.document_id) if document is None: raise HTTPException( status_code=500, detail="Could not find document with id {label.document_id} for label id {label.id}" ) if full_document_context: context = document.text answer_start = label.offset_start_in_doc else: text = document.text # the final length of context(including the answer string) is 'context_size'. # we try to add equal characters for context before and after the answer string. # if either beginning or end of text is reached, we correspondingly # append more context characters at the other end of answer string. context_to_add = int((context_size - len(label.answer)) / 2) start_pos = max(label.offset_start_in_doc - context_to_add, 0) additional_context_at_end = max(context_to_add - label.offset_start_in_doc, 0) end_pos = min(label.offset_start_in_doc + len(label.answer) + context_to_add, len(text) - 1) additional_context_at_start = max( label.offset_start_in_doc + len(label.answer) + context_to_add - len(text), 0 ) start_pos = max(0, start_pos - additional_context_at_start) end_pos = min(len(text) - 1, end_pos + additional_context_at_end) context = text[start_pos:end_pos] answer_start = label.offset_start_in_doc - start_pos if label.is_correct_answer is False and label.is_correct_document is False: # No answer squad_label = { "paragraphs": [ { "context": context, "id": label.document_id, "qas": [{"question": label.question, "id": label.id, "is_impossible": True, "answers": []}], } ] } else: squad_label = { "paragraphs": [ { "context": context, "id": label.document_id, "qas": [ { "question": label.question, "id": label.id, "is_impossible": False, "answers": [{"text": label.answer, "answer_start": answer_start}], } ], } ] } # quality check start = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"] answer = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["text"] context = squad_label["paragraphs"][0]["context"] if not context[start: start + len(answer)] == answer: logger.error( f"Skipping invalid squad label as string via offsets " f"('{context[start:start + len(answer)]}') does not match answer string ('{answer}') " ) export_data.append(squad_label) export = {"data": export_data} with open("feedback_squad_direct.json", "w", encoding="utf8") as f: json.dump(export_data, f, ensure_ascii=False, sort_keys=True, indent=4) return export