mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-23 07:58:36 +00:00
Fix UI demo feedback (#1816)
* Fix the feedback function of the demo with a workaround * Some docstring * Update tests and rename methods in feedback.py * Fix tests * Remove operation_ids * Add a couple of status code checks
This commit is contained in:
parent
84147edcca
commit
c29f960c47
@ -11,21 +11,21 @@ router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/feedback", operation_id="post_feedback")
|
||||
def user_feedback(feedback: LabelSerialized):
|
||||
@router.post("/feedback")
|
||||
def post_feedback(feedback: LabelSerialized):
|
||||
if feedback.origin is None:
|
||||
feedback.origin = "user-feedback"
|
||||
DOCUMENT_STORE.write_labels([feedback])
|
||||
|
||||
|
||||
@router.get("/feedback", operation_id="get_feedback")
|
||||
def user_feedback():
|
||||
@router.get("/feedback")
|
||||
def get_feedback():
|
||||
labels = DOCUMENT_STORE.get_all_labels()
|
||||
return labels
|
||||
|
||||
|
||||
@router.post("/eval-feedback", operation_id="get_feedback_metrics")
|
||||
def eval_extractive_qa_feedback(filters: FilterRequest = None):
|
||||
@router.post("/eval-feedback")
|
||||
def get_feedback_metrics(filters: FilterRequest = None):
|
||||
"""
|
||||
Return basic accuracy metrics based on the user feedback.
|
||||
Which ratio of answers was correct? Which ratio of documents was correct?
|
||||
|
@ -3,7 +3,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from haystack.schema import Label
|
||||
|
||||
|
||||
from rest_api.application import app
|
||||
@ -16,7 +15,9 @@ FEEDBACK={
|
||||
"content_type": "text",
|
||||
"score": None,
|
||||
"id": "fc18c987a8312e72a47fb1524f230bb0",
|
||||
"meta": {}
|
||||
"meta": {},
|
||||
"embedding": None,
|
||||
"id_hash_keys": None
|
||||
},
|
||||
"answer":
|
||||
{
|
||||
@ -25,7 +26,9 @@ FEEDBACK={
|
||||
"context": "A sample PDF file\n\nHistory and standardization\nFormat (PDF) Adobe Systems made the PDF specification available free of charge in 1993. In the early ye",
|
||||
"offsets_in_context": [{"start": 60, "end": 73}],
|
||||
"offsets_in_document": [{"start": 60, "end": 73}],
|
||||
"document_id": "fc18c987a8312e72a47fb1524f230bb0"
|
||||
"document_id": "fc18c987a8312e72a47fb1524f230bb0",
|
||||
"meta": {},
|
||||
"score": None
|
||||
},
|
||||
"is_correct_answer": True,
|
||||
"is_correct_document": True,
|
||||
@ -138,13 +141,12 @@ def test_delete_documents():
|
||||
assert 200 == response.status_code
|
||||
response_json = response.json()
|
||||
assert len(response_json) == 1
|
||||
|
||||
|
||||
|
||||
def test_file_upload(client: TestClient):
|
||||
file_to_upload = {'files': (Path(__file__).parent / "samples"/"pdf"/"sample_pdf_1.pdf").open('rb')}
|
||||
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value"}'})
|
||||
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value", "non-existing-field": "wrong-value"}'})
|
||||
assert 200 == response.status_code
|
||||
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
|
||||
|
||||
def test_query_with_no_filter(populated_client: TestClient):
|
||||
@ -204,12 +206,16 @@ def test_write_feedback(populated_client: TestClient):
|
||||
|
||||
def test_get_feedback(client: TestClient):
|
||||
response = client.post(url="/feedback", json=FEEDBACK)
|
||||
resp = client.get(url="/feedback")
|
||||
labels = [Label.from_dict(i) for i in resp.json()]
|
||||
assert response.status_code == 200
|
||||
response = client.get(url="/feedback")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
for response_item, expected_item in [(json_response[0][key], value) for key, value in FEEDBACK.items()]:
|
||||
assert response_item == expected_item
|
||||
|
||||
|
||||
def test_export_feedback(populated_client: TestClient):
|
||||
response = populated_client.post(url="/feedback", json=FEEDBACK)
|
||||
def test_export_feedback(client: TestClient):
|
||||
response = client.post(url="/feedback", json=FEEDBACK)
|
||||
assert 200 == response.status_code
|
||||
|
||||
feedback_urls = [
|
||||
@ -218,10 +224,16 @@ def test_export_feedback(populated_client: TestClient):
|
||||
"/export-feedback?full_document_context=false&context_size=50000",
|
||||
]
|
||||
for url in feedback_urls:
|
||||
response = populated_client.get(url=url, json=FEEDBACK)
|
||||
response = client.get(url=url, json=FEEDBACK)
|
||||
response_json = response.json()
|
||||
context = response_json["data"][0]["paragraphs"][0]["context"]
|
||||
answer_start = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
|
||||
answer = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["text"]
|
||||
assert context[answer_start:answer_start+len(answer)] == answer
|
||||
|
||||
|
||||
def test_get_feedback_malformed_query(client: TestClient):
|
||||
feedback = FEEDBACK.copy()
|
||||
feedback["unexpected_field"] = "misplaced-value"
|
||||
response = client.post(url="/feedback", json=feedback)
|
||||
assert response.status_code == 422
|
||||
|
88
ui/utils.py
88
ui/utils.py
@ -1,9 +1,12 @@
|
||||
import os
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
from uuid import uuid4
|
||||
import streamlit as st
|
||||
|
||||
|
||||
API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:8000")
|
||||
STATUS = "initialized"
|
||||
HS_VERSION = "hs_version"
|
||||
@ -13,6 +16,9 @@ DOC_UPLOAD = "file-upload"
|
||||
|
||||
|
||||
def haystack_is_ready():
|
||||
"""
|
||||
Used to show the "Haystack is loading..." message
|
||||
"""
|
||||
url = f"{API_ENDPOINT}/{STATUS}"
|
||||
try:
|
||||
if requests.get(url).json():
|
||||
@ -21,68 +27,80 @@ def haystack_is_ready():
|
||||
logging.exception(e)
|
||||
return False
|
||||
|
||||
|
||||
@st.cache
|
||||
def haystack_version():
|
||||
"""
|
||||
Get the Haystack version from the REST API
|
||||
"""
|
||||
url = f"{API_ENDPOINT}/{HS_VERSION}"
|
||||
return requests.get(url, timeout=0.1).json()["hs_version"]
|
||||
|
||||
def retrieve_doc(query, filters={}, top_k_reader=5, top_k_retriever=5):
|
||||
# Query Haystack API
|
||||
|
||||
def query(query, filters={}, top_k_reader=5, top_k_retriever=5) -> Tuple[List[Dict[str, Any]], Dict[str, str]]:
|
||||
"""
|
||||
Send a query to the REST API and parse the answer.
|
||||
Returns both a ready-to-use representation of the results and the raw JSON.
|
||||
"""
|
||||
|
||||
url = f"{API_ENDPOINT}/{DOC_REQUEST}"
|
||||
params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}, "Reader": {"top_k": top_k_reader}}
|
||||
req = {"query": query, "params": params}
|
||||
response_raw = requests.post(url, json=req).json()
|
||||
|
||||
# Format response
|
||||
result = []
|
||||
if response_raw.status_code >= 400:
|
||||
raise Exception(f"{response_raw}")
|
||||
|
||||
if "errors" in response_raw:
|
||||
raise Exception(", ".join(response_raw["errors"]))
|
||||
|
||||
answers = response_raw["answers"]
|
||||
for i in range(len(answers)):
|
||||
answer = answers[i]
|
||||
answer_text = answer.get("answer", None)
|
||||
if answer_text:
|
||||
result.append(
|
||||
response = requests.post(url, json=req).json()
|
||||
if "errors" in response:
|
||||
raise Exception(", ".join(response["errors"]))
|
||||
|
||||
# Format response
|
||||
results = []
|
||||
answers = response["answers"]
|
||||
for answer in answers:
|
||||
if answer.get("answer", None):
|
||||
results.append(
|
||||
{
|
||||
"context": "..." + answer["context"] + "...",
|
||||
"answer": answer_text,
|
||||
"answer": answer.get("answer", None),
|
||||
"source": answer["meta"]["name"],
|
||||
"relevance": round(answer["score"] * 100, 2),
|
||||
"document_id": answer["document_id"],
|
||||
"document": [doc for doc in response_raw["documents"] if doc["id"] == answer["document_id"]][0],
|
||||
"offset_start_in_doc": answer["offsets_in_document"][0]["start"],
|
||||
"_raw": answer
|
||||
}
|
||||
)
|
||||
else:
|
||||
result.append(
|
||||
results.append(
|
||||
{
|
||||
"context": None,
|
||||
"answer": None,
|
||||
"document": None,
|
||||
"relevance": round(answer["score"] * 100, 2),
|
||||
"_raw": answer,
|
||||
}
|
||||
)
|
||||
return result, response_raw
|
||||
return results, response_raw
|
||||
|
||||
|
||||
def feedback_doc(question, is_correct_answer, document_id, model_id, is_correct_document, answer, offset_start_in_doc):
|
||||
# Feedback Haystack API
|
||||
try:
|
||||
url = f"{API_ENDPOINT}/{DOC_FEEDBACK}"
|
||||
#TODO adjust after Label refactoring
|
||||
req = {
|
||||
"question": question,
|
||||
"is_correct_answer": is_correct_answer,
|
||||
"document_id": document_id,
|
||||
"model_id": model_id,
|
||||
"is_correct_document": is_correct_document,
|
||||
"answer": answer,
|
||||
"offset_start_in_doc": offset_start_in_doc,
|
||||
def send_feedback(query, answer_obj, is_correct_answer, is_correct_document, document) -> None:
|
||||
"""
|
||||
Send a feedback (label) to the REST API
|
||||
"""
|
||||
url = f"{API_ENDPOINT}/{DOC_FEEDBACK}"
|
||||
req = {
|
||||
"id": str(uuid4()),
|
||||
"query": query,
|
||||
"document": document,
|
||||
"is_correct_answer": is_correct_answer,
|
||||
"is_correct_document": is_correct_document,
|
||||
"origin": "user-feedback",
|
||||
"answer": answer_obj
|
||||
}
|
||||
response_raw = requests.post(url, json=req).json()
|
||||
return response_raw
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
response_raw = requests.post(url, json=req)
|
||||
if response_raw.status_code >= 400:
|
||||
raise ValueError(f"An error was returned [code {response_raw.status_code}]: {response_raw.json()}")
|
||||
|
||||
|
||||
def upload_doc(file):
|
||||
|
68
ui/webapp.py
68
ui/webapp.py
@ -1,7 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import html
|
||||
import logging
|
||||
import pandas as pd
|
||||
from json import JSONDecodeError
|
||||
@ -9,13 +8,12 @@ from pathlib import Path
|
||||
import streamlit as st
|
||||
from annotated_text import annotation
|
||||
from markdown import markdown
|
||||
from htbuilder import H
|
||||
|
||||
# streamlit does not support any states out of the box. On every button click, streamlit reload the whole page
|
||||
# and every value gets lost. To keep track of our feedback state we use the official streamlit gist mentioned
|
||||
# here https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
|
||||
import SessionState
|
||||
from utils import HS_VERSION, feedback_doc, haystack_is_ready, retrieve_doc, upload_doc, haystack_version
|
||||
from utils import HS_VERSION, haystack_is_ready, query, send_feedback, upload_doc, haystack_version
|
||||
|
||||
|
||||
# Adjust to a question that you would like users to see in the search bar when they load the UI:
|
||||
@ -154,7 +152,7 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
||||
"Check out the docs: https://haystack.deepset.ai/usage/optimization "
|
||||
):
|
||||
try:
|
||||
state.results, state.raw_json = retrieve_doc(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever)
|
||||
state.results, state.raw_json = query(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever)
|
||||
except JSONDecodeError as je:
|
||||
st.error("👓 An error occurred reading the results. Is the document store working?")
|
||||
return
|
||||
@ -163,20 +161,19 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
||||
if "The server is busy processing requests" in str(e):
|
||||
st.error("🧑🌾 All our workers are busy! Try again later.")
|
||||
else:
|
||||
st.error("🐞 An error occurred during the request. Check the logs in the console to know more.")
|
||||
st.error("🐞 An error occurred during the request.")
|
||||
return
|
||||
|
||||
if state.results:
|
||||
|
||||
# Show the gold answer if we use a question of the given set
|
||||
if question == state.random_question and eval_mode:
|
||||
if question == state.random_question and eval_mode and state.random_answer:
|
||||
st.write("## Correct answers:")
|
||||
st.write(state.random_answer)
|
||||
|
||||
st.write("## Results:")
|
||||
count = 0 # Make every button key unique
|
||||
|
||||
for result in state.results:
|
||||
for count, result in enumerate(state.results):
|
||||
if result["answer"]:
|
||||
answer, context = result["answer"], result["context"]
|
||||
start_idx = context.find(answer)
|
||||
@ -191,43 +188,36 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
||||
|
||||
if eval_mode:
|
||||
# Define columns for buttons
|
||||
is_correct_answer = None
|
||||
is_correct_document = None
|
||||
|
||||
button_col1, button_col2, button_col3, _ = st.columns([1, 1, 1, 6])
|
||||
if button_col1.button("👍", key=f"{result['context']}{count}1", help="Correct answer"):
|
||||
feedback_doc(
|
||||
question=question,
|
||||
is_correct_answer="true",
|
||||
document_id=result.get("document_id", None),
|
||||
model_id=1,
|
||||
is_correct_document="true",
|
||||
answer=result["answer"],
|
||||
offset_start_in_doc=result.get("offset_start_in_doc", None)
|
||||
)
|
||||
st.success("✨ Thanks for your feedback! ✨")
|
||||
is_correct_answer=True
|
||||
is_correct_document=True
|
||||
|
||||
if button_col2.button("👎", key=f"{result['context']}{count}2", help="Wrong answer and wrong passage"):
|
||||
feedback_doc(
|
||||
question=question,
|
||||
is_correct_answer="false",
|
||||
document_id=result.get("document_id", None),
|
||||
model_id=1,
|
||||
is_correct_document="false",
|
||||
answer=result["answer"],
|
||||
offset_start_in_doc=result.get("offset_start_in_doc", None)
|
||||
)
|
||||
st.success("✨ Thanks for your feedback! ✨")
|
||||
is_correct_answer=False
|
||||
is_correct_document=False
|
||||
|
||||
if button_col3.button("👎👍", key=f"{result['context']}{count}3", help="Wrong answer, but correct passage"):
|
||||
feedback_doc(
|
||||
question=question,
|
||||
is_correct_answer="false",
|
||||
document_id=result.get("document_id", None),
|
||||
model_id=1,
|
||||
is_correct_document="true",
|
||||
answer=result["answer"],
|
||||
offset_start_in_doc=result.get("offset_start_in_doc", None)
|
||||
)
|
||||
st.success("✨ Thanks for your feedback! ✨")
|
||||
count += 1
|
||||
is_correct_answer=False
|
||||
is_correct_document=True
|
||||
|
||||
if is_correct_answer is not None and is_correct_document is not None:
|
||||
try:
|
||||
send_feedback(
|
||||
query=question,
|
||||
answer_obj=result["_raw"],
|
||||
is_correct_answer=is_correct_answer,
|
||||
is_correct_document=is_correct_document,
|
||||
document=result["document"]
|
||||
)
|
||||
st.success("✨ Thanks for your feedback! ✨")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
st.error("🐞 An error occurred while submitting your feedback!")
|
||||
|
||||
st.write("___")
|
||||
|
||||
if debug:
|
||||
|
Loading…
x
Reference in New Issue
Block a user