diff --git a/ui/Dockerfile b/ui/Dockerfile index bce12a4fc..c7a256c32 100644 --- a/ui/Dockerfile +++ b/ui/Dockerfile @@ -12,7 +12,6 @@ RUN pip install -r requirements.txt COPY utils.py /home/user/ COPY webapp.py /home/user/ COPY eval_labels_example.csv /home/user/ -COPY SessionState.py /home/user/ EXPOSE 8501 diff --git a/ui/SessionState.py b/ui/SessionState.py deleted file mode 100644 index f8b03aec9..000000000 --- a/ui/SessionState.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Hack to add per-session state to Streamlit. -Usage ------ ->>> import SessionState ->>> ->>> session_state = SessionState.get(user_name='', favorite_color='black') ->>> session_state.user_name -'' ->>> session_state.user_name = 'Mary' ->>> session_state.favorite_color -'black' -Since you set user_name above, next time your script runs this will be the -result: ->>> session_state = get(user_name='', favorite_color='black') ->>> session_state.user_name -'Mary' -""" -try: - import streamlit.ReportThread as ReportThread - from streamlit.server.Server import Server -except Exception: - # Streamlit >= 0.65.0 - import streamlit.report_thread as ReportThread - from streamlit.server.server import Server - - -class SessionState(object): - def __init__(self, **kwargs): - """A new SessionState object. - Parameters - ---------- - **kwargs : any - Default values for the session state. - Example - ------- - >>> session_state = SessionState(user_name='', favorite_color='black') - >>> session_state.user_name = 'Mary' - '' - >>> session_state.favorite_color - 'black' - """ - for key, val in kwargs.items(): - setattr(self, key, val) - - -def get(**kwargs): - """Gets a SessionState object for the current session. - Creates a new object if necessary. - Parameters - ---------- - **kwargs : any - Default values you want to add to the session state, if we're creating a - new one. - Example - ------- - >>> session_state = get(user_name='', favorite_color='black') - >>> session_state.user_name - '' - >>> session_state.user_name = 'Mary' - >>> session_state.favorite_color - 'black' - Since you set user_name above, next time your script runs this will be the - result: - >>> session_state = get(user_name='', favorite_color='black') - >>> session_state.user_name - 'Mary' - """ - # Hack to get the session object from Streamlit. - - ctx = ReportThread.get_report_ctx() - - this_session = None - - current_server = Server.get_current() - if hasattr(current_server, '_session_infos'): - # Streamlit < 0.56 - session_infos = Server.get_current()._session_infos.values() - else: - session_infos = Server.get_current()._session_info_by_id.values() - - for session_info in session_infos: - s = session_info.session - if ( - # Streamlit < 0.54.0 - (hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg) - or - # Streamlit >= 0.54.0 - (not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue) - or - # Streamlit >= 0.65.2 - (not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr) - ): - this_session = s - - if this_session is None: - raise RuntimeError( - "Oh noes. Couldn't get your Streamlit Session object. " - 'Are you doing something fancy with threads?') - - # Got the session object! Now let's attach some state into it. - - if not hasattr(this_session, '_custom_session_state'): - this_session._custom_session_state = SessionState(**kwargs) - - return this_session._custom_session_state \ No newline at end of file diff --git a/ui/webapp.py b/ui/webapp.py index cb4f86afa..5ce7c3ea7 100644 --- a/ui/webapp.py +++ b/ui/webapp.py @@ -9,10 +9,6 @@ import streamlit as st from annotated_text import annotation from markdown import markdown -# streamlit does not support any states out of the box. On every button click, streamlit reload the whole page -# and every value gets lost. To keep track of our feedback state we use the official streamlit gist mentioned -# here https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92 -import SessionState from utils import haystack_is_ready, query, send_feedback, upload_doc, haystack_version, get_backlink @@ -31,24 +27,27 @@ EVAL_LABELS = os.getenv("EVAL_FILE", Path(__file__).parent / "eval_labels_exampl DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD")) +def set_state_if_absent(key, value): + if key not in st.session_state: + st.session_state[key] = value + def main(): st.set_page_config(page_title='Haystack Demo', page_icon="https://haystack.deepset.ai/img/HaystackIcon.png") # Persistent state - state = SessionState.get( - question=DEFAULT_QUESTION_AT_STARTUP, - answer=DEFAULT_ANSWER_AT_STARTUP, - results=None, - raw_json=None, - random_question_requested=False - ) + set_state_if_absent('question', DEFAULT_QUESTION_AT_STARTUP) + set_state_if_absent('answer', DEFAULT_ANSWER_AT_STARTUP) + set_state_if_absent('results', None) + set_state_if_absent('raw_json', None) + set_state_if_absent('random_question_requested', False) + # Small callback to reset the interface in case the text of the question changes def reset_results(*args): - state.answer = None - state.results = None - state.raw_json = None + st.session_state.answer = None + st.session_state.results = None + st.session_state.raw_json = None # Title st.write("# Haystack Demo - Explore the world") @@ -133,7 +132,7 @@ Ask any question on this topic and see if Haystack can find the correct answer t # Search bar question = st.text_input("", - value=state.question, + value=st.session_state.question, max_chars=100, on_change=reset_results ) @@ -148,18 +147,18 @@ Ask any question on this topic and see if Haystack can find the correct answer t if col2.button("Random question"): reset_results() new_row = df.sample(1) - while new_row["Question Text"].values[0] == state.question: # Avoid picking the same question twice (the change is not visible on the UI) + while new_row["Question Text"].values[0] == st.session_state.question: # Avoid picking the same question twice (the change is not visible on the UI) new_row = df.sample(1) - state.question = new_row["Question Text"].values[0] - state.answer = new_row["Answer"].values[0] - state.random_question_requested = True + st.session_state.question = new_row["Question Text"].values[0] + st.session_state.answer = new_row["Answer"].values[0] + st.session_state.random_question_requested = True # Re-runs the script setting the random question as the textbox value # Unfortunately necessary as the Random Question button is _below_ the textbox raise st.script_runner.RerunException(st.script_request_queue.RerunData(None)) else: - state.random_question_requested = False + st.session_state.random_question_requested = False - run_query = (run_pressed or question != state.question) and not state.random_question_requested + run_query = (run_pressed or question != st.session_state.question) and not st.session_state.random_question_requested # Check the connection with st.spinner("⌛️    Haystack is starting..."): @@ -171,14 +170,16 @@ Ask any question on this topic and see if Haystack can find the correct answer t # Get results for query if run_query and question: reset_results() - state.question = question + st.session_state.question = question + with st.spinner( "🧠    Performing neural search on documents... \n " "Do you want to optimize speed or accuracy? \n" "Check out the docs: https://haystack.deepset.ai/usage/optimization " ): try: - state.results, state.raw_json = query(question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever) + st.session_state.results, st.session_state.raw_json = query(question, top_k_reader=top_k_reader, + top_k_retriever=top_k_retriever) except JSONDecodeError as je: st.error("👓    An error occurred reading the results. Is the document store working?") return @@ -190,16 +191,16 @@ Ask any question on this topic and see if Haystack can find the correct answer t st.error("🐞    An error occurred during the request.") return - if state.results: + if st.session_state.results: # Show the gold answer if we use a question of the given set - if eval_mode and state.answer: + if eval_mode and st.session_state.answer: st.write("## Correct answer:") - st.write(state.answer) + st.write(st.session_state.answer) st.write("## Results:") - for count, result in enumerate(state.results): + for count, result in enumerate(st.session_state.results): if result["answer"]: answer, context = result["answer"], result["context"] start_idx = context.find(answer) @@ -254,6 +255,6 @@ Ask any question on this topic and see if Haystack can find the correct answer t if debug: st.subheader("REST API JSON response") - st.write(state.raw_json) + st.write(st.session_state.raw_json) main()