Update frontend UI app (#174)

Co-authored-by: dorbaker <dorbaker@microsoft.com>
This commit is contained in:
Josh Bradley 2024-09-19 01:09:26 -04:00 committed by GitHub
parent 38096c8e86
commit 680cfc055e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 3657 additions and 700 deletions

View File

@ -12,15 +12,16 @@
],
"remoteUser": "vscode",
"remoteEnv": {
// We add the .venv to the beginning of the path env in the Dockerfile
// so that we use the proper python, however vscode rewrites/overwrites
// the PATH in the image and puts /usr/local/bin in front of our .venv
// path. This fixes that issue.
"PATH": "${containerEnv:PATH}",
// Add src folder to PYTHONPATH so that we can import modules that
// are in the source dir
"PYTHONPATH": "/graphrag-accelerator/backend/:$PATH"
},
// We add the .venv to the beginning of the path env in the Dockerfile
// so that we use the proper python, however vscode rewrites/overwrites
// the PATH in the image and puts /usr/local/bin in front of our .venv
// path. This fixes that issue.
"PATH": "${containerEnv:PATH}",
// Add src folder to PYTHONPATH so that we can import modules that
// are in the source dir
"PYTHONPATH": "/graphrag-accelerator/backend/:$PATH",
"AZURE_CLI_DISABLE_CONNECTION_VERIFICATION": "1"
},
"mounts": [
// NOTE: we reference both HOME and USERPROFILE environment variables to simultaneously support both Windows and Unix environments
// in most default situations, only one variable will exist (Windows has USERPROFILE and unix has HOME) and a reference to the other variable will result in an empty string

View File

@ -16,6 +16,7 @@ from kubernetes import (
client,
config,
)
from src.api.azure_clients import AzureStorageClientManager
from src.api.common import sanitize_name
from src.models import PipelineJob

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,6 @@ authors = [
"Gabriel Nieves <gnievesponce@microsoft.com>",
"Douglas Orbaker <dorbaker@microsoft.com>",
"Chris Sanchez <chrissanchez@microsoft.com>",
"Katy Smith <katysmith@microsoft.com>",
"Shane Solomon <shane.solomon@microsoft.com>",
]
readme = "README.md"
@ -33,7 +32,7 @@ pytest = ">=8.2.1"
wikipedia = ">=1.4.0"
[tool.poetry.group.backend.dependencies]
adlfs = ">=2023.10.0"
adlfs = ">=2024.7.0"
applicationinsights = ">=0.11.10"
attrs = ">=23.2.0"
azure-core = ">=1.30.1"
@ -66,12 +65,6 @@ tiktoken = ">=0.6.0"
uvicorn = ">=0.23.2"
urllib3 = ">=2.2.2"
[tool.poetry.group.frontend.dependencies]
python-dotenv = ">=0.19.1"
requests = "*"
streamlit = ">=0.88.0"
streamlit-nested-layout = "==0.1.3"
[tool.ruff]
target-version = "py310"
line-length = 88

View File

@ -127,7 +127,9 @@ async def setup_indexing_pipeline(
)
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
if PipelineJobState(existing_job.status) == PipelineJobState.FAILED:
_delete_k8s_job(f"indexing-job-{sanitized_index_name}", os.environ["AKS_NAMESPACE"])
_delete_k8s_job(
f"indexing-job-{sanitized_index_name}", os.environ["AKS_NAMESPACE"]
)
# reset the pipeline job details
existing_job._status = PipelineJobState.SCHEDULED
existing_job._percent_complete = 0

View File

@ -68,11 +68,15 @@ async def lifespan(app: FastAPI):
] = pod.spec.service_account_name
# retrieve list of existing cronjobs
batch_v1 = client.BatchV1Api()
namespace_cronjobs = batch_v1.list_namespaced_cron_job(namespace=os.environ["AKS_NAMESPACE"])
namespace_cronjobs = batch_v1.list_namespaced_cron_job(
namespace=os.environ["AKS_NAMESPACE"]
)
cronjob_names = [cronjob.metadata.name for cronjob in namespace_cronjobs.items]
# create cronjob if it does not exist
if manifest["metadata"]["name"] not in cronjob_names:
batch_v1.create_namespaced_cron_job(namespace=os.environ["AKS_NAMESPACE"], body=manifest)
batch_v1.create_namespaced_cron_job(
namespace=os.environ["AKS_NAMESPACE"], body=manifest
)
except Exception as e:
print("Failed to create graphrag cronjob.")
reporter = ReporterSingleton().get_instance()

View File

@ -10,6 +10,7 @@ from azure.cosmos import (
)
from azure.storage.blob import BlobServiceClient
from azure.storage.blob.aio import BlobServiceClient as BlobServiceClientAsync
from src.api.azure_clients import AzureStorageClientManager

View File

@ -11,11 +11,11 @@ ENV PIP_DISABLE_PIP_VERSION_CHECK=1
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
ENV PYTHONPATH=/backend
COPY poetry.lock pyproject.toml /
COPY backend /backend
RUN pip install poetry \
RUN cd backend \
&& pip install poetry \
&& poetry config virtualenvs.create false \
&& poetry install --without frontend
&& poetry install
# download all nltk data that graphrag requires
RUN python -m nltk.downloader punkt averaged_perceptron_tagger maxent_ne_chunker words wordnet

View File

@ -7,11 +7,11 @@ ENV PIP_ROOT_USER_ACTION=ignore
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
COPY poetry.lock pyproject.toml /
COPY frontend /frontend
RUN pip install poetry \
RUN cd frontend \
&& pip install poetry \
&& poetry config virtualenvs.create false \
&& poetry install --without backend
&& poetry install
WORKDIR /frontend
EXPOSE 8080

View File

@ -4,6 +4,7 @@
import os
import streamlit as st
from src.components import tabs
from src.components.index_pipeline import IndexPipeline
from src.enums import EnvVars
@ -16,29 +17,28 @@ st.session_state["initialized"] = True if initialized else False
def graphrag_app(initialized: bool):
# main entry point for app interface
st.title("Microsoft GraphRAG Copilot")
main_tab, prompt_gen_tab, prompt_edit_tab, index_tab, query_tab = st.tabs(
[
"**Intro**",
"**1. Prompt Generation**",
"**2. Prompt Configuration**",
"**3. Index**",
"**4. Query**",
]
)
main_tab, prompt_gen_tab, prompt_edit_tab, index_tab, query_tab = st.tabs([
"**Intro**",
"**1. Prompt Generation**",
"**2. Prompt Configuration**",
"**3. Index**",
"**4. Query**",
])
# display only the main tab if a connection to an existing APIM has not been initialized
with main_tab:
tabs.get_main_tab(initialized)
# if not initialized, only main tab is displayed
if initialized:
# assign API request information
# setup API request information
COLUMN_WIDTHS = [0.275, 0.45, 0.275]
api_url = st.session_state[EnvVars.DEPLOYMENT_URL.value]
apim_url = st.session_state[EnvVars.DEPLOYMENT_URL.value]
apim_key = st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value]
client = GraphragAPI(api_url, apim_key)
# perform health check to verify connectivity
client = GraphragAPI(apim_url, apim_key)
if not client.health_check_passed():
st.error("APIM Connection Error")
st.stop()
indexPipe = IndexPipeline(client, COLUMN_WIDTHS)
# display tabs
with prompt_gen_tab:
tabs.get_prompt_generation_tab(client, COLUMN_WIDTHS)
@ -48,9 +48,7 @@ def graphrag_app(initialized: bool):
tabs.get_index_tab(indexPipe)
with query_tab:
tabs.get_query_tab(client)
deployer_email = os.getenv("DEPLOYER_EMAIL", "deployer@email.com")
footer = f"""
<div class="footer">
<p> Responses may be inaccurate; please review all responses for accuracy. Learn more about Azure OpenAI code of conduct <a href="https://learn.microsoft.com/en-us/legal/cognitive-services/openai/code-of-conduct"> here</a>. </br> For feedback, email us at <a href="mailto:{deployer_email}">{deployer_email}</a>.</p>

3109
frontend/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

55
frontend/pyproject.toml Normal file
View File

@ -0,0 +1,55 @@
[tool.poetry]
name = "graphrag-solution-accelerator"
version = "0.1.1"
description = ""
authors = [
"Josh Bradley <joshbradley@microsoft.com>",
"Newman Cheng <newmancheng@microsoft.com>",
"Christine DiFonzo <cdifonzo@microsoft.com>",
"Gabriel Nieves <gnievesponce@microsoft.com>",
"Douglas Orbaker <dorbaker@microsoft.com>",
"Chris Sanchez <chrissanchez@microsoft.com>",
"Shane Solomon <shane.solomon@microsoft.com>",
]
readme = "README.md"
license = "MIT"
package-mode = false
[tool.poetry.dependencies]
python = "~3.10"
[tool.poetry.group.dev.dependencies]
detect-secrets = ">=1.5.0"
devtools = ">=0.12.2"
flake8 = ">=6.1.0"
ipython = "*"
jupyter = "*"
pre-commit = ">=3.6.0"
ruff = ">=0.1.13"
[tool.poetry.group.test.dependencies]
pytest = ">=8.2.1"
wikipedia = ">=1.4.0"
[tool.poetry.group.frontend.dependencies]
requests = "*"
streamlit = ">=1.38.0"
streamlit-nested-layout = "*"
[tool.ruff]
target-version = "py310"
line-length = 88
indent-width = 4
[tool.ruff.format]
preview = true
quote-style = "double"
[tool.ruff.lint]
preview = true
select = ["E", "F", "I"]
ignore = ["E402", "E501", "F821"]
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -50,7 +50,6 @@ class IndexPipeline:
key_prefix="index",
disable_other_input=disable_other_input,
)
if select_storage_name != "":
disable_other_input = True
@ -138,7 +137,6 @@ class IndexPipeline:
divider=True,
help="Select an index to check the status of what stage indexing is in. Indexing must be complete in order to be able to execute queries.",
)
options_indexes = self.client.get_index_names()
# create logic for defaulting to running job index if one exists
new_index_name = st.session_state["index-name-input"]

View File

@ -26,8 +26,7 @@ def login():
form_submit = st.form_submit_button("Login")
if form_submit:
client = GraphragAPI(apim_url, apim_sub_key)
status_code = client.health_check()
if status_code == 200:
if client.health_check_passed():
st.success("Login Successful")
st.session_state[EnvVars.DEPLOYMENT_URL.value] = apim_url
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = apim_sub_key
@ -36,4 +35,3 @@ def login():
else:
st.error("Login Failed")
st.error("Please check the APIM Gateway URL and Subscription Key")
return status_code

View File

@ -20,7 +20,7 @@ class GraphQuery:
def search(
self,
query_type: Literal["Global Streaming", "Global", "Local"],
query_type: Literal["Global Streaming", "Local Streaming", "Global", "Local"],
search_index: str | list[str],
query: str,
) -> None:
@ -51,6 +51,8 @@ class GraphQuery:
match query_type:
case "Global Streaming":
_ = self.global_streaming_search(search_index, query)
case "Local Streaming":
_ = self.local_streaming_search(search_index, query)
case "Global":
_ = self.global_search(search_index, query)
case "Local":
@ -109,7 +111,73 @@ class GraphQuery:
"Double-click on content to expand text", "red", False
)
)
self._build_st_dataframe(context_list)
self._build_st_dataframe(
context_list[0]["reports"], drop_columns=[]
)
else:
print(query_response.reason, query_response.content)
raise Exception("Received unexpected response from server")
def local_streaming_search(self, search_index: str | list[str], query: str) -> None:
"""
Executes a local streaming query on the specified index.
Handles the response and displays the results in the Streamlit app.
"""
query_response = self.client.local_streaming_query(search_index, query)
assistant_response = ""
context_list = []
if query_response.status_code == 200:
text_placeholder = st.empty()
for chunk in query_response.iter_lines(
# allow up to 256KB to avoid excessive many reads
chunk_size=256 * GraphQuery.KILOBYTE,
decode_unicode=True,
):
try:
payload = json.loads(chunk)
except json.JSONDecodeError as e:
# In the event that a chunk is not a complete JSON object,
# document it for further analysis.
print(chunk)
raise e
token = payload["token"]
context = payload["context"]
if (token != "<EOM>") and (context is None):
assistant_response += token
text_placeholder.write(assistant_response)
elif (token == "<EOM>") and (context is not None):
context_list.append(context)
if not assistant_response:
st.write(
self.format_md_text(
"Not enough contextual data to support your query: No results found.\tTry another query.",
"red",
True,
)
)
return
else:
with self._create_section_expander("Query Context"):
st.write(
self.format_md_text(
"Double-click on content to expand text", "red", False
)
)
self._build_st_dataframe(
context_list[0]["reports"], drop_columns=[]
)
self._build_st_dataframe(
context_list[0]["entities"], drop_columns=[]
)
self._build_st_dataframe(
context_list[0]["relationships"], drop_columns=[]
)
self._build_st_dataframe(
context_list[0]["sources"], drop_columns=[]
)
else:
print(query_response.reason, query_response.content)
raise Exception("Received unexpected response from server")
@ -216,14 +284,12 @@ class GraphQuery:
drop_columns: list[str] = ["id", "index_id", "index_name", "in_context"],
entity_df: bool = False,
rel_df: bool = False,
) -> st.dataframe:
) -> st.dataframe: # type: ignore
df_context = (
data if isinstance(data, pd.DataFrame) else pd.DataFrame.from_records(data)
)
if any(drop_columns):
for column in drop_columns:
if column in df_context.columns:
df_context = df_context.drop(column, axis=1)
df_context.drop(columns=drop_columns, inplace=True, axis=1, errors="ignore")
if entity_df:
return st.dataframe(
df_context,
@ -267,7 +333,7 @@ class GraphQuery:
def _create_section_expander(
self, title: str, color: str = "blue", bold: bool = True, expanded: bool = False
) -> st.expander:
) -> st.expander: # type: ignore
"""
Creates an expander in the Streamlit app with the specified title and content.
"""

View File

@ -63,8 +63,6 @@ def get_prompt_generation_tab(
"""
Displays content of Prompt Generation Tab
"""
# hard set limit to 5 files to reduce overly long processing times and to reduce over sampling errors.
num_chunks = num_chunks if num_chunks <= 5 else 5
_, col2, _ = st.columns(column_widths)
with col2:
st.header(
@ -133,24 +131,10 @@ def get_prompt_generation_tab(
"Prompts generated successfully! Move on to the next tab to configure the prompts."
)
else:
# assume limit parameter is too high
# limit parameter was too high
st.warning(
"You do not have enough data to generate prompts. Retrying with a smaller sample size."
)
while num_chunks > 1:
num_chunks -= 1
generated = generate_and_extract_prompts(
client=client,
storage_name=select_prompt_storage,
limit=num_chunks,
)
if not isinstance(generated, Exception):
st.success(
"Prompts generated successfully! Move on to the next tab to configure the prompts."
)
break
else:
st.warning(f"Retrying with sample size: {num_chunks}")
def get_prompt_configuration_tab(
@ -237,7 +221,7 @@ def get_query_tab(client: GraphragAPI) -> None:
with col1:
query_type = st.selectbox(
"Query Type",
["Global Streaming", "Global", "Local"],
["Global Streaming", "Local Streaming", "Global", "Local"],
help="Select the query type - Each yeilds different results of specificity. Global queries focus on the entire graph structure. Local queries focus on a set of communities (subgraphs) in the graph that are more connected to each other than they are to the rest of the graph structure and can focus on very specific entities in the graph. Global streaming is a global query that displays results as they appear live.",
)
with col2:
@ -253,7 +237,6 @@ def get_query_tab(client: GraphragAPI) -> None:
disabled = True if not any(select_index_search) else False
col3, col4 = st.columns([0.8, 0.2])
with col3:
search_bar = st.text_input("Query", key="search-query", disabled=disabled)
with col4:

View File

@ -4,30 +4,30 @@
from enum import Enum
class PromptKeys(Enum):
class PromptKeys(str, Enum):
ENTITY = "entity_extraction"
SUMMARY = "summarize_descriptions"
COMMUNITY = "community_report"
class PromptFileNames(Enum):
class PromptFileNames(str, Enum):
ENTITY = "entity_extraction_prompt.txt"
SUMMARY = "summarize_descriptions_prompt.txt"
COMMUNITY = "community_report_prompt.txt"
class PromptTextAreas(Enum):
class PromptTextAreas(str, Enum):
ENTITY = "entity_text_area"
SUMMARY = "summary_text_area"
COMMUNITY = "community_text_area"
class StorageIndexVars(Enum):
class StorageIndexVars(str, Enum):
SELECTED_STORAGE = "selected_storage"
INPUT_STORAGE = "input_storage"
SELECTED_INDEX = "selected_index"
class EnvVars(Enum):
class EnvVars(str, Enum):
APIM_SUBSCRIPTION_KEY = "APIM_SUBSCRIPTION_KEY"
DEPLOYMENT_URL = "DEPLOYMENT_URL"

View File

@ -7,7 +7,6 @@ from typing import Optional
from zipfile import ZipFile
import streamlit as st
from dotenv import find_dotenv, load_dotenv
from src.enums import EnvVars, PromptKeys, StorageIndexVars
from src.graphrag_api import GraphragAPI
@ -17,7 +16,7 @@ This module contains functions that are used across the Streamlit app.
"""
def initialize_app(env_file: str = ".env", css_file: str = "style.css") -> bool:
def initialize_app(css_file: str = "style.css") -> bool:
"""
Initialize the Streamlit app with the necessary configurations.
"""
@ -31,10 +30,7 @@ def initialize_app(env_file: str = ".env", css_file: str = "style.css") -> bool:
# initialize session state variables
set_session_state_variables()
# load environment variables
_ = load_dotenv(find_dotenv(filename=env_file) or None, override=True)
# either load from .env file or from session state
# load settings from environment variables
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = os.getenv(
EnvVars.APIM_SUBSCRIPTION_KEY.value,
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value],

View File

@ -131,17 +131,16 @@ class GraphragAPI:
except Exception as e:
print(f"Error: {str(e)}")
def health_check(self) -> int | Response:
def health_check_passed(self) -> bool:
"""
Check the health of the APIM endpoint.
"""
url = self.api_url + "/health"
try:
response = requests.get(url, headers=self.headers)
return response.status_code
except Exception as e:
print(f"Error: {str(e)}")
return e
return response.ok
except Exception:
return False
def query_index(self, index_name: str | list[str], query_type: str, query: str):
"""
@ -174,7 +173,25 @@ class GraphragAPI:
"""
Returns a streaming response object for a global query.
"""
url = f"{self.api_url}/experimental/query/global/streaming"
url = f"{self.api_url}/query/streaming/global"
try:
query_response = requests.post(
url,
json={"index_name": index_name, "query": query},
headers=self.headers,
stream=True,
)
return query_response
except Exception as e:
print(f"Error: {str(e)}")
def local_streaming_query(
self, index_name: str | list[str], query: str
) -> Response | None:
"""
Returns a streaming response object for a global query.
"""
url = f"{self.api_url}/query/streaming/local"
try:
query_response = requests.post(
url,