mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
Replace SessionState with Streamlit built-in (#2006)
* Replace SessionState with Streamlit built-in * Set session state to default if absent Co-authored-by: Yorick van Zweeden <git@yorickvanzweeden.nl>
This commit is contained in:
parent
0cca2b97cd
commit
ea10d011ab
@ -12,7 +12,6 @@ RUN pip install -r requirements.txt
|
||||
COPY utils.py /home/user/
|
||||
COPY webapp.py /home/user/
|
||||
COPY eval_labels_example.csv /home/user/
|
||||
COPY SessionState.py /home/user/
|
||||
|
||||
EXPOSE 8501
|
||||
|
||||
|
||||
@ -1,105 +0,0 @@
|
||||
"""Hack to add per-session state to Streamlit.
|
||||
Usage
|
||||
-----
|
||||
>>> import SessionState
|
||||
>>>
|
||||
>>> session_state = SessionState.get(user_name='', favorite_color='black')
|
||||
>>> session_state.user_name
|
||||
''
|
||||
>>> session_state.user_name = 'Mary'
|
||||
>>> session_state.favorite_color
|
||||
'black'
|
||||
Since you set user_name above, next time your script runs this will be the
|
||||
result:
|
||||
>>> session_state = get(user_name='', favorite_color='black')
|
||||
>>> session_state.user_name
|
||||
'Mary'
|
||||
"""
|
||||
try:
|
||||
import streamlit.ReportThread as ReportThread
|
||||
from streamlit.server.Server import Server
|
||||
except Exception:
|
||||
# Streamlit >= 0.65.0
|
||||
import streamlit.report_thread as ReportThread
|
||||
from streamlit.server.server import Server
|
||||
|
||||
|
||||
class SessionState(object):
|
||||
def __init__(self, **kwargs):
|
||||
"""A new SessionState object.
|
||||
Parameters
|
||||
----------
|
||||
**kwargs : any
|
||||
Default values for the session state.
|
||||
Example
|
||||
-------
|
||||
>>> session_state = SessionState(user_name='', favorite_color='black')
|
||||
>>> session_state.user_name = 'Mary'
|
||||
''
|
||||
>>> session_state.favorite_color
|
||||
'black'
|
||||
"""
|
||||
for key, val in kwargs.items():
|
||||
setattr(self, key, val)
|
||||
|
||||
|
||||
def get(**kwargs):
|
||||
"""Gets a SessionState object for the current session.
|
||||
Creates a new object if necessary.
|
||||
Parameters
|
||||
----------
|
||||
**kwargs : any
|
||||
Default values you want to add to the session state, if we're creating a
|
||||
new one.
|
||||
Example
|
||||
-------
|
||||
>>> session_state = get(user_name='', favorite_color='black')
|
||||
>>> session_state.user_name
|
||||
''
|
||||
>>> session_state.user_name = 'Mary'
|
||||
>>> session_state.favorite_color
|
||||
'black'
|
||||
Since you set user_name above, next time your script runs this will be the
|
||||
result:
|
||||
>>> session_state = get(user_name='', favorite_color='black')
|
||||
>>> session_state.user_name
|
||||
'Mary'
|
||||
"""
|
||||
# Hack to get the session object from Streamlit.
|
||||
|
||||
ctx = ReportThread.get_report_ctx()
|
||||
|
||||
this_session = None
|
||||
|
||||
current_server = Server.get_current()
|
||||
if hasattr(current_server, '_session_infos'):
|
||||
# Streamlit < 0.56
|
||||
session_infos = Server.get_current()._session_infos.values()
|
||||
else:
|
||||
session_infos = Server.get_current()._session_info_by_id.values()
|
||||
|
||||
for session_info in session_infos:
|
||||
s = session_info.session
|
||||
if (
|
||||
# Streamlit < 0.54.0
|
||||
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
|
||||
or
|
||||
# Streamlit >= 0.54.0
|
||||
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
|
||||
or
|
||||
# Streamlit >= 0.65.2
|
||||
(not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
|
||||
):
|
||||
this_session = s
|
||||
|
||||
if this_session is None:
|
||||
raise RuntimeError(
|
||||
"Oh noes. Couldn't get your Streamlit Session object. "
|
||||
'Are you doing something fancy with threads?')
|
||||
|
||||
# Got the session object! Now let's attach some state into it.
|
||||
|
||||
if not hasattr(this_session, '_custom_session_state'):
|
||||
this_session._custom_session_state = SessionState(**kwargs)
|
||||
|
||||
return this_session._custom_session_state
|
||||
57
ui/webapp.py
57
ui/webapp.py
@ -9,10 +9,6 @@ import streamlit as st
|
||||
from annotated_text import annotation
|
||||
from markdown import markdown
|
||||
|
||||
# streamlit does not support any states out of the box. On every button click, streamlit reload the whole page
|
||||
# and every value gets lost. To keep track of our feedback state we use the official streamlit gist mentioned
|
||||
# here https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
|
||||
import SessionState
|
||||
from utils import haystack_is_ready, query, send_feedback, upload_doc, haystack_version, get_backlink
|
||||
|
||||
|
||||
@ -31,24 +27,27 @@ EVAL_LABELS = os.getenv("EVAL_FILE", Path(__file__).parent / "eval_labels_exampl
|
||||
DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD"))
|
||||
|
||||
|
||||
def set_state_if_absent(key, value):
|
||||
if key not in st.session_state:
|
||||
st.session_state[key] = value
|
||||
|
||||
def main():
|
||||
|
||||
st.set_page_config(page_title='Haystack Demo', page_icon="https://haystack.deepset.ai/img/HaystackIcon.png")
|
||||
|
||||
# Persistent state
|
||||
state = SessionState.get(
|
||||
question=DEFAULT_QUESTION_AT_STARTUP,
|
||||
answer=DEFAULT_ANSWER_AT_STARTUP,
|
||||
results=None,
|
||||
raw_json=None,
|
||||
random_question_requested=False
|
||||
)
|
||||
set_state_if_absent('question', DEFAULT_QUESTION_AT_STARTUP)
|
||||
set_state_if_absent('answer', DEFAULT_ANSWER_AT_STARTUP)
|
||||
set_state_if_absent('results', None)
|
||||
set_state_if_absent('raw_json', None)
|
||||
set_state_if_absent('random_question_requested', False)
|
||||
|
||||
|
||||
# Small callback to reset the interface in case the text of the question changes
|
||||
def reset_results(*args):
|
||||
state.answer = None
|
||||
state.results = None
|
||||
state.raw_json = None
|
||||
st.session_state.answer = None
|
||||
st.session_state.results = None
|
||||
st.session_state.raw_json = None
|
||||
|
||||
# Title
|
||||
st.write("# Haystack Demo - Explore the world")
|
||||
@ -133,7 +132,7 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
||||
|
||||
# Search bar
|
||||
question = st.text_input("",
|
||||
value=state.question,
|
||||
value=st.session_state.question,
|
||||
max_chars=100,
|
||||
on_change=reset_results
|
||||
)
|
||||
@ -148,18 +147,18 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
||||
if col2.button("Random question"):
|
||||
reset_results()
|
||||
new_row = df.sample(1)
|
||||
while new_row["Question Text"].values[0] == state.question: # Avoid picking the same question twice (the change is not visible on the UI)
|
||||
while new_row["Question Text"].values[0] == st.session_state.question: # Avoid picking the same question twice (the change is not visible on the UI)
|
||||
new_row = df.sample(1)
|
||||
state.question = new_row["Question Text"].values[0]
|
||||
state.answer = new_row["Answer"].values[0]
|
||||
state.random_question_requested = True
|
||||
st.session_state.question = new_row["Question Text"].values[0]
|
||||
st.session_state.answer = new_row["Answer"].values[0]
|
||||
st.session_state.random_question_requested = True
|
||||
# Re-runs the script setting the random question as the textbox value
|
||||
# Unfortunately necessary as the Random Question button is _below_ the textbox
|
||||
raise st.script_runner.RerunException(st.script_request_queue.RerunData(None))
|
||||
else:
|
||||
state.random_question_requested = False
|
||||
st.session_state.random_question_requested = False
|
||||
|
||||
run_query = (run_pressed or question != state.question) and not state.random_question_requested
|
||||
run_query = (run_pressed or question != st.session_state.question) and not st.session_state.random_question_requested
|
||||
|
||||
# Check the connection
|
||||
with st.spinner("⌛️ Haystack is starting..."):
|
||||
@ -171,14 +170,16 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
||||
# Get results for query
|
||||
if run_query and question:
|
||||
reset_results()
|
||||
state.question = question
|
||||
st.session_state.question = question
|
||||
|
||||
with st.spinner(
|
||||
"🧠 Performing neural search on documents... \n "
|
||||
"Do you want to optimize speed or accuracy? \n"
|
||||
"Check out the docs: https://haystack.deepset.ai/usage/optimization "
|
||||
):
|
||||
try:
|
||||
state.results, state.raw_json = query(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever)
|
||||
st.session_state.results, st.session_state.raw_json = query(question, top_k_reader=top_k_reader,
|
||||
top_k_retriever=top_k_retriever)
|
||||
except JSONDecodeError as je:
|
||||
st.error("👓 An error occurred reading the results. Is the document store working?")
|
||||
return
|
||||
@ -190,16 +191,16 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
||||
st.error("🐞 An error occurred during the request.")
|
||||
return
|
||||
|
||||
if state.results:
|
||||
if st.session_state.results:
|
||||
|
||||
# Show the gold answer if we use a question of the given set
|
||||
if eval_mode and state.answer:
|
||||
if eval_mode and st.session_state.answer:
|
||||
st.write("## Correct answer:")
|
||||
st.write(state.answer)
|
||||
st.write(st.session_state.answer)
|
||||
|
||||
st.write("## Results:")
|
||||
|
||||
for count, result in enumerate(state.results):
|
||||
for count, result in enumerate(st.session_state.results):
|
||||
if result["answer"]:
|
||||
answer, context = result["answer"], result["context"]
|
||||
start_idx = context.find(answer)
|
||||
@ -254,6 +255,6 @@ Ask any question on this topic and see if Haystack can find the correct answer t
|
||||
|
||||
if debug:
|
||||
st.subheader("REST API JSON response")
|
||||
st.write(state.raw_json)
|
||||
st.write(st.session_state.raw_json)
|
||||
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user