2022-02-09 18:27:12 +01:00
|
|
|
from typing import List, Dict, Any, Tuple, Optional
|
2021-04-07 17:53:32 +02:00
|
|
|
|
2021-11-29 17:03:54 +01:00
|
|
|
import os
|
2021-09-27 16:40:25 +02:00
|
|
|
import logging
|
2020-12-27 18:06:09 +05:30
|
|
|
import requests
|
2021-11-30 18:11:54 +01:00
|
|
|
from time import sleep
|
2021-11-29 17:03:54 +01:00
|
|
|
from uuid import uuid4
|
2022-02-09 17:35:18 +01:00
|
|
|
import streamlit as st
|
2020-12-27 18:06:09 +05:30
|
|
|
|
2021-11-29 17:03:54 +01:00
|
|
|
|
2020-12-27 18:06:09 +05:30
|
|
|
API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:8000")
|
2021-09-27 16:40:25 +02:00
|
|
|
STATUS = "initialized"
|
2021-11-19 11:34:32 +01:00
|
|
|
HS_VERSION = "hs_version"
|
2021-04-07 17:53:32 +02:00
|
|
|
DOC_REQUEST = "query"
|
2021-04-22 17:30:17 +02:00
|
|
|
DOC_FEEDBACK = "feedback"
|
2021-04-30 14:16:30 +05:30
|
|
|
DOC_UPLOAD = "file-upload"
|
|
|
|
|
2020-12-27 18:06:09 +05:30
|
|
|
|
2021-09-27 16:40:25 +02:00
|
|
|
def haystack_is_ready():
|
2021-11-29 17:03:54 +01:00
|
|
|
"""
|
|
|
|
Used to show the "Haystack is loading..." message
|
|
|
|
"""
|
2021-09-27 16:40:25 +02:00
|
|
|
url = f"{API_ENDPOINT}/{STATUS}"
|
|
|
|
try:
|
2021-12-10 18:05:23 +01:00
|
|
|
if requests.get(url).status_code < 400:
|
2021-09-27 16:40:25 +02:00
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
|
|
logging.exception(e)
|
2021-11-30 18:11:54 +01:00
|
|
|
sleep(1) # To avoid spamming a non-existing endpoint at startup
|
2021-09-27 16:40:25 +02:00
|
|
|
return False
|
|
|
|
|
2021-11-29 17:03:54 +01:00
|
|
|
|
2021-11-19 11:34:32 +01:00
|
|
|
@st.cache
|
|
|
|
def haystack_version():
|
2021-11-29 17:03:54 +01:00
|
|
|
"""
|
|
|
|
Get the Haystack version from the REST API
|
|
|
|
"""
|
2021-11-19 11:34:32 +01:00
|
|
|
url = f"{API_ENDPOINT}/{HS_VERSION}"
|
2021-11-22 19:06:08 +01:00
|
|
|
return requests.get(url, timeout=0.1).json()["hs_version"]
|
2021-11-19 11:34:32 +01:00
|
|
|
|
2021-11-29 17:03:54 +01:00
|
|
|
|
|
|
|
def query(query, filters={}, top_k_reader=5, top_k_retriever=5) -> Tuple[List[Dict[str, Any]], Dict[str, str]]:
|
|
|
|
"""
|
|
|
|
Send a query to the REST API and parse the answer.
|
|
|
|
Returns both a ready-to-use representation of the results and the raw JSON.
|
|
|
|
"""
|
|
|
|
|
2021-04-30 14:16:30 +05:30
|
|
|
url = f"{API_ENDPOINT}/{DOC_REQUEST}"
|
2021-10-20 11:55:29 +02:00
|
|
|
params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}, "Reader": {"top_k": top_k_reader}}
|
2021-09-10 11:41:16 +02:00
|
|
|
req = {"query": query, "params": params}
|
2021-11-29 19:42:10 +01:00
|
|
|
response_raw = requests.post(url, json=req)
|
2021-04-30 14:16:30 +05:30
|
|
|
|
2021-11-30 18:11:54 +01:00
|
|
|
if response_raw.status_code >= 400 and response_raw.status_code != 503:
|
|
|
|
raise Exception(f"{vars(response_raw)}")
|
2021-11-29 17:03:54 +01:00
|
|
|
|
2021-11-29 19:42:10 +01:00
|
|
|
response = response_raw.json()
|
2021-11-29 17:03:54 +01:00
|
|
|
if "errors" in response:
|
|
|
|
raise Exception(", ".join(response["errors"]))
|
|
|
|
|
2021-04-30 14:16:30 +05:30
|
|
|
# Format response
|
2021-11-29 17:03:54 +01:00
|
|
|
results = []
|
|
|
|
answers = response["answers"]
|
|
|
|
for answer in answers:
|
|
|
|
if answer.get("answer", None):
|
|
|
|
results.append(
|
2021-04-30 14:16:30 +05:30
|
|
|
{
|
2021-10-13 16:48:33 +02:00
|
|
|
"context": "..." + answer["context"] + "...",
|
2021-11-29 17:03:54 +01:00
|
|
|
"answer": answer.get("answer", None),
|
2021-10-13 16:48:33 +02:00
|
|
|
"source": answer["meta"]["name"],
|
|
|
|
"relevance": round(answer["score"] * 100, 2),
|
2021-11-29 19:42:10 +01:00
|
|
|
"document": [doc for doc in response["documents"] if doc["id"] == answer["document_id"]][0],
|
2021-10-13 16:48:33 +02:00
|
|
|
"offset_start_in_doc": answer["offsets_in_document"][0]["start"],
|
2022-02-03 13:43:18 +01:00
|
|
|
"_raw": answer,
|
2021-04-30 14:16:30 +05:30
|
|
|
}
|
|
|
|
)
|
2021-11-22 19:06:08 +01:00
|
|
|
else:
|
2021-11-29 17:03:54 +01:00
|
|
|
results.append(
|
2021-11-22 19:06:08 +01:00
|
|
|
{
|
|
|
|
"context": None,
|
|
|
|
"answer": None,
|
2021-11-29 17:03:54 +01:00
|
|
|
"document": None,
|
2021-11-22 19:06:08 +01:00
|
|
|
"relevance": round(answer["score"] * 100, 2),
|
2021-11-29 17:03:54 +01:00
|
|
|
"_raw": answer,
|
2021-11-22 19:06:08 +01:00
|
|
|
}
|
|
|
|
)
|
2021-12-06 18:55:39 +01:00
|
|
|
return results, response
|
2021-04-30 14:16:30 +05:30
|
|
|
|
|
|
|
|
2021-11-29 17:03:54 +01:00
|
|
|
def send_feedback(query, answer_obj, is_correct_answer, is_correct_document, document) -> None:
|
|
|
|
"""
|
|
|
|
Send a feedback (label) to the REST API
|
|
|
|
"""
|
|
|
|
url = f"{API_ENDPOINT}/{DOC_FEEDBACK}"
|
|
|
|
req = {
|
|
|
|
"query": query,
|
|
|
|
"document": document,
|
|
|
|
"is_correct_answer": is_correct_answer,
|
|
|
|
"is_correct_document": is_correct_document,
|
|
|
|
"origin": "user-feedback",
|
2022-02-03 13:43:18 +01:00
|
|
|
"answer": answer_obj,
|
|
|
|
}
|
2021-11-29 17:03:54 +01:00
|
|
|
response_raw = requests.post(url, json=req)
|
|
|
|
if response_raw.status_code >= 400:
|
|
|
|
raise ValueError(f"An error was returned [code {response_raw.status_code}]: {response_raw.json()}")
|
2021-04-30 14:16:30 +05:30
|
|
|
|
|
|
|
|
|
|
|
def upload_doc(file):
|
|
|
|
url = f"{API_ENDPOINT}/{DOC_UPLOAD}"
|
2021-06-30 17:13:46 +05:00
|
|
|
files = [("files", file)]
|
2021-12-06 18:55:39 +01:00
|
|
|
response = requests.post(url, files=files).json()
|
|
|
|
return response
|
2021-12-02 13:37:23 +01:00
|
|
|
|
|
|
|
|
2022-02-09 18:27:12 +01:00
|
|
|
def get_backlink(result) -> Tuple[Optional[str], Optional[str]]:
|
2021-12-02 13:37:23 +01:00
|
|
|
if result.get("document", None):
|
|
|
|
doc = result["document"]
|
|
|
|
if isinstance(doc, dict):
|
|
|
|
if doc.get("meta", None):
|
|
|
|
if isinstance(doc["meta"], dict):
|
|
|
|
if doc["meta"].get("url", None) and doc["meta"].get("title", None):
|
|
|
|
return doc["meta"]["url"], doc["meta"]["title"]
|
2022-02-03 13:43:18 +01:00
|
|
|
return None, None
|