Fix UI demo feedback (#1816)

* Fix the feedback function of the demo with a workaround

* Some docstring

* Update tests and rename methods in feedback.py

* Fix tests

* Remove operation_ids

* Add a couple of status code checks
This commit is contained in:
Sara Zan 2021-11-29 17:03:54 +01:00 committed by GitHub
parent 84147edcca
commit c29f960c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 111 additions and 91 deletions

View File

@ -11,21 +11,21 @@ router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/feedback", operation_id="post_feedback")
def user_feedback(feedback: LabelSerialized):
@router.post("/feedback")
def post_feedback(feedback: LabelSerialized):
if feedback.origin is None:
feedback.origin = "user-feedback"
DOCUMENT_STORE.write_labels([feedback])
@router.get("/feedback", operation_id="get_feedback")
def user_feedback():
@router.get("/feedback")
def get_feedback():
labels = DOCUMENT_STORE.get_all_labels()
return labels
@router.post("/eval-feedback", operation_id="get_feedback_metrics")
def eval_extractive_qa_feedback(filters: FilterRequest = None):
@router.post("/eval-feedback")
def get_feedback_metrics(filters: FilterRequest = None):
"""
Return basic accuracy metrics based on the user feedback.
Which ratio of answers was correct? Which ratio of documents was correct?

View File

@ -3,7 +3,6 @@ from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from haystack.schema import Label
from rest_api.application import app
@ -16,7 +15,9 @@ FEEDBACK={
"content_type": "text",
"score": None,
"id": "fc18c987a8312e72a47fb1524f230bb0",
"meta": {}
"meta": {},
"embedding": None,
"id_hash_keys": None
},
"answer":
{
@ -25,7 +26,9 @@ FEEDBACK={
"context": "A sample PDF file\n\nHistory and standardization\nFormat (PDF) Adobe Systems made the PDF specification available free of charge in 1993. In the early ye",
"offsets_in_context": [{"start": 60, "end": 73}],
"offsets_in_document": [{"start": 60, "end": 73}],
"document_id": "fc18c987a8312e72a47fb1524f230bb0"
"document_id": "fc18c987a8312e72a47fb1524f230bb0",
"meta": {},
"score": None
},
"is_correct_answer": True,
"is_correct_document": True,
@ -138,13 +141,12 @@ def test_delete_documents():
assert 200 == response.status_code
response_json = response.json()
assert len(response_json) == 1
def test_file_upload(client: TestClient):
file_to_upload = {'files': (Path(__file__).parent / "samples"/"pdf"/"sample_pdf_1.pdf").open('rb')}
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value"}'})
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value", "non-existing-field": "wrong-value"}'})
assert 200 == response.status_code
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
def test_query_with_no_filter(populated_client: TestClient):
@ -204,12 +206,16 @@ def test_write_feedback(populated_client: TestClient):
def test_get_feedback(client: TestClient):
response = client.post(url="/feedback", json=FEEDBACK)
resp = client.get(url="/feedback")
labels = [Label.from_dict(i) for i in resp.json()]
assert response.status_code == 200
response = client.get(url="/feedback")
assert response.status_code == 200
json_response = response.json()
for response_item, expected_item in [(json_response[0][key], value) for key, value in FEEDBACK.items()]:
assert response_item == expected_item
def test_export_feedback(populated_client: TestClient):
response = populated_client.post(url="/feedback", json=FEEDBACK)
def test_export_feedback(client: TestClient):
response = client.post(url="/feedback", json=FEEDBACK)
assert 200 == response.status_code
feedback_urls = [
@ -218,10 +224,16 @@ def test_export_feedback(populated_client: TestClient):
"/export-feedback?full_document_context=false&context_size=50000",
]
for url in feedback_urls:
response = populated_client.get(url=url, json=FEEDBACK)
response = client.get(url=url, json=FEEDBACK)
response_json = response.json()
context = response_json["data"][0]["paragraphs"][0]["context"]
answer_start = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
answer = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["text"]
assert context[answer_start:answer_start+len(answer)] == answer
def test_get_feedback_malformed_query(client: TestClient):
feedback = FEEDBACK.copy()
feedback["unexpected_field"] = "misplaced-value"
response = client.post(url="/feedback", json=feedback)
assert response.status_code == 422

View File

@ -1,9 +1,12 @@
import os
from typing import List, Dict, Any, Tuple
import os
import logging
import requests
from uuid import uuid4
import streamlit as st
API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:8000")
STATUS = "initialized"
HS_VERSION = "hs_version"
@ -13,6 +16,9 @@ DOC_UPLOAD = "file-upload"
def haystack_is_ready():
"""
Used to show the "Haystack is loading..." message
"""
url = f"{API_ENDPOINT}/{STATUS}"
try:
if requests.get(url).json():
@ -21,68 +27,80 @@ def haystack_is_ready():
logging.exception(e)
return False
@st.cache
def haystack_version():
"""
Get the Haystack version from the REST API
"""
url = f"{API_ENDPOINT}/{HS_VERSION}"
return requests.get(url, timeout=0.1).json()["hs_version"]
def retrieve_doc(query, filters={}, top_k_reader=5, top_k_retriever=5):
# Query Haystack API
def query(query, filters={}, top_k_reader=5, top_k_retriever=5) -> Tuple[List[Dict[str, Any]], Dict[str, str]]:
"""
Send a query to the REST API and parse the answer.
Returns both a ready-to-use representation of the results and the raw JSON.
"""
url = f"{API_ENDPOINT}/{DOC_REQUEST}"
params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}, "Reader": {"top_k": top_k_reader}}
req = {"query": query, "params": params}
response_raw = requests.post(url, json=req).json()
# Format response
result = []
if response_raw.status_code >= 400:
raise Exception(f"{response_raw}")
if "errors" in response_raw:
raise Exception(", ".join(response_raw["errors"]))
answers = response_raw["answers"]
for i in range(len(answers)):
answer = answers[i]
answer_text = answer.get("answer", None)
if answer_text:
result.append(
response = requests.post(url, json=req).json()
if "errors" in response:
raise Exception(", ".join(response["errors"]))
# Format response
results = []
answers = response["answers"]
for answer in answers:
if answer.get("answer", None):
results.append(
{
"context": "..." + answer["context"] + "...",
"answer": answer_text,
"answer": answer.get("answer", None),
"source": answer["meta"]["name"],
"relevance": round(answer["score"] * 100, 2),
"document_id": answer["document_id"],
"document": [doc for doc in response_raw["documents"] if doc["id"] == answer["document_id"]][0],
"offset_start_in_doc": answer["offsets_in_document"][0]["start"],
"_raw": answer
}
)
else:
result.append(
results.append(
{
"context": None,
"answer": None,
"document": None,
"relevance": round(answer["score"] * 100, 2),
"_raw": answer,
}
)
return result, response_raw
return results, response_raw
def feedback_doc(question, is_correct_answer, document_id, model_id, is_correct_document, answer, offset_start_in_doc):
# Feedback Haystack API
try:
url = f"{API_ENDPOINT}/{DOC_FEEDBACK}"
#TODO adjust after Label refactoring
req = {
"question": question,
"is_correct_answer": is_correct_answer,
"document_id": document_id,
"model_id": model_id,
"is_correct_document": is_correct_document,
"answer": answer,
"offset_start_in_doc": offset_start_in_doc,
def send_feedback(query, answer_obj, is_correct_answer, is_correct_document, document) -> None:
"""
Send a feedback (label) to the REST API
"""
url = f"{API_ENDPOINT}/{DOC_FEEDBACK}"
req = {
"id": str(uuid4()),
"query": query,
"document": document,
"is_correct_answer": is_correct_answer,
"is_correct_document": is_correct_document,
"origin": "user-feedback",
"answer": answer_obj
}
response_raw = requests.post(url, json=req).json()
return response_raw
except Exception as e:
logging.exception(e)
response_raw = requests.post(url, json=req)
if response_raw.status_code >= 400:
raise ValueError(f"An error was returned [code {response_raw.status_code}]: {response_raw.json()}")
def upload_doc(file):

View File

@ -1,7 +1,6 @@
import os
import sys
import html
import logging
import pandas as pd
from json import JSONDecodeError
@ -9,13 +8,12 @@ from pathlib import Path
import streamlit as st
from annotated_text import annotation
from markdown import markdown
from htbuilder import H
# streamlit does not support any states out of the box. On every button click, streamlit reload the whole page
# and every value gets lost. To keep track of our feedback state we use the official streamlit gist mentioned
# here https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
import SessionState
from utils import HS_VERSION, feedback_doc, haystack_is_ready, retrieve_doc, upload_doc, haystack_version
from utils import HS_VERSION, haystack_is_ready, query, send_feedback, upload_doc, haystack_version
# Adjust to a question that you would like users to see in the search bar when they load the UI:
@ -154,7 +152,7 @@ Ask any question on this topic and see if Haystack can find the correct answer t
"Check out the docs: https://haystack.deepset.ai/usage/optimization "
):
try:
state.results, state.raw_json = retrieve_doc(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever)
state.results, state.raw_json = query(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever)
except JSONDecodeError as je:
st.error("👓    An error occurred reading the results. Is the document store working?")
return
@ -163,20 +161,19 @@ Ask any question on this topic and see if Haystack can find the correct answer t
if "The server is busy processing requests" in str(e):
st.error("🧑‍🌾    All our workers are busy! Try again later.")
else:
st.error("🐞    An error occurred during the request. Check the logs in the console to know more.")
st.error("🐞    An error occurred during the request.")
return
if state.results:
# Show the gold answer if we use a question of the given set
if question == state.random_question and eval_mode:
if question == state.random_question and eval_mode and state.random_answer:
st.write("## Correct answers:")
st.write(state.random_answer)
st.write("## Results:")
count = 0 # Make every button key unique
for result in state.results:
for count, result in enumerate(state.results):
if result["answer"]:
answer, context = result["answer"], result["context"]
start_idx = context.find(answer)
@ -191,43 +188,36 @@ Ask any question on this topic and see if Haystack can find the correct answer t
if eval_mode:
# Define columns for buttons
is_correct_answer = None
is_correct_document = None
button_col1, button_col2, button_col3, _ = st.columns([1, 1, 1, 6])
if button_col1.button("👍", key=f"{result['context']}{count}1", help="Correct answer"):
feedback_doc(
question=question,
is_correct_answer="true",
document_id=result.get("document_id", None),
model_id=1,
is_correct_document="true",
answer=result["answer"],
offset_start_in_doc=result.get("offset_start_in_doc", None)
)
st.success("✨    Thanks for your feedback!    ✨")
is_correct_answer=True
is_correct_document=True
if button_col2.button("👎", key=f"{result['context']}{count}2", help="Wrong answer and wrong passage"):
feedback_doc(
question=question,
is_correct_answer="false",
document_id=result.get("document_id", None),
model_id=1,
is_correct_document="false",
answer=result["answer"],
offset_start_in_doc=result.get("offset_start_in_doc", None)
)
st.success("✨    Thanks for your feedback!    ✨")
is_correct_answer=False
is_correct_document=False
if button_col3.button("👎👍", key=f"{result['context']}{count}3", help="Wrong answer, but correct passage"):
feedback_doc(
question=question,
is_correct_answer="false",
document_id=result.get("document_id", None),
model_id=1,
is_correct_document="true",
answer=result["answer"],
offset_start_in_doc=result.get("offset_start_in_doc", None)
)
st.success("✨    Thanks for your feedback!    ✨")
count += 1
is_correct_answer=False
is_correct_document=True
if is_correct_answer is not None and is_correct_document is not None:
try:
send_feedback(
query=question,
answer_obj=result["_raw"],
is_correct_answer=is_correct_answer,
is_correct_document=is_correct_document,
document=result["document"]
)
st.success("✨    Thanks for your feedback!    ✨")
except Exception as e:
logging.exception(e)
st.error("🐞    An error occurred while submitting your feedback!")
st.write("___")
if debug: