mirror of
https://github.com/microsoft/graphrag.git
synced 2025-08-02 22:02:34 +00:00

* unified search app added to graphrag repository * ignore print statements * update words for unified-search * fix lint errors * fix lint error * fix module name --------- Co-authored-by: Gaudy Blanco <gaudy-microsoft@MacBook-Pro-m4-Gaudy-For-Work.local>
261 lines
9.1 KiB
Python
261 lines
9.1 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""Home Page module."""
|
|
|
|
import asyncio
|
|
|
|
import streamlit as st
|
|
from app_logic import dataset_name, initialize, run_all_searches, run_generate_questions
|
|
from rag.typing import SearchType
|
|
from st_tabs import TabBar
|
|
from state.session_variables import SessionVariables
|
|
from ui.full_graph import create_full_graph_ui
|
|
from ui.questions_list import create_questions_list_ui
|
|
from ui.report_details import create_report_details_ui
|
|
from ui.report_list import create_report_list_ui
|
|
from ui.search import display_citations, format_suggested_questions, init_search_ui
|
|
from ui.sidebar import create_side_bar
|
|
|
|
|
|
async def main():
|
|
"""Return main streamlit component to render the app."""
|
|
sv = initialize()
|
|
|
|
create_side_bar(sv)
|
|
|
|
st.markdown(
|
|
"#### GraphRAG: A Novel Knowledge Graph-based Approach to Retrieval Augmented Generation (RAG)"
|
|
)
|
|
st.markdown("##### Dataset selected: " + dataset_name(sv.dataset.value, sv))
|
|
st.markdown(sv.dataset_config.value.description)
|
|
|
|
def on_click_reset(sv: SessionVariables):
|
|
sv.generated_questions.value = []
|
|
sv.selected_question.value = ""
|
|
sv.show_text_input.value = True
|
|
|
|
def on_change(sv: SessionVariables):
|
|
sv.question.value = st.session_state[question_input]
|
|
|
|
question_input = "question_input"
|
|
|
|
generate_questions = st.button("Suggest some questions")
|
|
|
|
question = ""
|
|
|
|
if len(sv.question.value.strip()) > 0:
|
|
question = sv.question.value
|
|
|
|
if generate_questions:
|
|
with st.spinner("Generating suggested questions..."):
|
|
try:
|
|
result = await run_generate_questions(
|
|
query=f"Generate numbered list only with the top {sv.suggested_questions.value} most important questions of this dataset (numbered list only without titles or anything extra)",
|
|
sv=sv,
|
|
)
|
|
for result_item in result:
|
|
questions = format_suggested_questions(result_item.response)
|
|
sv.generated_questions.value = questions
|
|
sv.show_text_input.value = False
|
|
except Exception as e: # noqa: BLE001
|
|
print(f"Search exception: {e}") # noqa T201
|
|
st.write(e)
|
|
|
|
if sv.show_text_input.value is True:
|
|
st.text_input(
|
|
"Ask a question to compare the results",
|
|
key=question_input,
|
|
on_change=on_change,
|
|
value=question,
|
|
kwargs={"sv": sv},
|
|
)
|
|
|
|
if len(sv.generated_questions.value) != 0:
|
|
create_questions_list_ui(sv)
|
|
|
|
if sv.show_text_input.value is False:
|
|
st.button(label="Reset", on_click=on_click_reset, kwargs={"sv": sv})
|
|
|
|
tab_id = TabBar(
|
|
tabs=["Search", "Graph Explorer"],
|
|
color="#fc9e9e",
|
|
activeColor="#ff4b4b",
|
|
default=0,
|
|
)
|
|
|
|
if tab_id == 0:
|
|
if len(sv.question.value.strip()) > 0:
|
|
question = sv.question.value
|
|
|
|
if sv.selected_question.value != "":
|
|
question = sv.selected_question.value
|
|
sv.question.value = question
|
|
|
|
if question:
|
|
st.write(f"##### Answering the question: *{question}*")
|
|
|
|
ss_basic = None
|
|
ss_local = None
|
|
ss_global = None
|
|
ss_drift = None
|
|
|
|
ss_basic_citations = None
|
|
ss_local_citations = None
|
|
ss_global_citations = None
|
|
ss_drift_citations = None
|
|
|
|
count = sum([
|
|
sv.include_basic_rag.value,
|
|
sv.include_local_search.value,
|
|
sv.include_global_search.value,
|
|
sv.include_drift_search.value,
|
|
])
|
|
|
|
if count > 0:
|
|
columns = st.columns(count)
|
|
index = 0
|
|
if sv.include_basic_rag.value:
|
|
ss_basic = columns[index]
|
|
index += 1
|
|
if sv.include_local_search.value:
|
|
ss_local = columns[index]
|
|
index += 1
|
|
if sv.include_global_search.value:
|
|
ss_global = columns[index]
|
|
index += 1
|
|
if sv.include_drift_search.value:
|
|
ss_drift = columns[index]
|
|
|
|
else:
|
|
st.write("Please select at least one search option from the sidebar.")
|
|
|
|
with st.container():
|
|
if ss_basic:
|
|
with ss_basic:
|
|
init_search_ui(
|
|
container=ss_basic,
|
|
search_type=SearchType.Basic,
|
|
title="##### GraphRAG: Basic RAG",
|
|
caption="###### Answer context: Fixed number of text chunks of raw documents",
|
|
)
|
|
|
|
if ss_local:
|
|
with ss_local:
|
|
init_search_ui(
|
|
container=ss_local,
|
|
search_type=SearchType.Local,
|
|
title="##### GraphRAG: Local Search",
|
|
caption="###### Answer context: Graph index query results with relevant document text chunks",
|
|
)
|
|
|
|
if ss_global:
|
|
with ss_global:
|
|
init_search_ui(
|
|
container=ss_global,
|
|
search_type=SearchType.Global,
|
|
title="##### GraphRAG: Global Search",
|
|
caption="###### Answer context: AI-generated network reports covering all input documents",
|
|
)
|
|
|
|
if ss_drift:
|
|
with ss_drift:
|
|
init_search_ui(
|
|
container=ss_drift,
|
|
search_type=SearchType.Drift,
|
|
title="##### GraphRAG: Drift Search",
|
|
caption="###### Answer context: Includes community information",
|
|
)
|
|
|
|
count = sum([
|
|
sv.include_basic_rag.value,
|
|
sv.include_local_search.value,
|
|
sv.include_global_search.value,
|
|
sv.include_drift_search.value,
|
|
])
|
|
|
|
if count > 0:
|
|
columns = st.columns(count)
|
|
index = 0
|
|
if sv.include_basic_rag.value:
|
|
ss_basic_citations = columns[index]
|
|
index += 1
|
|
if sv.include_local_search.value:
|
|
ss_local_citations = columns[index]
|
|
index += 1
|
|
if sv.include_global_search.value:
|
|
ss_global_citations = columns[index]
|
|
index += 1
|
|
if sv.include_drift_search.value:
|
|
ss_drift_citations = columns[index]
|
|
|
|
with st.container():
|
|
if ss_basic_citations:
|
|
with ss_basic_citations:
|
|
st.empty()
|
|
if ss_local_citations:
|
|
with ss_local_citations:
|
|
st.empty()
|
|
if ss_global_citations:
|
|
with ss_global_citations:
|
|
st.empty()
|
|
if ss_drift_citations:
|
|
with ss_drift_citations:
|
|
st.empty()
|
|
|
|
if question != "" and question != sv.question_in_progress.value:
|
|
sv.question_in_progress.value = question
|
|
try:
|
|
await run_all_searches(query=question, sv=sv)
|
|
|
|
if "response_lengths" not in st.session_state:
|
|
st.session_state.response_lengths = []
|
|
|
|
for result in st.session_state.response_lengths:
|
|
if result["search"] == SearchType.Basic.value.lower():
|
|
display_citations(
|
|
container=ss_basic_citations,
|
|
result=result["result"],
|
|
)
|
|
if result["search"] == SearchType.Local.value.lower():
|
|
display_citations(
|
|
container=ss_local_citations,
|
|
result=result["result"],
|
|
)
|
|
if result["search"] == SearchType.Global.value.lower():
|
|
display_citations(
|
|
container=ss_global_citations,
|
|
result=result["result"],
|
|
)
|
|
elif result["search"] == SearchType.Drift.value.lower():
|
|
display_citations(
|
|
container=ss_drift_citations,
|
|
result=result["result"],
|
|
)
|
|
except Exception as e: # noqa: BLE001
|
|
print(f"Search exception: {e}") # noqa T201
|
|
st.write(e)
|
|
|
|
if tab_id == 1:
|
|
report_list, graph, report_content = st.columns([0.20, 0.55, 0.25])
|
|
|
|
with report_list:
|
|
st.markdown("##### Community Reports")
|
|
create_report_list_ui(sv)
|
|
|
|
with graph:
|
|
title, dropdown = st.columns([0.80, 0.20])
|
|
title.markdown("##### Entity Graph (All entities)")
|
|
dropdown.selectbox(
|
|
"Community level", options=[0, 1], key=sv.graph_community_level.key
|
|
)
|
|
create_full_graph_ui(sv)
|
|
|
|
with report_content:
|
|
st.markdown("##### Selected Report")
|
|
create_report_details_ui(sv)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|