Replace SessionState with Streamlit built-in (#2006)

* Replace SessionState with Streamlit built-in

* Set session state to default if absent

Co-authored-by: Yorick van Zweeden <git@yorickvanzweeden.nl>
This commit is contained in:
Yorick van Zweeden 2022-01-18 14:59:42 +01:00 committed by GitHub
parent 0cca2b97cd
commit ea10d011ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 134 deletions

View File

@ -12,7 +12,6 @@ RUN pip install -r requirements.txt
COPY utils.py /home/user/
COPY webapp.py /home/user/
COPY eval_labels_example.csv /home/user/
COPY SessionState.py /home/user/
EXPOSE 8501

View File

@ -1,105 +0,0 @@
"""Hack to add per-session state to Streamlit.
Usage
-----
>>> import SessionState
>>>
>>> session_state = SessionState.get(user_name='', favorite_color='black')
>>> session_state.user_name
''
>>> session_state.user_name = 'Mary'
>>> session_state.favorite_color
'black'
Since you set user_name above, next time your script runs this will be the
result:
>>> session_state = get(user_name='', favorite_color='black')
>>> session_state.user_name
'Mary'
"""
try:
import streamlit.ReportThread as ReportThread
from streamlit.server.Server import Server
except Exception:
# Streamlit >= 0.65.0
import streamlit.report_thread as ReportThread
from streamlit.server.server import Server
class SessionState(object):
def __init__(self, **kwargs):
"""A new SessionState object.
Parameters
----------
**kwargs : any
Default values for the session state.
Example
-------
>>> session_state = SessionState(user_name='', favorite_color='black')
>>> session_state.user_name = 'Mary'
''
>>> session_state.favorite_color
'black'
"""
for key, val in kwargs.items():
setattr(self, key, val)
def get(**kwargs):
"""Gets a SessionState object for the current session.
Creates a new object if necessary.
Parameters
----------
**kwargs : any
Default values you want to add to the session state, if we're creating a
new one.
Example
-------
>>> session_state = get(user_name='', favorite_color='black')
>>> session_state.user_name
''
>>> session_state.user_name = 'Mary'
>>> session_state.favorite_color
'black'
Since you set user_name above, next time your script runs this will be the
result:
>>> session_state = get(user_name='', favorite_color='black')
>>> session_state.user_name
'Mary'
"""
# Hack to get the session object from Streamlit.
ctx = ReportThread.get_report_ctx()
this_session = None
current_server = Server.get_current()
if hasattr(current_server, '_session_infos'):
# Streamlit < 0.56
session_infos = Server.get_current()._session_infos.values()
else:
session_infos = Server.get_current()._session_info_by_id.values()
for session_info in session_infos:
s = session_info.session
if (
# Streamlit < 0.54.0
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
or
# Streamlit >= 0.54.0
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
or
# Streamlit >= 0.65.2
(not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
):
this_session = s
if this_session is None:
raise RuntimeError(
"Oh noes. Couldn't get your Streamlit Session object. "
'Are you doing something fancy with threads?')
# Got the session object! Now let's attach some state into it.
if not hasattr(this_session, '_custom_session_state'):
this_session._custom_session_state = SessionState(**kwargs)
return this_session._custom_session_state

View File

@ -9,10 +9,6 @@ import streamlit as st
from annotated_text import annotation
from markdown import markdown
# streamlit does not support any states out of the box. On every button click, streamlit reload the whole page
# and every value gets lost. To keep track of our feedback state we use the official streamlit gist mentioned
# here https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
import SessionState
from utils import haystack_is_ready, query, send_feedback, upload_doc, haystack_version, get_backlink
@ -31,24 +27,27 @@ EVAL_LABELS = os.getenv("EVAL_FILE", Path(__file__).parent / "eval_labels_exampl
DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD"))
def set_state_if_absent(key, value):
if key not in st.session_state:
st.session_state[key] = value
def main():
st.set_page_config(page_title='Haystack Demo', page_icon="https://haystack.deepset.ai/img/HaystackIcon.png")
# Persistent state
state = SessionState.get(
question=DEFAULT_QUESTION_AT_STARTUP,
answer=DEFAULT_ANSWER_AT_STARTUP,
results=None,
raw_json=None,
random_question_requested=False
)
set_state_if_absent('question', DEFAULT_QUESTION_AT_STARTUP)
set_state_if_absent('answer', DEFAULT_ANSWER_AT_STARTUP)
set_state_if_absent('results', None)
set_state_if_absent('raw_json', None)
set_state_if_absent('random_question_requested', False)
# Small callback to reset the interface in case the text of the question changes
def reset_results(*args):
state.answer = None
state.results = None
state.raw_json = None
st.session_state.answer = None
st.session_state.results = None
st.session_state.raw_json = None
# Title
st.write("# Haystack Demo - Explore the world")
@ -133,7 +132,7 @@ Ask any question on this topic and see if Haystack can find the correct answer t
# Search bar
question = st.text_input("",
value=state.question,
value=st.session_state.question,
max_chars=100,
on_change=reset_results
)
@ -148,18 +147,18 @@ Ask any question on this topic and see if Haystack can find the correct answer t
if col2.button("Random question"):
reset_results()
new_row = df.sample(1)
while new_row["Question Text"].values[0] == state.question: # Avoid picking the same question twice (the change is not visible on the UI)
while new_row["Question Text"].values[0] == st.session_state.question: # Avoid picking the same question twice (the change is not visible on the UI)
new_row = df.sample(1)
state.question = new_row["Question Text"].values[0]
state.answer = new_row["Answer"].values[0]
state.random_question_requested = True
st.session_state.question = new_row["Question Text"].values[0]
st.session_state.answer = new_row["Answer"].values[0]
st.session_state.random_question_requested = True
# Re-runs the script setting the random question as the textbox value
# Unfortunately necessary as the Random Question button is _below_ the textbox
raise st.script_runner.RerunException(st.script_request_queue.RerunData(None))
else:
state.random_question_requested = False
st.session_state.random_question_requested = False
run_query = (run_pressed or question != state.question) and not state.random_question_requested
run_query = (run_pressed or question != st.session_state.question) and not st.session_state.random_question_requested
# Check the connection
with st.spinner("⌛️ &nbsp;&nbsp; Haystack is starting..."):
@ -171,14 +170,16 @@ Ask any question on this topic and see if Haystack can find the correct answer t
# Get results for query
if run_query and question:
reset_results()
state.question = question
st.session_state.question = question
with st.spinner(
"🧠 &nbsp;&nbsp; Performing neural search on documents... \n "
"Do you want to optimize speed or accuracy? \n"
"Check out the docs: https://haystack.deepset.ai/usage/optimization "
):
try:
state.results, state.raw_json = query(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever)
st.session_state.results, st.session_state.raw_json = query(question, top_k_reader=top_k_reader,
top_k_retriever=top_k_retriever)
except JSONDecodeError as je:
st.error("👓 &nbsp;&nbsp; An error occurred reading the results. Is the document store working?")
return
@ -190,16 +191,16 @@ Ask any question on this topic and see if Haystack can find the correct answer t
st.error("🐞 &nbsp;&nbsp; An error occurred during the request.")
return
if state.results:
if st.session_state.results:
# Show the gold answer if we use a question of the given set
if eval_mode and state.answer:
if eval_mode and st.session_state.answer:
st.write("## Correct answer:")
st.write(state.answer)
st.write(st.session_state.answer)
st.write("## Results:")
for count, result in enumerate(state.results):
for count, result in enumerate(st.session_state.results):
if result["answer"]:
answer, context = result["answer"], result["context"]
start_idx = context.find(answer)
@ -254,6 +255,6 @@ Ask any question on this topic and see if Haystack can find the correct answer t
if debug:
st.subheader("REST API JSON response")
st.write(state.raw_json)
st.write(st.session_state.raw_json)
main()