mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-24 00:18:30 +00:00
Fix UI demo feedback (#1816)
* Fix the feedback function of the demo with a workaround * Some docstring * Update tests and rename methods in feedback.py * Fix tests * Remove operation_ids * Add a couple of status code checks
This commit is contained in:
parent
84147edcca
commit
c29f960c47
@ -11,21 +11,21 @@ router = APIRouter()
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/feedback", operation_id="post_feedback")
|
@router.post("/feedback")
|
||||||
def user_feedback(feedback: LabelSerialized):
|
def post_feedback(feedback: LabelSerialized):
|
||||||
if feedback.origin is None:
|
if feedback.origin is None:
|
||||||
feedback.origin = "user-feedback"
|
feedback.origin = "user-feedback"
|
||||||
DOCUMENT_STORE.write_labels([feedback])
|
DOCUMENT_STORE.write_labels([feedback])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/feedback", operation_id="get_feedback")
|
@router.get("/feedback")
|
||||||
def user_feedback():
|
def get_feedback():
|
||||||
labels = DOCUMENT_STORE.get_all_labels()
|
labels = DOCUMENT_STORE.get_all_labels()
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@router.post("/eval-feedback", operation_id="get_feedback_metrics")
|
@router.post("/eval-feedback")
|
||||||
def eval_extractive_qa_feedback(filters: FilterRequest = None):
|
def get_feedback_metrics(filters: FilterRequest = None):
|
||||||
"""
|
"""
|
||||||
Return basic accuracy metrics based on the user feedback.
|
Return basic accuracy metrics based on the user feedback.
|
||||||
Which ratio of answers was correct? Which ratio of documents was correct?
|
Which ratio of answers was correct? Which ratio of documents was correct?
|
||||||
|
@ -3,7 +3,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from haystack.schema import Label
|
|
||||||
|
|
||||||
|
|
||||||
from rest_api.application import app
|
from rest_api.application import app
|
||||||
@ -16,7 +15,9 @@ FEEDBACK={
|
|||||||
"content_type": "text",
|
"content_type": "text",
|
||||||
"score": None,
|
"score": None,
|
||||||
"id": "fc18c987a8312e72a47fb1524f230bb0",
|
"id": "fc18c987a8312e72a47fb1524f230bb0",
|
||||||
"meta": {}
|
"meta": {},
|
||||||
|
"embedding": None,
|
||||||
|
"id_hash_keys": None
|
||||||
},
|
},
|
||||||
"answer":
|
"answer":
|
||||||
{
|
{
|
||||||
@ -25,7 +26,9 @@ FEEDBACK={
|
|||||||
"context": "A sample PDF file\n\nHistory and standardization\nFormat (PDF) Adobe Systems made the PDF specification available free of charge in 1993. In the early ye",
|
"context": "A sample PDF file\n\nHistory and standardization\nFormat (PDF) Adobe Systems made the PDF specification available free of charge in 1993. In the early ye",
|
||||||
"offsets_in_context": [{"start": 60, "end": 73}],
|
"offsets_in_context": [{"start": 60, "end": 73}],
|
||||||
"offsets_in_document": [{"start": 60, "end": 73}],
|
"offsets_in_document": [{"start": 60, "end": 73}],
|
||||||
"document_id": "fc18c987a8312e72a47fb1524f230bb0"
|
"document_id": "fc18c987a8312e72a47fb1524f230bb0",
|
||||||
|
"meta": {},
|
||||||
|
"score": None
|
||||||
},
|
},
|
||||||
"is_correct_answer": True,
|
"is_correct_answer": True,
|
||||||
"is_correct_document": True,
|
"is_correct_document": True,
|
||||||
@ -142,9 +145,8 @@ def test_delete_documents():
|
|||||||
|
|
||||||
def test_file_upload(client: TestClient):
|
def test_file_upload(client: TestClient):
|
||||||
file_to_upload = {'files': (Path(__file__).parent / "samples"/"pdf"/"sample_pdf_1.pdf").open('rb')}
|
file_to_upload = {'files': (Path(__file__).parent / "samples"/"pdf"/"sample_pdf_1.pdf").open('rb')}
|
||||||
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value"}'})
|
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"meta_key": "meta_value", "non-existing-field": "wrong-value"}'})
|
||||||
assert 200 == response.status_code
|
assert 200 == response.status_code
|
||||||
client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
|
||||||
|
|
||||||
|
|
||||||
def test_query_with_no_filter(populated_client: TestClient):
|
def test_query_with_no_filter(populated_client: TestClient):
|
||||||
@ -204,12 +206,16 @@ def test_write_feedback(populated_client: TestClient):
|
|||||||
|
|
||||||
def test_get_feedback(client: TestClient):
|
def test_get_feedback(client: TestClient):
|
||||||
response = client.post(url="/feedback", json=FEEDBACK)
|
response = client.post(url="/feedback", json=FEEDBACK)
|
||||||
resp = client.get(url="/feedback")
|
assert response.status_code == 200
|
||||||
labels = [Label.from_dict(i) for i in resp.json()]
|
response = client.get(url="/feedback")
|
||||||
|
assert response.status_code == 200
|
||||||
|
json_response = response.json()
|
||||||
|
for response_item, expected_item in [(json_response[0][key], value) for key, value in FEEDBACK.items()]:
|
||||||
|
assert response_item == expected_item
|
||||||
|
|
||||||
|
|
||||||
def test_export_feedback(populated_client: TestClient):
|
def test_export_feedback(client: TestClient):
|
||||||
response = populated_client.post(url="/feedback", json=FEEDBACK)
|
response = client.post(url="/feedback", json=FEEDBACK)
|
||||||
assert 200 == response.status_code
|
assert 200 == response.status_code
|
||||||
|
|
||||||
feedback_urls = [
|
feedback_urls = [
|
||||||
@ -218,10 +224,16 @@ def test_export_feedback(populated_client: TestClient):
|
|||||||
"/export-feedback?full_document_context=false&context_size=50000",
|
"/export-feedback?full_document_context=false&context_size=50000",
|
||||||
]
|
]
|
||||||
for url in feedback_urls:
|
for url in feedback_urls:
|
||||||
response = populated_client.get(url=url, json=FEEDBACK)
|
response = client.get(url=url, json=FEEDBACK)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
context = response_json["data"][0]["paragraphs"][0]["context"]
|
context = response_json["data"][0]["paragraphs"][0]["context"]
|
||||||
answer_start = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
|
answer_start = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
|
||||||
answer = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["text"]
|
answer = response_json["data"][0]["paragraphs"][0]["qas"][0]["answers"][0]["text"]
|
||||||
assert context[answer_start:answer_start+len(answer)] == answer
|
assert context[answer_start:answer_start+len(answer)] == answer
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_feedback_malformed_query(client: TestClient):
|
||||||
|
feedback = FEEDBACK.copy()
|
||||||
|
feedback["unexpected_field"] = "misplaced-value"
|
||||||
|
response = client.post(url="/feedback", json=feedback)
|
||||||
|
assert response.status_code == 422
|
||||||
|
80
ui/utils.py
80
ui/utils.py
@ -1,9 +1,12 @@
|
|||||||
import os
|
from typing import List, Dict, Any, Tuple
|
||||||
|
|
||||||
|
import os
|
||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
|
from uuid import uuid4
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
|
||||||
API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:8000")
|
API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:8000")
|
||||||
STATUS = "initialized"
|
STATUS = "initialized"
|
||||||
HS_VERSION = "hs_version"
|
HS_VERSION = "hs_version"
|
||||||
@ -13,6 +16,9 @@ DOC_UPLOAD = "file-upload"
|
|||||||
|
|
||||||
|
|
||||||
def haystack_is_ready():
|
def haystack_is_ready():
|
||||||
|
"""
|
||||||
|
Used to show the "Haystack is loading..." message
|
||||||
|
"""
|
||||||
url = f"{API_ENDPOINT}/{STATUS}"
|
url = f"{API_ENDPOINT}/{STATUS}"
|
||||||
try:
|
try:
|
||||||
if requests.get(url).json():
|
if requests.get(url).json():
|
||||||
@ -21,68 +27,80 @@ def haystack_is_ready():
|
|||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@st.cache
|
@st.cache
|
||||||
def haystack_version():
|
def haystack_version():
|
||||||
|
"""
|
||||||
|
Get the Haystack version from the REST API
|
||||||
|
"""
|
||||||
url = f"{API_ENDPOINT}/{HS_VERSION}"
|
url = f"{API_ENDPOINT}/{HS_VERSION}"
|
||||||
return requests.get(url, timeout=0.1).json()["hs_version"]
|
return requests.get(url, timeout=0.1).json()["hs_version"]
|
||||||
|
|
||||||
def retrieve_doc(query, filters={}, top_k_reader=5, top_k_retriever=5):
|
|
||||||
# Query Haystack API
|
def query(query, filters={}, top_k_reader=5, top_k_retriever=5) -> Tuple[List[Dict[str, Any]], Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Send a query to the REST API and parse the answer.
|
||||||
|
Returns both a ready-to-use representation of the results and the raw JSON.
|
||||||
|
"""
|
||||||
|
|
||||||
url = f"{API_ENDPOINT}/{DOC_REQUEST}"
|
url = f"{API_ENDPOINT}/{DOC_REQUEST}"
|
||||||
params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}, "Reader": {"top_k": top_k_reader}}
|
params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}, "Reader": {"top_k": top_k_reader}}
|
||||||
req = {"query": query, "params": params}
|
req = {"query": query, "params": params}
|
||||||
response_raw = requests.post(url, json=req).json()
|
response_raw = requests.post(url, json=req).json()
|
||||||
|
|
||||||
|
if response_raw.status_code >= 400:
|
||||||
|
raise Exception(f"{response_raw}")
|
||||||
|
|
||||||
|
response = requests.post(url, json=req).json()
|
||||||
|
if "errors" in response:
|
||||||
|
raise Exception(", ".join(response["errors"]))
|
||||||
|
|
||||||
# Format response
|
# Format response
|
||||||
result = []
|
results = []
|
||||||
|
answers = response["answers"]
|
||||||
if "errors" in response_raw:
|
for answer in answers:
|
||||||
raise Exception(", ".join(response_raw["errors"]))
|
if answer.get("answer", None):
|
||||||
|
results.append(
|
||||||
answers = response_raw["answers"]
|
|
||||||
for i in range(len(answers)):
|
|
||||||
answer = answers[i]
|
|
||||||
answer_text = answer.get("answer", None)
|
|
||||||
if answer_text:
|
|
||||||
result.append(
|
|
||||||
{
|
{
|
||||||
"context": "..." + answer["context"] + "...",
|
"context": "..." + answer["context"] + "...",
|
||||||
"answer": answer_text,
|
"answer": answer.get("answer", None),
|
||||||
"source": answer["meta"]["name"],
|
"source": answer["meta"]["name"],
|
||||||
"relevance": round(answer["score"] * 100, 2),
|
"relevance": round(answer["score"] * 100, 2),
|
||||||
"document_id": answer["document_id"],
|
"document": [doc for doc in response_raw["documents"] if doc["id"] == answer["document_id"]][0],
|
||||||
"offset_start_in_doc": answer["offsets_in_document"][0]["start"],
|
"offset_start_in_doc": answer["offsets_in_document"][0]["start"],
|
||||||
|
"_raw": answer
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"context": None,
|
"context": None,
|
||||||
"answer": None,
|
"answer": None,
|
||||||
|
"document": None,
|
||||||
"relevance": round(answer["score"] * 100, 2),
|
"relevance": round(answer["score"] * 100, 2),
|
||||||
|
"_raw": answer,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return result, response_raw
|
return results, response_raw
|
||||||
|
|
||||||
|
|
||||||
def feedback_doc(question, is_correct_answer, document_id, model_id, is_correct_document, answer, offset_start_in_doc):
|
def send_feedback(query, answer_obj, is_correct_answer, is_correct_document, document) -> None:
|
||||||
# Feedback Haystack API
|
"""
|
||||||
try:
|
Send a feedback (label) to the REST API
|
||||||
|
"""
|
||||||
url = f"{API_ENDPOINT}/{DOC_FEEDBACK}"
|
url = f"{API_ENDPOINT}/{DOC_FEEDBACK}"
|
||||||
#TODO adjust after Label refactoring
|
|
||||||
req = {
|
req = {
|
||||||
"question": question,
|
"id": str(uuid4()),
|
||||||
|
"query": query,
|
||||||
|
"document": document,
|
||||||
"is_correct_answer": is_correct_answer,
|
"is_correct_answer": is_correct_answer,
|
||||||
"document_id": document_id,
|
|
||||||
"model_id": model_id,
|
|
||||||
"is_correct_document": is_correct_document,
|
"is_correct_document": is_correct_document,
|
||||||
"answer": answer,
|
"origin": "user-feedback",
|
||||||
"offset_start_in_doc": offset_start_in_doc,
|
"answer": answer_obj
|
||||||
}
|
}
|
||||||
response_raw = requests.post(url, json=req).json()
|
response_raw = requests.post(url, json=req)
|
||||||
return response_raw
|
if response_raw.status_code >= 400:
|
||||||
except Exception as e:
|
raise ValueError(f"An error was returned [code {response_raw.status_code}]: {response_raw.json()}")
|
||||||
logging.exception(e)
|
|
||||||
|
|
||||||
|
|
||||||
def upload_doc(file):
|
def upload_doc(file):
|
||||||
|
64
ui/webapp.py
64
ui/webapp.py
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import html
|
|
||||||
import logging
|
import logging
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
@ -9,13 +8,12 @@ from pathlib import Path
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
from annotated_text import annotation
|
from annotated_text import annotation
|
||||||
from markdown import markdown
|
from markdown import markdown
|
||||||
from htbuilder import H
|
|
||||||
|
|
||||||
# streamlit does not support any states out of the box. On every button click, streamlit reload the whole page
|
# streamlit does not support any states out of the box. On every button click, streamlit reload the whole page
|
||||||
# and every value gets lost. To keep track of our feedback state we use the official streamlit gist mentioned
|
# and every value gets lost. To keep track of our feedback state we use the official streamlit gist mentioned
|
||||||
# here https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
|
# here https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
|
||||||
import SessionState
|
import SessionState
|
||||||
from utils import HS_VERSION, feedback_doc, haystack_is_ready, retrieve_doc, upload_doc, haystack_version
|
from utils import HS_VERSION, haystack_is_ready, query, send_feedback, upload_doc, haystack_version
|
||||||
|
|
||||||
|
|
||||||
# Adjust to a question that you would like users to see in the search bar when they load the UI:
|
# Adjust to a question that you would like users to see in the search bar when they load the UI:
|
||||||
@ -154,7 +152,7 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
|||||||
"Check out the docs: https://haystack.deepset.ai/usage/optimization "
|
"Check out the docs: https://haystack.deepset.ai/usage/optimization "
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
state.results, state.raw_json = retrieve_doc(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever)
|
state.results, state.raw_json = query(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever)
|
||||||
except JSONDecodeError as je:
|
except JSONDecodeError as je:
|
||||||
st.error("👓 An error occurred reading the results. Is the document store working?")
|
st.error("👓 An error occurred reading the results. Is the document store working?")
|
||||||
return
|
return
|
||||||
@ -163,20 +161,19 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
|||||||
if "The server is busy processing requests" in str(e):
|
if "The server is busy processing requests" in str(e):
|
||||||
st.error("🧑🌾 All our workers are busy! Try again later.")
|
st.error("🧑🌾 All our workers are busy! Try again later.")
|
||||||
else:
|
else:
|
||||||
st.error("🐞 An error occurred during the request. Check the logs in the console to know more.")
|
st.error("🐞 An error occurred during the request.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if state.results:
|
if state.results:
|
||||||
|
|
||||||
# Show the gold answer if we use a question of the given set
|
# Show the gold answer if we use a question of the given set
|
||||||
if question == state.random_question and eval_mode:
|
if question == state.random_question and eval_mode and state.random_answer:
|
||||||
st.write("## Correct answers:")
|
st.write("## Correct answers:")
|
||||||
st.write(state.random_answer)
|
st.write(state.random_answer)
|
||||||
|
|
||||||
st.write("## Results:")
|
st.write("## Results:")
|
||||||
count = 0 # Make every button key unique
|
|
||||||
|
|
||||||
for result in state.results:
|
for count, result in enumerate(state.results):
|
||||||
if result["answer"]:
|
if result["answer"]:
|
||||||
answer, context = result["answer"], result["context"]
|
answer, context = result["answer"], result["context"]
|
||||||
start_idx = context.find(answer)
|
start_idx = context.find(answer)
|
||||||
@ -191,43 +188,36 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
|||||||
|
|
||||||
if eval_mode:
|
if eval_mode:
|
||||||
# Define columns for buttons
|
# Define columns for buttons
|
||||||
|
is_correct_answer = None
|
||||||
|
is_correct_document = None
|
||||||
|
|
||||||
button_col1, button_col2, button_col3, _ = st.columns([1, 1, 1, 6])
|
button_col1, button_col2, button_col3, _ = st.columns([1, 1, 1, 6])
|
||||||
if button_col1.button("👍", key=f"{result['context']}{count}1", help="Correct answer"):
|
if button_col1.button("👍", key=f"{result['context']}{count}1", help="Correct answer"):
|
||||||
feedback_doc(
|
is_correct_answer=True
|
||||||
question=question,
|
is_correct_document=True
|
||||||
is_correct_answer="true",
|
|
||||||
document_id=result.get("document_id", None),
|
|
||||||
model_id=1,
|
|
||||||
is_correct_document="true",
|
|
||||||
answer=result["answer"],
|
|
||||||
offset_start_in_doc=result.get("offset_start_in_doc", None)
|
|
||||||
)
|
|
||||||
st.success("✨ Thanks for your feedback! ✨")
|
|
||||||
|
|
||||||
if button_col2.button("👎", key=f"{result['context']}{count}2", help="Wrong answer and wrong passage"):
|
if button_col2.button("👎", key=f"{result['context']}{count}2", help="Wrong answer and wrong passage"):
|
||||||
feedback_doc(
|
is_correct_answer=False
|
||||||
question=question,
|
is_correct_document=False
|
||||||
is_correct_answer="false",
|
|
||||||
document_id=result.get("document_id", None),
|
|
||||||
model_id=1,
|
|
||||||
is_correct_document="false",
|
|
||||||
answer=result["answer"],
|
|
||||||
offset_start_in_doc=result.get("offset_start_in_doc", None)
|
|
||||||
)
|
|
||||||
st.success("✨ Thanks for your feedback! ✨")
|
|
||||||
|
|
||||||
if button_col3.button("👎👍", key=f"{result['context']}{count}3", help="Wrong answer, but correct passage"):
|
if button_col3.button("👎👍", key=f"{result['context']}{count}3", help="Wrong answer, but correct passage"):
|
||||||
feedback_doc(
|
is_correct_answer=False
|
||||||
question=question,
|
is_correct_document=True
|
||||||
is_correct_answer="false",
|
|
||||||
document_id=result.get("document_id", None),
|
if is_correct_answer is not None and is_correct_document is not None:
|
||||||
model_id=1,
|
try:
|
||||||
is_correct_document="true",
|
send_feedback(
|
||||||
answer=result["answer"],
|
query=question,
|
||||||
offset_start_in_doc=result.get("offset_start_in_doc", None)
|
answer_obj=result["_raw"],
|
||||||
|
is_correct_answer=is_correct_answer,
|
||||||
|
is_correct_document=is_correct_document,
|
||||||
|
document=result["document"]
|
||||||
)
|
)
|
||||||
st.success("✨ Thanks for your feedback! ✨")
|
st.success("✨ Thanks for your feedback! ✨")
|
||||||
count += 1
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
st.error("🐞 An error occurred while submitting your feedback!")
|
||||||
|
|
||||||
st.write("___")
|
st.write("___")
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user