Fix UI demo feedback (#1816)

* Fix the feedback function of the demo with a workaround

* Some docstring

* Update tests and rename methods in feedback.py

* Fix tests

* Remove operation_ids

* Add a couple of status code checks
This commit is contained in:
Sara Zan 2021-11-29 17:03:54 +01:00 committed by GitHub
parent 84147edcca
commit c29f960c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 111 additions and 91 deletions

View File

@ -11,21 +11,21 @@ router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@router.post("/feedback", operation_id="post_feedback") @router.post("/feedback")
def user_feedback(feedback: LabelSerialized): def post_feedback(feedback: LabelSerialized):
if feedback.origin is None: if feedback.origin is None:
feedback.origin = "user-feedback" feedback.origin = "user-feedback"
DOCUMENT_STORE.write_labels([feedback]) DOCUMENT_STORE.write_labels([feedback])
@router.get("/feedback", operation_id="get_feedback") @router.get("/feedback")
def user_feedback(): def get_feedback():
labels = DOCUMENT_STORE.get_all_labels() labels = DOCUMENT_STORE.get_all_labels()
return labels return labels
@router.post("/eval-feedback", operation_id="get_feedback_metrics") @router.post("/eval-feedback")
def eval_extractive_qa_feedback(filters: FilterRequest = None): def get_feedback_metrics(filters: FilterRequest = None):
""" """
Return basic accuracy metrics based on the user feedback. Return basic accuracy metrics based on the user feedback.
Which ratio of answers was correct? Which ratio of documents was correct? Which ratio of answers was correct? Which ratio of documents was correct?

View File

@ -3,7 +3,6 @@ from pathlib import Path
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from haystack.schema import Label
from rest_api.application import app from rest_api.application import app
@ -16,7 +15,9 @@ FEEDBACK={
"content_type": "text", "content_type": "text",
"score": None, "score": None,
"id": "fc18c987a8312e72a47fb1524f230bb0", "id": "fc18c987a8312e72a47fb1524f230bb0",
"meta": {} "meta": {},
"embedding": None,
"id_hash_keys": None
}, },
"answer": "answer":
{ {
@ -25,7 +26,9 @@ FEEDBACK={
"context": "A sample PDF file\n\nHistory and standardization\nFormat (PDF) Adobe Systems made the PDF specification available free of charge in 1993. In the early ye", "context": "A sample PDF file\n\nHistory and standardization\nFormat (PDF) Adobe Systems made the PDF specification available free of charge in 1993. In the early ye",
"offsets_in_context": [{"start": 60, "end": 73}], "offsets_in_context": [{"start": 60, "end": 73}],
"offsets_in_document": [{"start": 60, "end": 73}], "offsets_in_document": [{"start": 60, "end": 73}],
"document_id": "fc18c987a8312e72a47fb1524f230bb0" "document_id": "fc18c987a8312e72a47fb1524f230bb0",
"meta": {},
"score": None
}, },
"is_correct_answer": True, "is_correct_answer": True,
"is_correct_document": True, "is_correct_document": True,
@ -138,13 +141,12 @@ def test_delete_documents():
assert 200 == response.status_code assert 200 == response.status_code
response_json = response.json() response_json = response.json()
assert len(response_json) == 1 assert len(response_json) == 1
def test_file_upload(client: TestClient): def test_file_upload(client: TestClient):
file_to_upload = {'files': (Path(__file__).parent / "samples"/"pdf"/"sample_pdf_1.pdf").open('rb')} file_to_upload = {'files': (Path(__file__).parent / "samples"/"pdf"/"sample_pdf_1.pdf").open('rb')}
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value"}'}) response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value", "non-existing-field": "wrong-value"}'})
assert 200 == response.status_code assert 200 == response.status_code
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
def test_query_with_no_filter(populated_client: TestClient): def test_query_with_no_filter(populated_client: TestClient):
@ -204,12 +206,16 @@ def test_write_feedback(populated_client: TestClient):
def test_get_feedback(client: TestClient): def test_get_feedback(client: TestClient):
response = client.post(url="/feedback", json=FEEDBACK) response = client.post(url="/feedback", json=FEEDBACK)
resp = client.get(url="/feedback") assert response.status_code == 200
labels = [Label.from_dict(i) for i in resp.json()] response = client.get(url="/feedback")
assert response.status_code == 200
json_response = response.json()
for response_item, expected_item in [(json_response[0][key], value) for key, value in FEEDBACK.items()]:
assert response_item == expected_item
def test_export_feedback(populated_client: TestClient): def test_export_feedback(client: TestClient):
response = populated_client.post(url="/feedback", json=FEEDBACK) response = client.post(url="/feedback", json=FEEDBACK)
assert 200 == response.status_code assert 200 == response.status_code
feedback_urls = [ feedback_urls = [
@ -218,10 +224,16 @@ def test_export_feedback(populated_client: TestClient):
"/export-feedback?full_document_context=false&context_size=50000", "/export-feedback?full_document_context=false&context_size=50000",
] ]
for url in feedback_urls: for url in feedback_urls:
response = populated_client.get(url=url, json=FEEDBACK) response = client.get(url=url, json=FEEDBACK)
response_json = response.json() response_json = response.json()
context = response_json["data"][0]["paragraphs"][0]["context"] context = response_json["data"][0]["paragraphs"][0]["context"]
answer_start = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"] answer_start = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
answer = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["text"] answer = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["text"]
assert context[answer_start:answer_start+len(answer)] == answer assert context[answer_start:answer_start+len(answer)] == answer
def test_get_feedback_malformed_query(client: TestClient):
feedback = FEEDBACK.copy()
feedback["unexpected_field"] = "misplaced-value"
response = client.post(url="/feedback", json=feedback)
assert response.status_code == 422

View File

@ -1,9 +1,12 @@
import os from typing import List, Dict, Any, Tuple
import os
import logging import logging
import requests import requests
from uuid import uuid4
import streamlit as st import streamlit as st
API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:8000") API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:8000")
STATUS = "initialized" STATUS = "initialized"
HS_VERSION = "hs_version" HS_VERSION = "hs_version"
@ -13,6 +16,9 @@ DOC_UPLOAD = "file-upload"
def haystack_is_ready(): def haystack_is_ready():
"""
Used to show the "Haystack is loading..." message
"""
url = f"{API_ENDPOINT}/{STATUS}" url = f"{API_ENDPOINT}/{STATUS}"
try: try:
if requests.get(url).json(): if requests.get(url).json():
@ -21,68 +27,80 @@ def haystack_is_ready():
logging.exception(e) logging.exception(e)
return False return False
@st.cache @st.cache
def haystack_version(): def haystack_version():
"""
Get the Haystack version from the REST API
"""
url = f"{API_ENDPOINT}/{HS_VERSION}" url = f"{API_ENDPOINT}/{HS_VERSION}"
return requests.get(url, timeout=0.1).json()["hs_version"] return requests.get(url, timeout=0.1).json()["hs_version"]
def retrieve_doc(query, filters={}, top_k_reader=5, top_k_retriever=5):
# Query Haystack API def query(query, filters={}, top_k_reader=5, top_k_retriever=5) -> Tuple[List[Dict[str, Any]], Dict[str, str]]:
"""
Send a query to the REST API and parse the answer.
Returns both a ready-to-use representation of the results and the raw JSON.
"""
url = f"{API_ENDPOINT}/{DOC_REQUEST}" url = f"{API_ENDPOINT}/{DOC_REQUEST}"
params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}, "Reader": {"top_k": top_k_reader}} params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}, "Reader": {"top_k": top_k_reader}}
req = {"query": query, "params": params} req = {"query": query, "params": params}
response_raw = requests.post(url, json=req).json() response_raw = requests.post(url, json=req).json()
# Format response if response_raw.status_code >= 400:
result = [] raise Exception(f"{response_raw}")
if "errors" in response_raw: response = requests.post(url, json=req).json()
raise Exception(", ".join(response_raw["errors"])) if "errors" in response:
raise Exception(", ".join(response["errors"]))
answers = response_raw["answers"]
for i in range(len(answers)): # Format response
answer = answers[i] results = []
answer_text = answer.get("answer", None) answers = response["answers"]
if answer_text: for answer in answers:
result.append( if answer.get("answer", None):
results.append(
{ {
"context": "..." + answer["context"] + "...", "context": "..." + answer["context"] + "...",
"answer": answer_text, "answer": answer.get("answer", None),
"source": answer["meta"]["name"], "source": answer["meta"]["name"],
"relevance": round(answer["score"] * 100, 2), "relevance": round(answer["score"] * 100, 2),
"document_id": answer["document_id"], "document": [doc for doc in response_raw["documents"] if doc["id"] == answer["document_id"]][0],
"offset_start_in_doc": answer["offsets_in_document"][0]["start"], "offset_start_in_doc": answer["offsets_in_document"][0]["start"],
"_raw": answer
} }
) )
else: else:
result.append( results.append(
{ {
"context": None, "context": None,
"answer": None, "answer": None,
"document": None,
"relevance": round(answer["score"] * 100, 2), "relevance": round(answer["score"] * 100, 2),
"_raw": answer,
} }
) )
return result, response_raw return results, response_raw
def feedback_doc(question, is_correct_answer, document_id, model_id, is_correct_document, answer, offset_start_in_doc): def send_feedback(query, answer_obj, is_correct_answer, is_correct_document, document) -> None:
# Feedback Haystack API """
try: Send a feedback (label) to the REST API
url = f"{API_ENDPOINT}/{DOC_FEEDBACK}" """
#TODO adjust after Label refactoring url = f"{API_ENDPOINT}/{DOC_FEEDBACK}"
req = { req = {
"question": question, "id": str(uuid4()),
"is_correct_answer": is_correct_answer, "query": query,
"document_id": document_id, "document": document,
"model_id": model_id, "is_correct_answer": is_correct_answer,
"is_correct_document": is_correct_document, "is_correct_document": is_correct_document,
"answer": answer, "origin": "user-feedback",
"offset_start_in_doc": offset_start_in_doc, "answer": answer_obj
} }
response_raw = requests.post(url, json=req).json() response_raw = requests.post(url, json=req)
return response_raw if response_raw.status_code >= 400:
except Exception as e: raise ValueError(f"An error was returned [code {response_raw.status_code}]: {response_raw.json()}")
logging.exception(e)
def upload_doc(file): def upload_doc(file):

View File

@ -1,7 +1,6 @@
import os import os
import sys import sys
import html
import logging import logging
import pandas as pd import pandas as pd
from json import JSONDecodeError from json import JSONDecodeError
@ -9,13 +8,12 @@ from pathlib import Path
import streamlit as st import streamlit as st
from annotated_text import annotation from annotated_text import annotation
from markdown import markdown from markdown import markdown
from htbuilder import H
# streamlit does not support any states out of the box. On every button click, streamlit reload the whole page # streamlit does not support any states out of the box. On every button click, streamlit reload the whole page
# and every value gets lost. To keep track of our feedback state we use the official streamlit gist mentioned # and every value gets lost. To keep track of our feedback state we use the official streamlit gist mentioned
# here https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92 # here https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
import SessionState import SessionState
from utils import HS_VERSION, feedback_doc, haystack_is_ready, retrieve_doc, upload_doc, haystack_version from utils import HS_VERSION, haystack_is_ready, query, send_feedback, upload_doc, haystack_version
# Adjust to a question that you would like users to see in the search bar when they load the UI: # Adjust to a question that you would like users to see in the search bar when they load the UI:
@ -154,7 +152,7 @@ Ask any question on this topic and see if Haystack can find the correct answer t
"Check out the docs: https://haystack.deepset.ai/usage/optimization " "Check out the docs: https://haystack.deepset.ai/usage/optimization "
): ):
try: try:
state.results, state.raw_json = retrieve_doc(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever) state.results, state.raw_json = query(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever)
except JSONDecodeError as je: except JSONDecodeError as je:
st.error("👓    An error occurred reading the results. Is the document store working?") st.error("👓    An error occurred reading the results. Is the document store working?")
return return
@ -163,20 +161,19 @@ Ask any question on this topic and see if Haystack can find the correct answer t
if "The server is busy processing requests" in str(e): if "The server is busy processing requests" in str(e):
st.error("🧑‍🌾    All our workers are busy! Try again later.") st.error("🧑‍🌾    All our workers are busy! Try again later.")
else: else:
st.error("🐞    An error occurred during the request. Check the logs in the console to know more.") st.error("🐞    An error occurred during the request.")
return return
if state.results: if state.results:
# Show the gold answer if we use a question of the given set # Show the gold answer if we use a question of the given set
if question == state.random_question and eval_mode: if question == state.random_question and eval_mode and state.random_answer:
st.write("## Correct answers:") st.write("## Correct answers:")
st.write(state.random_answer) st.write(state.random_answer)
st.write("## Results:") st.write("## Results:")
count = 0 # Make every button key unique
for result in state.results: for count, result in enumerate(state.results):
if result["answer"]: if result["answer"]:
answer, context = result["answer"], result["context"] answer, context = result["answer"], result["context"]
start_idx = context.find(answer) start_idx = context.find(answer)
@ -191,43 +188,36 @@ Ask any question on this topic and see if Haystack can find the correct answer t
if eval_mode: if eval_mode:
# Define columns for buttons # Define columns for buttons
is_correct_answer = None
is_correct_document = None
button_col1, button_col2, button_col3, _ = st.columns([1, 1, 1, 6]) button_col1, button_col2, button_col3, _ = st.columns([1, 1, 1, 6])
if button_col1.button("👍", key=f"{result['context']}{count}1", help="Correct answer"): if button_col1.button("👍", key=f"{result['context']}{count}1", help="Correct answer"):
feedback_doc( is_correct_answer=True
question=question, is_correct_document=True
is_correct_answer="true",
document_id=result.get("document_id", None),
model_id=1,
is_correct_document="true",
answer=result["answer"],
offset_start_in_doc=result.get("offset_start_in_doc", None)
)
st.success("✨    Thanks for your feedback!    ✨")
if button_col2.button("👎", key=f"{result['context']}{count}2", help="Wrong answer and wrong passage"): if button_col2.button("👎", key=f"{result['context']}{count}2", help="Wrong answer and wrong passage"):
feedback_doc( is_correct_answer=False
question=question, is_correct_document=False
is_correct_answer="false",
document_id=result.get("document_id", None),
model_id=1,
is_correct_document="false",
answer=result["answer"],
offset_start_in_doc=result.get("offset_start_in_doc", None)
)
st.success("✨    Thanks for your feedback!    ✨")
if button_col3.button("👎👍", key=f"{result['context']}{count}3", help="Wrong answer, but correct passage"): if button_col3.button("👎👍", key=f"{result['context']}{count}3", help="Wrong answer, but correct passage"):
feedback_doc( is_correct_answer=False
question=question, is_correct_document=True
is_correct_answer="false",
document_id=result.get("document_id", None), if is_correct_answer is not None and is_correct_document is not None:
model_id=1, try:
is_correct_document="true", send_feedback(
answer=result["answer"], query=question,
offset_start_in_doc=result.get("offset_start_in_doc", None) answer_obj=result["_raw"],
) is_correct_answer=is_correct_answer,
st.success("✨    Thanks for your feedback!    ✨") is_correct_document=is_correct_document,
count += 1 document=result["document"]
)
st.success("✨    Thanks for your feedback!    ✨")
except Exception as e:
logging.exception(e)
st.error("🐞    An error occurred while submitting your feedback!")
st.write("___") st.write("___")
if debug: if debug: