From c29f960c4792c88f5f69324912d40fa0649208f7 Mon Sep 17 00:00:00 2001 From: Sara Zan Date: Mon, 29 Nov 2021 17:03:54 +0100 Subject: [PATCH] Fix UI demo feedback (#1816) * Fix the feedback function of the demo with a workaround * Some docstring * Update tests and rename methods in feedback.py * Fix tests * Remove operation_ids * Add a couple of status code checks --- rest_api/controller/feedback.py | 12 ++--- test/test_rest_api.py | 34 ++++++++----- ui/utils.py | 88 ++++++++++++++++++++------------- ui/webapp.py | 68 +++++++++++-------------- 4 files changed, 111 insertions(+), 91 deletions(-) diff --git a/rest_api/controller/feedback.py b/rest_api/controller/feedback.py index e9ad820ad..715143080 100644 --- a/rest_api/controller/feedback.py +++ b/rest_api/controller/feedback.py @@ -11,21 +11,21 @@ router = APIRouter() logger = logging.getLogger(__name__) -@router.post("/feedback", operation_id="post_feedback") -def user_feedback(feedback: LabelSerialized): +@router.post("/feedback") +def post_feedback(feedback: LabelSerialized): if feedback.origin is None: feedback.origin = "user-feedback" DOCUMENT_STORE.write_labels([feedback]) -@router.get("/feedback", operation_id="get_feedback") -def user_feedback(): +@router.get("/feedback") +def get_feedback(): labels = DOCUMENT_STORE.get_all_labels() return labels -@router.post("/eval-feedback", operation_id="get_feedback_metrics") -def eval_extractive_qa_feedback(filters: FilterRequest = None): +@router.post("/eval-feedback") +def get_feedback_metrics(filters: FilterRequest = None): """ Return basic accuracy metrics based on the user feedback. Which ratio of answers was correct? Which ratio of documents was correct? diff --git a/test/test_rest_api.py b/test/test_rest_api.py index 4b2e56c33..37d28bdf7 100644 --- a/test/test_rest_api.py +++ b/test/test_rest_api.py @@ -3,7 +3,6 @@ from pathlib import Path import pytest from fastapi.testclient import TestClient -from haystack.schema import Label from rest_api.application import app @@ -16,7 +15,9 @@ FEEDBACK={ "content_type": "text", "score": None, "id": "fc18c987a8312e72a47fb1524f230bb0", - "meta": {} + "meta": {}, + "embedding": None, + "id_hash_keys": None }, "answer": { @@ -25,7 +26,9 @@ FEEDBACK={ "context": "A sample PDF file\n\nHistory and standardization\nFormat (PDF) Adobe Systems made the PDF specification available free of charge in 1993. In the early ye", "offsets_in_context": [{"start": 60, "end": 73}], "offsets_in_document": [{"start": 60, "end": 73}], - "document_id": "fc18c987a8312e72a47fb1524f230bb0" + "document_id": "fc18c987a8312e72a47fb1524f230bb0", + "meta": {}, + "score": None }, "is_correct_answer": True, "is_correct_document": True, @@ -138,13 +141,12 @@ def test_delete_documents(): assert 200 == response.status_code response_json = response.json() assert len(response_json) == 1 - + def test_file_upload(client: TestClient): file_to_upload = {'files': (Path(__file__).parent / "samples"/"pdf"/"sample_pdf_1.pdf").open('rb')} - response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value"}'}) + response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value", "non-existing-field": "wrong-value"}'}) assert 200 == response.status_code - client.post(url="/documents/delete_by_filters", data='{"filters": {}}') def test_query_with_no_filter(populated_client: TestClient): @@ -204,12 +206,16 @@ def test_write_feedback(populated_client: TestClient): def test_get_feedback(client: TestClient): response = client.post(url="/feedback", json=FEEDBACK) - resp = client.get(url="/feedback") - labels = [Label.from_dict(i) for i in resp.json()] + assert response.status_code == 200 + response = client.get(url="/feedback") + assert response.status_code == 200 + json_response = response.json() + for response_item, expected_item in [(json_response[0][key], value) for key, value in FEEDBACK.items()]: + assert response_item == expected_item -def test_export_feedback(populated_client: TestClient): - response = populated_client.post(url="/feedback", json=FEEDBACK) +def test_export_feedback(client: TestClient): + response = client.post(url="/feedback", json=FEEDBACK) assert 200 == response.status_code feedback_urls = [ @@ -218,10 +224,16 @@ def test_export_feedback(populated_client: TestClient): "/export-feedback?full_document_context=false&context_size=50000", ] for url in feedback_urls: - response = populated_client.get(url=url, json=FEEDBACK) + response = client.get(url=url, json=FEEDBACK) response_json = response.json() context = response_json["data"][0]["paragraphs"][0]["context"] answer_start = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"] answer = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["text"] assert context[answer_start:answer_start+len(answer)] == answer + +def test_get_feedback_malformed_query(client: TestClient): + feedback = FEEDBACK.copy() + feedback["unexpected_field"] = "misplaced-value" + response = client.post(url="/feedback", json=feedback) + assert response.status_code == 422 diff --git a/ui/utils.py b/ui/utils.py index 6cbb0a695..fbeb34d69 100644 --- a/ui/utils.py +++ b/ui/utils.py @@ -1,9 +1,12 @@ -import os +from typing import List, Dict, Any, Tuple +import os import logging import requests +from uuid import uuid4 import streamlit as st + API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:8000") STATUS = "initialized" HS_VERSION = "hs_version" @@ -13,6 +16,9 @@ DOC_UPLOAD = "file-upload" def haystack_is_ready(): + """ + Used to show the "Haystack is loading..." message + """ url = f"{API_ENDPOINT}/{STATUS}" try: if requests.get(url).json(): @@ -21,68 +27,80 @@ def haystack_is_ready(): logging.exception(e) return False + @st.cache def haystack_version(): + """ + Get the Haystack version from the REST API + """ url = f"{API_ENDPOINT}/{HS_VERSION}" return requests.get(url, timeout=0.1).json()["hs_version"] -def retrieve_doc(query, filters={}, top_k_reader=5, top_k_retriever=5): - # Query Haystack API + +def query(query, filters={}, top_k_reader=5, top_k_retriever=5) -> Tuple[List[Dict[str, Any]], Dict[str, str]]: + """ + Send a query to the REST API and parse the answer. + Returns both a ready-to-use representation of the results and the raw JSON. + """ + url = f"{API_ENDPOINT}/{DOC_REQUEST}" params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}, "Reader": {"top_k": top_k_reader}} req = {"query": query, "params": params} response_raw = requests.post(url, json=req).json() - # Format response - result = [] + if response_raw.status_code >= 400: + raise Exception(f"{response_raw}") - if "errors" in response_raw: - raise Exception(", ".join(response_raw["errors"])) - - answers = response_raw["answers"] - for i in range(len(answers)): - answer = answers[i] - answer_text = answer.get("answer", None) - if answer_text: - result.append( + response = requests.post(url, json=req).json() + if "errors" in response: + raise Exception(", ".join(response["errors"])) + + # Format response + results = [] + answers = response["answers"] + for answer in answers: + if answer.get("answer", None): + results.append( { "context": "..." + answer["context"] + "...", - "answer": answer_text, + "answer": answer.get("answer", None), "source": answer["meta"]["name"], "relevance": round(answer["score"] * 100, 2), - "document_id": answer["document_id"], + "document": [doc for doc in response_raw["documents"] if doc["id"] == answer["document_id"]][0], "offset_start_in_doc": answer["offsets_in_document"][0]["start"], + "_raw": answer } ) else: - result.append( + results.append( { "context": None, "answer": None, + "document": None, "relevance": round(answer["score"] * 100, 2), + "_raw": answer, } ) - return result, response_raw + return results, response_raw -def feedback_doc(question, is_correct_answer, document_id, model_id, is_correct_document, answer, offset_start_in_doc): - # Feedback Haystack API - try: - url = f"{API_ENDPOINT}/{DOC_FEEDBACK}" - #TODO adjust after Label refactoring - req = { - "question": question, - "is_correct_answer": is_correct_answer, - "document_id": document_id, - "model_id": model_id, - "is_correct_document": is_correct_document, - "answer": answer, - "offset_start_in_doc": offset_start_in_doc, +def send_feedback(query, answer_obj, is_correct_answer, is_correct_document, document) -> None: + """ + Send a feedback (label) to the REST API + """ + url = f"{API_ENDPOINT}/{DOC_FEEDBACK}" + req = { + "id": str(uuid4()), + "query": query, + "document": document, + "is_correct_answer": is_correct_answer, + "is_correct_document": is_correct_document, + "origin": "user-feedback", + "answer": answer_obj } - response_raw = requests.post(url, json=req).json() - return response_raw - except Exception as e: - logging.exception(e) + response_raw = requests.post(url, json=req) + if response_raw.status_code >= 400: + raise ValueError(f"An error was returned [code {response_raw.status_code}]: {response_raw.json()}") def upload_doc(file): diff --git a/ui/webapp.py b/ui/webapp.py index 903fa51b4..28110d7f1 100644 --- a/ui/webapp.py +++ b/ui/webapp.py @@ -1,7 +1,6 @@ import os import sys -import html import logging import pandas as pd from json import JSONDecodeError @@ -9,13 +8,12 @@ from pathlib import Path import streamlit as st from annotated_text import annotation from markdown import markdown -from htbuilder import H # streamlit does not support any states out of the box. On every button click, streamlit reload the whole page # and every value gets lost. To keep track of our feedback state we use the official streamlit gist mentioned # here https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92 import SessionState -from utils import HS_VERSION, feedback_doc, haystack_is_ready, retrieve_doc, upload_doc, haystack_version +from utils import HS_VERSION, haystack_is_ready, query, send_feedback, upload_doc, haystack_version # Adjust to a question that you would like users to see in the search bar when they load the UI: @@ -154,7 +152,7 @@ Ask any question on this topic and see if Haystack can find the correct answer t "Check out the docs: https://haystack.deepset.ai/usage/optimization " ): try: - state.results, state.raw_json = retrieve_doc(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever) + state.results, state.raw_json = query(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever) except JSONDecodeError as je: st.error("👓    An error occurred reading the results. Is the document store working?") return @@ -163,20 +161,19 @@ Ask any question on this topic and see if Haystack can find the correct answer t if "The server is busy processing requests" in str(e): st.error("🧑‍🌾    All our workers are busy! Try again later.") else: - st.error("🐞    An error occurred during the request. Check the logs in the console to know more.") + st.error("🐞    An error occurred during the request.") return if state.results: # Show the gold answer if we use a question of the given set - if question == state.random_question and eval_mode: + if question == state.random_question and eval_mode and state.random_answer: st.write("## Correct answers:") st.write(state.random_answer) st.write("## Results:") - count = 0 # Make every button key unique - for result in state.results: + for count, result in enumerate(state.results): if result["answer"]: answer, context = result["answer"], result["context"] start_idx = context.find(answer) @@ -191,43 +188,36 @@ Ask any question on this topic and see if Haystack can find the correct answer t if eval_mode: # Define columns for buttons + is_correct_answer = None + is_correct_document = None + button_col1, button_col2, button_col3, _ = st.columns([1, 1, 1, 6]) if button_col1.button("👍", key=f"{result['context']}{count}1", help="Correct answer"): - feedback_doc( - question=question, - is_correct_answer="true", - document_id=result.get("document_id", None), - model_id=1, - is_correct_document="true", - answer=result["answer"], - offset_start_in_doc=result.get("offset_start_in_doc", None) - ) - st.success("✨    Thanks for your feedback!    ✨") + is_correct_answer=True + is_correct_document=True if button_col2.button("👎", key=f"{result['context']}{count}2", help="Wrong answer and wrong passage"): - feedback_doc( - question=question, - is_correct_answer="false", - document_id=result.get("document_id", None), - model_id=1, - is_correct_document="false", - answer=result["answer"], - offset_start_in_doc=result.get("offset_start_in_doc", None) - ) - st.success("✨    Thanks for your feedback!    ✨") + is_correct_answer=False + is_correct_document=False if button_col3.button("👎👍", key=f"{result['context']}{count}3", help="Wrong answer, but correct passage"): - feedback_doc( - question=question, - is_correct_answer="false", - document_id=result.get("document_id", None), - model_id=1, - is_correct_document="true", - answer=result["answer"], - offset_start_in_doc=result.get("offset_start_in_doc", None) - ) - st.success("✨    Thanks for your feedback!    ✨") - count += 1 + is_correct_answer=False + is_correct_document=True + + if is_correct_answer is not None and is_correct_document is not None: + try: + send_feedback( + query=question, + answer_obj=result["_raw"], + is_correct_answer=is_correct_answer, + is_correct_document=is_correct_document, + document=result["document"] + ) + st.success("✨    Thanks for your feedback!    ✨") + except Exception as e: + logging.exception(e) + st.error("🐞    An error occurred while submitting your feedback!") + st.write("___") if debug: