mirror of
https://github.com/microsoft/graphrag.git
synced 2025-07-24 17:31:50 +00:00

* unified search app added to graphrag repository * ignore print statements * update words for unified-search * fix lint errors * fix lint error * fix module name --------- Co-authored-by: Gaudy Blanco <gaudy-microsoft@MacBook-Pro-m4-Gaudy-For-Work.local>
369 lines
12 KiB
Python
369 lines
12 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""App logic module."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import TYPE_CHECKING
|
|
|
|
import streamlit as st
|
|
from knowledge_loader.data_sources.loader import (
|
|
create_datasource,
|
|
load_dataset_listing,
|
|
)
|
|
from knowledge_loader.model import load_model
|
|
from rag.typing import SearchResult, SearchType
|
|
from state.session_variables import SessionVariables
|
|
from ui.search import display_search_result
|
|
|
|
import graphrag.api as api
|
|
|
|
if TYPE_CHECKING:
|
|
import pandas as pd
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logging.getLogger("azure").setLevel(logging.WARNING)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def initialize() -> SessionVariables:
|
|
"""Initialize app logic."""
|
|
if "session_variables" not in st.session_state:
|
|
st.set_page_config(
|
|
layout="wide",
|
|
initial_sidebar_state="collapsed",
|
|
page_title="GraphRAG",
|
|
)
|
|
sv = SessionVariables()
|
|
datasets = load_dataset_listing()
|
|
sv.datasets.value = datasets
|
|
sv.dataset.value = (
|
|
st.query_params["dataset"].lower()
|
|
if "dataset" in st.query_params
|
|
else datasets[0].key
|
|
)
|
|
load_dataset(sv.dataset.value, sv)
|
|
st.session_state["session_variables"] = sv
|
|
return st.session_state["session_variables"]
|
|
|
|
|
|
def load_dataset(dataset: str, sv: SessionVariables):
|
|
"""Load dataset from the dropdown."""
|
|
sv.dataset.value = dataset
|
|
sv.dataset_config.value = next(
|
|
(d for d in sv.datasets.value if d.key == dataset), None
|
|
)
|
|
if sv.dataset_config.value is not None:
|
|
sv.datasource.value = create_datasource(f"{sv.dataset_config.value.path}") # type: ignore
|
|
sv.graphrag_config.value = sv.datasource.value.read_settings("settings.yaml")
|
|
load_knowledge_model(sv)
|
|
|
|
|
|
def dataset_name(key: str, sv: SessionVariables) -> str:
|
|
"""Get dataset name."""
|
|
return next((d for d in sv.datasets.value if d.key == key), None).name # type: ignore
|
|
|
|
|
|
async def run_all_searches(query: str, sv: SessionVariables) -> list[SearchResult]:
|
|
"""Run all search engines and return the results."""
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
tasks = []
|
|
if sv.include_drift_search.value:
|
|
tasks.append(
|
|
run_drift_search(
|
|
query=query,
|
|
sv=sv,
|
|
)
|
|
)
|
|
|
|
if sv.include_basic_rag.value:
|
|
tasks.append(
|
|
run_basic_search(
|
|
query=query,
|
|
sv=sv,
|
|
)
|
|
)
|
|
if sv.include_local_search.value:
|
|
tasks.append(
|
|
run_local_search(
|
|
query=query,
|
|
sv=sv,
|
|
)
|
|
)
|
|
if sv.include_global_search.value:
|
|
tasks.append(
|
|
run_global_search(
|
|
query=query,
|
|
sv=sv,
|
|
)
|
|
)
|
|
|
|
return await asyncio.gather(*tasks)
|
|
|
|
|
|
async def run_generate_questions(query: str, sv: SessionVariables):
|
|
"""Run global search to generate questions for the dataset."""
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
tasks = []
|
|
|
|
tasks.append(
|
|
run_global_search_question_generation(
|
|
query=query,
|
|
sv=sv,
|
|
)
|
|
)
|
|
|
|
return await asyncio.gather(*tasks)
|
|
|
|
|
|
async def run_global_search_question_generation(
|
|
query: str,
|
|
sv: SessionVariables,
|
|
) -> SearchResult:
|
|
"""Run global search question generation process."""
|
|
empty_context_data: dict[str, pd.DataFrame] = {}
|
|
|
|
response, context_data = await api.global_search(
|
|
config=sv.graphrag_config.value,
|
|
entities=sv.entities.value,
|
|
communities=sv.communities.value,
|
|
community_reports=sv.community_reports.value,
|
|
dynamic_community_selection=True,
|
|
response_type="Single paragraph",
|
|
community_level=sv.dataset_config.value.community_level,
|
|
query=query,
|
|
)
|
|
|
|
# display response and reference context to UI
|
|
return SearchResult(
|
|
search_type=SearchType.Global,
|
|
response=str(response),
|
|
context=context_data if isinstance(context_data, dict) else empty_context_data,
|
|
)
|
|
|
|
|
|
async def run_local_search(
|
|
query: str,
|
|
sv: SessionVariables,
|
|
) -> SearchResult:
|
|
"""Run local search."""
|
|
print(f"Local search query: {query}") # noqa T201
|
|
|
|
# build local search engine
|
|
response_placeholder = st.session_state[
|
|
f"{SearchType.Local.value.lower()}_response_placeholder"
|
|
]
|
|
response_container = st.session_state[f"{SearchType.Local.value.lower()}_container"]
|
|
|
|
with response_placeholder, st.spinner("Generating answer using local search..."):
|
|
empty_context_data: dict[str, pd.DataFrame] = {}
|
|
|
|
response, context_data = await api.local_search(
|
|
config=sv.graphrag_config.value,
|
|
communities=sv.communities.value,
|
|
entities=sv.entities.value,
|
|
community_reports=sv.community_reports.value,
|
|
text_units=sv.text_units.value,
|
|
relationships=sv.relationships.value,
|
|
covariates=sv.covariates.value,
|
|
community_level=sv.dataset_config.value.community_level,
|
|
response_type="Multiple Paragraphs",
|
|
query=query,
|
|
)
|
|
|
|
print(f"Local Response: {response}") # noqa T201
|
|
print(f"Context data: {context_data}") # noqa T201
|
|
|
|
# display response and reference context to UI
|
|
search_result = SearchResult(
|
|
search_type=SearchType.Local,
|
|
response=str(response),
|
|
context=context_data if isinstance(context_data, dict) else empty_context_data,
|
|
)
|
|
|
|
display_search_result(
|
|
container=response_container, result=search_result, stats=None
|
|
)
|
|
|
|
if "response_lengths" not in st.session_state:
|
|
st.session_state.response_lengths = []
|
|
|
|
st.session_state["response_lengths"].append({
|
|
"result": search_result,
|
|
"search": SearchType.Local.value.lower(),
|
|
})
|
|
|
|
return search_result
|
|
|
|
|
|
async def run_global_search(query: str, sv: SessionVariables) -> SearchResult:
|
|
"""Run global search."""
|
|
print(f"Global search query: {query}") # noqa T201
|
|
|
|
# build global search engine
|
|
response_placeholder = st.session_state[
|
|
f"{SearchType.Global.value.lower()}_response_placeholder"
|
|
]
|
|
response_container = st.session_state[
|
|
f"{SearchType.Global.value.lower()}_container"
|
|
]
|
|
|
|
response_placeholder.empty()
|
|
with response_placeholder, st.spinner("Generating answer using global search..."):
|
|
empty_context_data: dict[str, pd.DataFrame] = {}
|
|
|
|
response, context_data = await api.global_search(
|
|
config=sv.graphrag_config.value,
|
|
entities=sv.entities.value,
|
|
communities=sv.communities.value,
|
|
community_reports=sv.community_reports.value,
|
|
dynamic_community_selection=False,
|
|
response_type="Multiple Paragraphs",
|
|
community_level=sv.dataset_config.value.community_level,
|
|
query=query,
|
|
)
|
|
|
|
print(f"Context data: {context_data}") # noqa T201
|
|
print(f"Global Response: {response}") # noqa T201
|
|
|
|
# display response and reference context to UI
|
|
search_result = SearchResult(
|
|
search_type=SearchType.Global,
|
|
response=str(response),
|
|
context=context_data if isinstance(context_data, dict) else empty_context_data,
|
|
)
|
|
|
|
display_search_result(
|
|
container=response_container, result=search_result, stats=None
|
|
)
|
|
|
|
if "response_lengths" not in st.session_state:
|
|
st.session_state.response_lengths = []
|
|
|
|
st.session_state["response_lengths"].append({
|
|
"result": search_result,
|
|
"search": SearchType.Global.value.lower(),
|
|
})
|
|
|
|
return search_result
|
|
|
|
|
|
async def run_drift_search(
|
|
query: str,
|
|
sv: SessionVariables,
|
|
) -> SearchResult:
|
|
"""Run drift search."""
|
|
print(f"Drift search query: {query}") # noqa T201
|
|
|
|
# build drift search engine
|
|
response_placeholder = st.session_state[
|
|
f"{SearchType.Drift.value.lower()}_response_placeholder"
|
|
]
|
|
response_container = st.session_state[f"{SearchType.Drift.value.lower()}_container"]
|
|
|
|
with response_placeholder, st.spinner("Generating answer using drift search..."):
|
|
empty_context_data: dict[str, pd.DataFrame] = {}
|
|
|
|
response, context_data = await api.drift_search(
|
|
config=sv.graphrag_config.value,
|
|
entities=sv.entities.value,
|
|
communities=sv.communities.value,
|
|
community_reports=sv.community_reports.value,
|
|
text_units=sv.text_units.value,
|
|
relationships=sv.relationships.value,
|
|
community_level=sv.dataset_config.value.community_level,
|
|
response_type="Multiple Paragraphs",
|
|
query=query,
|
|
)
|
|
|
|
print(f"Drift Response: {response}") # noqa T201
|
|
print(f"Context data: {context_data}") # noqa T201
|
|
|
|
# display response and reference context to UI
|
|
search_result = SearchResult(
|
|
search_type=SearchType.Drift,
|
|
response=str(response),
|
|
context=context_data if isinstance(context_data, dict) else empty_context_data,
|
|
)
|
|
|
|
display_search_result(
|
|
container=response_container, result=search_result, stats=None
|
|
)
|
|
|
|
if "response_lengths" not in st.session_state:
|
|
st.session_state.response_lengths = []
|
|
|
|
st.session_state["response_lengths"].append({
|
|
"result": None,
|
|
"search": SearchType.Drift.value.lower(),
|
|
})
|
|
|
|
return search_result
|
|
|
|
|
|
async def run_basic_search(
|
|
query: str,
|
|
sv: SessionVariables,
|
|
) -> SearchResult:
|
|
"""Run basic search."""
|
|
print(f"Basic search query: {query}") # noqa T201
|
|
|
|
# build local search engine
|
|
response_placeholder = st.session_state[
|
|
f"{SearchType.Basic.value.lower()}_response_placeholder"
|
|
]
|
|
response_container = st.session_state[f"{SearchType.Basic.value.lower()}_container"]
|
|
|
|
with response_placeholder, st.spinner("Generating answer using basic RAG..."):
|
|
empty_context_data: dict[str, pd.DataFrame] = {}
|
|
|
|
response, context_data = await api.basic_search(
|
|
config=sv.graphrag_config.value,
|
|
text_units=sv.text_units.value,
|
|
query=query,
|
|
)
|
|
|
|
print(f"Basic Response: {response}") # noqa T201
|
|
print(f"Context data: {context_data}") # noqa T201
|
|
|
|
# display response and reference context to UI
|
|
search_result = SearchResult(
|
|
search_type=SearchType.Basic,
|
|
response=str(response),
|
|
context=context_data if isinstance(context_data, dict) else empty_context_data,
|
|
)
|
|
|
|
display_search_result(
|
|
container=response_container, result=search_result, stats=None
|
|
)
|
|
|
|
if "response_lengths" not in st.session_state:
|
|
st.session_state.response_lengths = []
|
|
|
|
st.session_state["response_lengths"].append({
|
|
"search": SearchType.Basic.value.lower(),
|
|
"result": search_result,
|
|
})
|
|
|
|
return search_result
|
|
|
|
|
|
def load_knowledge_model(sv: SessionVariables):
|
|
"""Load knowledge model from the datasource."""
|
|
print("Loading knowledge model...", sv.dataset.value, sv.dataset_config.value) # noqa T201
|
|
model = load_model(sv.dataset.value, sv.datasource.value)
|
|
|
|
sv.generated_questions.value = []
|
|
sv.selected_question.value = ""
|
|
sv.entities.value = model.entities
|
|
sv.relationships.value = model.relationships
|
|
sv.covariates.value = model.covariates
|
|
sv.community_reports.value = model.community_reports
|
|
sv.communities.value = model.communities
|
|
sv.text_units.value = model.text_units
|
|
|
|
return sv
|