Add frontend application (#68)

Co-authored-by: americanthinker <americanthinker@gmail.com>
Co-authored-by: Tim <timothymeyers@users.noreply.github.com>
Co-authored-by: Christine Caggiano <cdifonzo@microsoft.com>
This commit is contained in:
Josh Bradley 2024-07-10 03:43:22 -04:00 committed by GitHub
parent 5dd5060d32
commit 0abbfb2a5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1629 additions and 1 deletions

View File

@ -52,3 +52,16 @@ jobs:
context: .
file: docker/Dockerfile-backend
push: false
build-frontend:
needs: [lint-check]
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Build docker image
uses: docker/build-push-action@v2
with:
context: .
file: docker/Dockerfile-frontend
push: false

2
.gitignore vendored
View File

@ -166,4 +166,4 @@ main.parameters.json
**/charts/*.tgz
**/Chart.lock
.history
.history

View File

@ -0,0 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
FROM python:3.10
ENV PIP_ROOT_USER_ACTION=ignore
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
COPY poetry.lock pyproject.toml /
COPY frontend /frontend
RUN pip install poetry \
&& poetry config virtualenvs.create false \
&& poetry install --without backend
WORKDIR /frontend
EXPOSE 8080
CMD ["streamlit", "run", "app.py", "--server.port", "8080"]

View File

@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
[server]
enableXsrfProtection = false

26
frontend/README.md Normal file
View File

@ -0,0 +1,26 @@
# Frontend Application Launch Instructions
A small frontend application, a streamlit app, is provided to demonstrate how to build a UI on top of the solution accelerator API.
### 1. Deploy the GraphRAG solution accelerator
Follow instructions from the [deployment guide](../docs/DEPLOYMENT-GUIDE.md) to deploy a full instance of the solution accelerator.
### 2. (optional) Create a `.env` file:
| Variable Name | Required | Example | Description |
| :--- | --- | :--- | ---: |
DEPLOYMENT_URL | No | https://<my_apim>.azure-api.net | Base url of the deployed graphrag API. Also referred to as the APIM Gateway URL.
APIM_SUBSCRIPTION_KEY | No | <subscription_key> | A [subscription key](https://learn.microsoft.com/en-us/azure/api-management/api-management-subscriptions) generated by APIM.
DEPLOYER_EMAIL | No | deployer@email.com | Email address of the person/organization that deployed the solution accelerator.
### 3. Start UI
The frontend application can be run locally as a docker container. If a `.env` file is not provided, the UI will prompt the user for additional information.
```
# cd to the root directory of the repo
> docker build -t graphrag:frontend -f docker/Dockerfile-frontend .
> docker run --env-file <env_file> -p 8080:8080 graphrag:frontend
```
To access the app , visit `localhost:8080` in your browser.
This UI application can also be hosted in Azure as a [Web App](https://azure.microsoft.com/en-us/products/app-service/web).

63
frontend/app.py Normal file
View File

@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import streamlit as st
from src.components import tabs
from src.components.index_pipeline import IndexPipeline
from src.enums import EnvVars
from src.functions import initialize_app
from src.graphrag_api import GraphragAPI
# Load environment variables
initialized = initialize_app()
st.session_state["initialized"] = True if initialized else False
def graphrag_app(initialized: bool):
# main entry point for app interface
st.title("Microsoft GraphRAG Copilot")
main_tab, prompt_gen_tab, prompt_edit_tab, index_tab, query_tab = st.tabs(
[
"**Intro**",
"**1. Prompt Generation**",
"**2. Prompt Configuration**",
"**3. Index**",
"**4. Query**",
]
)
with main_tab:
tabs.get_main_tab(initialized)
# if not initialized, only main tab is displayed
if initialized:
# assign API request information
COLUMN_WIDTHS = [0.275, 0.45, 0.275]
api_url = st.session_state[EnvVars.DEPLOYMENT_URL.value]
apim_key = st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value]
client = GraphragAPI(api_url, apim_key)
indexPipe = IndexPipeline(client, COLUMN_WIDTHS)
# display tabs
with prompt_gen_tab:
tabs.get_prompt_generation_tab(client, COLUMN_WIDTHS)
with prompt_edit_tab:
tabs.get_prompt_configuration_tab()
with index_tab:
tabs.get_index_tab(indexPipe)
with query_tab:
tabs.get_query_tab(client)
deployer_email = os.getenv("DEPLOYER_EMAIL", "deployer@email.com")
footer = f"""
<div class="footer">
<p> Responses may be inaccurate; please review all responses for accuracy. Learn more about Azure OpenAI code of conduct <a href="https://learn.microsoft.com/en-us/legal/cognitive-services/openai/code-of-conduct"> here</a>. </br> For feedback, email us at <a href="mailto:{deployer_email}">{deployer_email}</a>.</p>
</div>
"""
st.markdown(footer, unsafe_allow_html=True)
if __name__ == "__main__":
graphrag_app(st.session_state["initialized"])

0
frontend/src/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,206 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from io import StringIO
import streamlit as st
from src.components.upload_files_component import upload_files
from src.enums import PromptKeys
from src.functions import GraphragAPI
class IndexPipeline:
def __init__(self, client: GraphragAPI, column_widths: list[float]) -> None:
self.client = client
self.containers = client.get_storage_container_names()
self.column_widths = column_widths
def storage_data_step(self):
"""
Builds the Storage Data Step for the Indexing Pipeline.
"""
disable_other_input = False
_, col2, _ = st.columns(self.column_widths)
with col2:
st.header(
"1. Data Storage",
divider=True,
help="Select a Data Storage Container to upload data to or select an existing container to use for indexing. The data will be processed by the LLM to create a Knowledge Graph.",
)
select_storage_name = st.selectbox(
label="Select an existing Storage Container.",
options=[""] + self.containers
if isinstance(self.containers, list)
else [],
key="index-storage",
index=0,
)
if select_storage_name != "":
disable_other_input = True
st.write("Or...")
with st.expander("Upload data to a storage container."):
# TODO: validate storage container name before uploading
# TODO: add user message that option not available while existing storage container is selected
upload_files(
self.client,
key_prefix="index",
disable_other_input=disable_other_input,
)
if select_storage_name != "":
disable_other_input = True
def build_index_step(self):
"""
Creates the Build Index Step for the Indexing Pipeline.
"""
_, col2, _ = st.columns(self.column_widths)
with col2:
st.header(
"2. Build Index",
divider=True,
help="Building an index will process the data from step 1 and create a Knowledge Graph suitable for querying. The LLM will use either the default prompt configuration or the prompts that you generated previously. To track the status of an indexing job, use the check index status below.",
)
# use data from either the selected storage container or the uploaded data
select_storage_name = st.session_state["index-storage"]
input_storage_name = (
st.session_state["index-storage-name-input"]
if st.session_state["index-upload-button"]
else ""
)
storage_selection = select_storage_name or input_storage_name
# Allow user to choose either default or custom prompts
custom_prompts = any([st.session_state[k.value] for k in PromptKeys])
prompt_options = ["Default", "Custom"] if custom_prompts else ["Default"]
prompt_choice = st.radio(
"Choose LLM Prompt Configuration",
options=prompt_options,
index=1 if custom_prompts else 0,
key="prompt-config-choice",
horizontal=True,
)
# Create new index name
index_name = st.text_input("Enter Index Name", key="index-name-input")
st.write(f"Selected Storage Container: **:blue[{storage_selection}]**")
if st.button(
"Build Index",
help="You must enter both an Index Name and Select a Storage Container to enable this button",
disabled=not index_name or not storage_selection,
):
entity_prompt = (
StringIO(st.session_state[PromptKeys.ENTITY.value])
if prompt_choice == "Custom"
else None
)
summarize_prompt = (
StringIO(st.session_state[PromptKeys.SUMMARY.value])
if prompt_choice == "Custom"
else None
)
community_prompt = (
StringIO(st.session_state[PromptKeys.COMMUNITY.value])
if prompt_choice == "Custom"
else None
)
response = self.client.build_index(
storage_name=storage_selection,
index_name=index_name,
entity_extraction_prompt_filepath=entity_prompt,
summarize_description_prompt_filepath=summarize_prompt,
community_prompt_filepath=community_prompt,
)
if response.status_code == 200:
st.success(
f"Job submitted successfully, using {prompt_choice} prompts!"
)
else:
st.error(
f"Failed to submit job.\nStatus: {response.json()['detail']}"
)
def check_status_step(self):
"""
Checks the progress of a running indexing job.
"""
_, col2, _ = st.columns(self.column_widths)
with col2:
st.header(
"3. Check Index Status",
divider=True,
help="Select an index to check the status of what stage indexing is in. Indexing must be complete in order to be able to execute queries.",
)
options_indexes = self.client.get_index_names()
# create logic for defaulting to running job index if one exists
new_index_name = st.session_state["index-name-input"]
default_index = (
options_indexes.index(new_index_name)
if new_index_name in options_indexes
else 0
)
index_name_select = st.selectbox(
label="Select an index to check its status.",
options=options_indexes if any(options_indexes) else [],
index=default_index,
)
progress_bar = st.progress(0, text="Index Job Progress")
if st.button("Check Status"):
status_response = self.client.check_index_status(index_name_select)
if status_response.status_code == 200:
status_response_text = status_response.json()
if status_response_text["status"] != "":
try:
# build status message
job_status = status_response_text["status"]
status_message = f"Status: {status_response_text['status']}"
st.success(status_message) if job_status in [
"running",
"complete",
] else st.warning(status_message)
except Exception as e:
print(e)
try:
# build percent complete message
percent_complete = status_response_text["percent_complete"]
progress_bar.progress(float(percent_complete) / 100)
completion_message = (
f"Percent Complete: {percent_complete}% "
)
st.warning(
completion_message
) if percent_complete < 100 else st.success(
completion_message
)
except Exception as e:
print(e)
try:
# build progress message
progress_status = status_response_text["progress"]
progress_status = (
progress_status if progress_status else "N/A"
)
progress_message = f"Progress: {progress_status}"
st.success(
progress_message
) if progress_status != "N/A" else st.warning(
progress_message
)
except Exception as e:
print(e)
else:
st.warning(
f"No status information available for this index: {index_name_select}"
)
else:
st.warning(
f"No workflow information available for this index: {index_name_select}"
)

View File

@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import streamlit as st
from src.enums import EnvVars
from src.graphrag_api import GraphragAPI
def login():
"""
Login component that displays in the sidebar. Requires the user to enter
the APIM Gateway URL and Subscription Key to login. After entering user
credentials, a simple health check call is made to the GraphRAG API.
"""
with st.sidebar:
st.title(
"Login",
help="Enter your APIM credentials to get started. Refreshing the browser will require you to login again.",
)
with st.form(key="login-form", clear_on_submit=True):
apim_url = st.text_input("APIM Gateway URL", key="apim-url")
apim_sub_key = st.text_input(
"APIM Subscription Key", key="subscription-key"
)
form_submit = st.form_submit_button("Login")
if form_submit:
client = GraphragAPI(apim_url, apim_sub_key)
status_code = client.health_check()
if status_code == 200:
st.success("Login Successful")
st.session_state[EnvVars.DEPLOYMENT_URL.value] = apim_url
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = apim_sub_key
st.session_state["initialized"] = True
st.rerun()
else:
st.error("Login Failed")
st.error("Please check the APIM Gateway URL and Subscription Key")
return status_code

View File

@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import streamlit as st
from src.enums import PromptFileNames, PromptKeys, PromptTextAreas
from src.functions import zip_directory
SAVED_PROMPT_VAR = "saved_prompts"
def save_prompts(
local_dir: str = "./edited_prompts/", zip_file_path: str = "edited_prompts.zip"
):
"""
Save prompts in memory and on disk as a zip file
"""
st.session_state[SAVED_PROMPT_VAR] = True
st.session_state[PromptKeys.ENTITY.value] = st.session_state[
PromptTextAreas.ENTITY.value
]
st.session_state[PromptKeys.SUMMARY.value] = st.session_state[
PromptTextAreas.SUMMARY.value
]
st.session_state[PromptKeys.COMMUNITY.value] = st.session_state[
PromptTextAreas.COMMUNITY.value
]
os.makedirs(local_dir, exist_ok=True)
for key, filename in zip(PromptKeys, PromptFileNames):
outpath = os.path.join(local_dir, filename.value)
with open(outpath, "w", encoding="utf-8") as f:
f.write(st.session_state[key.value])
zip_directory(local_dir, zip_file_path)
def edit_prompts():
"""
Re-edit the prompts
"""
st.session_state[SAVED_PROMPT_VAR] = False
def prompt_editor(prompt_values: list[str]):
"""
Container for prompt configurations
"""
saved_prompts = st.session_state[SAVED_PROMPT_VAR]
entity_ext_prompt, summ_prompt, comm_report_prompt = prompt_values
with st.container(border=True):
tab_labels = [
"**Entity Extraction**",
"**Summarize Descriptions**",
"**Community Reports**",
]
# subheaders = [f"{tab_label} Prompt" for tab_label in tab_labels]
tab1, tab2, tab3 = st.tabs(tabs=tab_labels)
with tab1:
st.text_area(
label="Entity Prompt",
value=entity_ext_prompt,
max_chars=20000,
key="entity_text_area",
label_visibility="hidden",
disabled=saved_prompts,
)
with tab2:
st.text_area(
label="Summarize Prompt",
value=summ_prompt,
max_chars=20000,
key="summary_text_area",
label_visibility="hidden",
disabled=saved_prompts,
)
with tab3:
st.text_area(
label="Community Reports Prompt",
value=comm_report_prompt,
max_chars=20000,
key="community_text_area",
label_visibility="hidden",
disabled=saved_prompts,
)

View File

@ -0,0 +1,274 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
from typing import Literal
import numpy as np
import pandas as pd
import requests
import streamlit as st
from src.graphrag_api import GraphragAPI
class GraphQuery:
KILOBYTE = 1024
def __init__(self, client: GraphragAPI):
self.client = client
def search(
self,
query_type: Literal["Global Streaming", "Global", "Local"],
search_index: str | list[str],
query: str,
) -> None:
idler_message_list = [
"Querying the graph...",
"Processing the query...",
"The graph is working hard...",
"Fetching the results...",
"Reticulating splines...",
"Almost there...",
"The report format is customizable, for this demo we report back in executive summary format. It's prompt driven to change as you like!",
"Just a few more seconds...",
"You probably know these messages are just for fun...",
"In the meantime, here's a fun fact: Did you know that the Microsoft GraphRAG Copilot is built on top of the Microsoft GraphRAG Solution Accelerator?",
"The average graph query processes several textbooks worth of information to get you your answer. I hope it was a good question!",
"Shamelessly buying time...",
"When the answer comes, make sure to check the context reports, the detail there is incredible!",
"When we ingest data into the graph, the structure of language itself is used to create the graph structure. It's like a language-based neural network, using neural networks to understand language to network. It's a network-ception!",
"The answers will come eventually, I promise. In the meantime, I recommend a doppio espresso, or a nice cup of tea. Or both! The GraphRAG team runs on caffeine.",
"The graph is a complex structure, but it's working hard to get you the answer you need.",
"GraphRAG is step one in a long journey of understanding the world through language. It's a big step, but there's so much more to come.",
"The results are on their way...",
]
message = np.random.choice(idler_message_list)
with st.spinner(text=message):
try:
match query_type:
case "Global Streaming":
_ = self.global_streaming_search(search_index, query)
case "Global":
_ = self.global_search(search_index, query)
case "Local":
_ = self.local_search(search_index, query)
except requests.exceptions.RequestException as e:
st.error(f"Error with query {query_type}: {str(e)}")
def global_streaming_search(
self, search_index: str | list[str], query: str
) -> None:
"""
Executes a global streaming query on the specified index.
Handles the response and displays the results in the Streamlit app.
"""
query_response = self.client.global_streaming_query(search_index, query)
assistant_response = ""
context_list = []
if query_response.status_code == 200:
text_placeholder = st.empty()
for chunk in query_response.iter_lines(
# allow up to 256KB to avoid excessive many reads
chunk_size=256 * GraphQuery.KILOBYTE,
decode_unicode=True,
):
try:
payload = json.loads(chunk)
except json.JSONDecodeError as e:
# In the event that a chunk is not a complete JSON object,
# document it for further analysis.
print(chunk)
raise e
token = payload["token"]
context = payload["context"]
if (token != "<EOM>") and (context is None):
assistant_response += token
text_placeholder.write(assistant_response)
elif (token == "<EOM>") and (context is not None):
context_list.append(context)
if not assistant_response:
st.write(
self.format_md_text(
"Not enough contextual data to support your query: No results found.\tTry another query.",
"red",
True,
)
)
return
else:
with self._create_section_expander("Query Context"):
st.write(
self.format_md_text(
"Double-click on content to expand text", "red", False
)
)
self._build_st_dataframe(context_list)
else:
print(query_response.reason, query_response.content)
raise Exception("Received unexpected response from server")
def global_search(self, search_index: str | list[str], query: str) -> None:
query_response = self.client.query_index(
index_name=search_index, query_type="Global", query=query
)
if query_response["result"] != "":
with self._create_section_expander("Query Response", "black", True, True):
st.write(query_response["result"])
with self._create_section_expander("Query Context"):
st.write(
self.format_md_text(
"Double-click on content to expand text", "red", False
)
)
self._build_st_dataframe(query_response["context_data"]["reports"])
def local_search(self, search_index: str | list[str], query: str) -> None:
query_response = self.client.query_index(
index_name=search_index, query_type="Local", query=query
)
results = query_response["result"]
if results != "":
with self._create_section_expander("Query Response", "black", True, True):
st.write(results)
context_data = query_response["context_data"]
reports = context_data["reports"]
entities = context_data["entities"]
relationships = context_data["relationships"]
# sources = context_data["sources"]
if any(reports):
with self._create_section_expander("Query Context"):
st.write(
self.format_md_text(
"Double-click on content to expand text", "red", False
)
)
self._build_st_dataframe(reports)
if any(entities):
with st.spinner("Loading context entities..."):
with self._create_section_expander("Context Entities"):
df_entities = pd.DataFrame(entities)
self._build_st_dataframe(df_entities, entity_df=True)
# TODO: Fix the next portion of code to provide a more granular entity view
# for report in entities:
# entity_response = get_source_entity(
# report["index_name"], report["id"], self.api_url, self.headers
# )
# for unit in entity_response["text_units"]:
# response = requests.get(
# f"{self.api_url}/source/text/{report['index_name']}/{unit}",
# headers=self.headers,
# )
# text_info = response.json()
# if text_info is not None:
# with st.expander(
# f" Entity: {report['entity']} - Source Document: {text_info['source_document']} "
# ):
# st.write(text_info["text"])
if any(relationships):
with st.spinner("Loading context relationships..."):
with self._create_section_expander("Context Relationships"):
df_relationships = pd.DataFrame(relationships)
self._build_st_dataframe(df_relationships, rel_df=True)
# TODO: Fix the next portion of code to provide a more granular relationship view
# for report in query_response["context_data"][
# "relationships"
# ][:15]:
# # with st.expander(
# # f"Source: {report['source']} Target: {report['target']} Rank: {report['rank']}"
# # ):
# # st.write(report["description"])
# relationship_data = requests.get(
# f"{self.api_url}/source/relationship/{report['index_name']}/{report['id']}",
# headers=self.headers,
# )
# relationship_data = relationship_data.json()
# for unit in relationship_data["text_units"]:
# response = requests.get(
# f"{self.api_url}/source/text/{report['index_name']}/{unit}",
# headers=self.headers,
# )
# text_info_rel = response.json()
# df_textinfo_rel = pd.DataFrame([text_info_rel])
# with st.expander(
# f"Source: {report['source']} Target: {report['target']} - Source Document: {sources['source_document']} "
# ):
# st.write(sources["text"])
# st.dataframe(
# df_textinfo_rel, use_container_width=True
# )
def _build_st_dataframe(
self,
data: dict | pd.DataFrame,
drop_columns: list[str] = ["id", "index_id", "index_name", "in_context"],
entity_df: bool = False,
rel_df: bool = False,
) -> st.dataframe:
df_context = (
data if isinstance(data, pd.DataFrame) else pd.DataFrame.from_records(data)
)
if any(drop_columns):
for column in drop_columns:
if column in df_context.columns:
df_context = df_context.drop(column, axis=1)
if entity_df:
return st.dataframe(
df_context,
use_container_width=True,
column_config={
"entity": "Entity",
"description": "Description",
"number of relationships": "Number of Relationships",
},
)
if rel_df:
return st.dataframe(
df_context,
use_container_width=True,
column_config={
"source": "Source",
"target": "Target",
"description": "Description",
"weight": "Weight",
"rank": "Rank",
"links": "Links",
},
)
return st.dataframe(
df_context,
use_container_width=True,
column_config={
"title": "Report Title",
"content": "Report Content",
"rank": "Rank",
},
)
def format_md_text(self, text: str, color: str, bold: bool) -> str:
"""
Formats text for display in Streamlit app using Markdown syntax.
"""
if bold:
return f":{color}[**{text}**]"
return f":{color}[{text}]"
def _create_section_expander(
self, title: str, color: str = "blue", bold: bool = True, expanded: bool = False
) -> st.expander:
"""
Creates an expander in the Streamlit app with the specified title and content.
"""
return st.expander(self.format_md_text(title, color, bold), expanded=expanded)

View File

@ -0,0 +1,275 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from time import sleep
import streamlit as st
from src.components.index_pipeline import IndexPipeline
from src.components.login_sidebar import login
from src.components.prompt_configuration import (
edit_prompts,
prompt_editor,
save_prompts,
)
from src.components.query import GraphQuery
from src.components.upload_files_component import upload_files
from src.enums import PromptKeys
from src.functions import generate_and_extract_prompts
from src.graphrag_api import GraphragAPI
def get_main_tab(initialized: bool) -> None:
"""
Displays content of Main Tab
"""
url = "https://github.com/Azure-Samples/graphrag-accelerator/blob/main/TRANSPARENCY.md"
content = f"""
## Welcome to GraphRAG!
Diving into complex information and uncovering semantic relationships utilizing generative AI has never been easier.
Here's how you can get started with just a few clicks:
- **PROMPT GENERATION:** (*Optional Step*)
1. Generate fine-tuned prompts for graphrag customized to your data and domain.
2. Select an existing Storage Container and click "Generate Prompts".
- **PROMPT CONFIGURATION:** (*Optional Step*)
1. Edit the generated prompts to best suit your needs.
2. Once you are finished editing, click the "Save Prompts" button.
3. Saving the prompts will store them for use in the follow-on Indexing step.
4. You can also download the edited prompts for future reference.
- **INDEXING:**
1. Select an existing data storage container or upload data, to Index
2. Name your index and select "Build Index" to begin building a GraphRAG Index.
3. Check the status of the index as the job progresses.
- **QUERYING:**
1. Choose an existing index
2. Specify a query type
3. Click "Query" button to search and view insights.
[GraphRAG]({url}) combines the power of RAG with a Graph structure, giving you insights at your fingertips.
"""
# Display text in the gray box
st.markdown(content, unsafe_allow_html=False)
if not initialized:
login()
def get_prompt_generation_tab(
client: GraphragAPI,
column_widths: list[float],
num_chunks: int = 5,
) -> None:
"""
Displays content of Prompt Generation Tab
"""
# hard set limit to 5 files to reduce overly long processing times and to reduce over sampling errors.
num_chunks = num_chunks if num_chunks <= 5 else 5
_, col2, _ = st.columns(column_widths)
with col2:
st.header(
"Generate Prompts (optional)",
divider=True,
help="Generate fine-tuned prompts for graphrag tailored to your data and domain.",
)
st.write(
"Select a storage container that contains your data. GraphRAG will use this data to generate domain-specific prompts for follow-on indexing."
)
storage_containers = client.get_storage_container_names()
# if no storage containers, allow user to upload files
if isinstance(storage_containers, list) and not (any(storage_containers)):
st.warning(
"No existing Storage Containers found. Please upload data to continue."
)
uploaded = upload_files(client, key_prefix="prompts-upload-1")
if uploaded:
# brief pause to allow success message to display
sleep(1.5)
st.rerun()
else:
select_prompt_storage = st.selectbox(
"Select an existing Storage Container.",
options=[""] + storage_containers
if isinstance(storage_containers, list)
else [],
key="prompt-storage",
index=0,
)
disable_other_input = True if select_prompt_storage != "" else False
with st.expander("I want to upload new data...", expanded=False):
new_upload = upload_files(
client,
key_prefix="prompts-upload-2",
disable_other_input=disable_other_input,
)
if new_upload:
# brief pause to allow success message to display
st.session_state["new_upload"] = True
sleep(1.5)
st.rerun()
if st.session_state["new_upload"] and not select_prompt_storage:
st.warning(
"Please select the newly uploaded Storage Container to continue."
)
st.write(f"**Selected Storage Container:** :blue[{select_prompt_storage}]")
triggered = st.button(
label="Generate Prompts",
key="prompt-generation",
help="Select either an existing Storage Container or upload new data to enable this button.\n\
Then, click to generate custom prompts for the LLM.",
disabled=not select_prompt_storage,
)
if triggered:
with st.spinner("Generating LLM prompts for GraphRAG..."):
generated = generate_and_extract_prompts(
client=client,
storage_name=select_prompt_storage,
limit=num_chunks,
)
if not isinstance(generated, Exception):
st.success(
"Prompts generated successfully! Move on to the next tab to configure the prompts."
)
else:
# assume limit parameter is too high
st.warning(
"You do not have enough data to generate prompts. Retrying with a smaller sample size."
)
while num_chunks > 1:
num_chunks -= 1
generated = generate_and_extract_prompts(
client=client,
storage_name=select_prompt_storage,
limit=num_chunks,
)
if not isinstance(generated, Exception):
st.success(
"Prompts generated successfully! Move on to the next tab to configure the prompts."
)
break
else:
st.warning(f"Retrying with sample size: {num_chunks}")
def get_prompt_configuration_tab(
download_file_name: str = "edited_prompts.zip",
) -> None:
"""
Displays content of Prompt Configuration Tab
"""
st.header(
"Configure Prompts (optional)",
divider=True,
help="Generate fine tuned prompts for the LLM specific to your data and domain.",
)
prompt_values = [st.session_state[k.value] for k in PromptKeys]
if any(prompt_values):
prompt_editor([prompt_values[0], prompt_values[1], prompt_values[2]])
col1, col2, col3 = st.columns(3, gap="large")
with col1:
clicked = st.button(
"Save Prompts",
help="Save the edited prompts for use with the follow-on indexing step. This button must be clicked to enable downloading the prompts.",
type="primary",
key="save-prompt-button",
on_click=save_prompts,
kwargs={"zip_file_path": download_file_name},
)
with col2:
st.button(
"Edit Prompts",
help="Allows user to re-edit the prompts after saving.",
type="primary",
key="edit-prompt-button",
on_click=edit_prompts,
)
with col3:
if os.path.exists(download_file_name):
with open(download_file_name, "rb") as fp:
st.download_button(
"Download Prompts",
data=fp,
file_name=download_file_name,
help="Downloads the saved prompts as a zip file containing three LLM prompts in .txt format.",
mime="application/zip",
type="primary",
disabled=not st.session_state["saved_prompts"],
key="download-prompt-button",
)
if clicked:
st.success(
"Prompts saved successfully! Downloading prompts is now enabled."
)
def get_index_tab(indexPipe: IndexPipeline) -> None:
"""
Displays content of Index tab
"""
indexPipe.storage_data_step()
indexPipe.build_index_step()
indexPipe.check_status_step()
def execute_query(
query_engine: GraphQuery, query_type: str, search_index: str | list[str], query: str
) -> None:
"""
Executes the query on the selected index
"""
if query:
query_engine.search(
query_type=query_type, search_index=search_index, query=query
)
else:
return st.warning("Please enter a query to search.")
def get_query_tab(client: GraphragAPI) -> None:
"""
Displays content of Query Tab
"""
gquery = GraphQuery(client)
col1, col2 = st.columns(2)
with col1:
query_type = st.selectbox(
"Query Type",
["Global Streaming", "Global", "Local"],
help="Select the query type - Each yeilds different results of specificity. Global queries focus on the entire graph structure. Local queries focus on a set of communities (subgraphs) in the graph that are more connected to each other than they are to the rest of the graph structure and can focus on very specific entities in the graph. Global streaming is a global query that displays results as they appear live.",
)
with col2:
search_indexes = client.get_index_names()
if not any(search_indexes):
st.warning("No indexes found. Please build an index to continue.")
select_index_search = st.multiselect(
label="Index",
options=search_indexes if any(search_indexes) else [],
key="multiselect-index-search",
help="Select the index(es) to query. The selected index(es) must have a complete status in order to yield query results without error. Use Check Index Status to confirm status.",
)
disabled = True if not any(select_index_search) else False
col3, col4 = st.columns([0.8, 0.2])
with col3:
search_bar = st.text_input("Query", key="search-query", disabled=disabled)
with col4:
search_button = st.button("QUERY", type="primary", disabled=disabled)
# defining a query variable enables the use of either the search bar or the search button to trigger the query
query = st.session_state["search-query"]
if len(query) > 5:
if (search_bar and search_button) and any(select_index_search):
execute_query(
query_engine=gquery,
query_type=query_type,
search_index=select_index_search,
query=query,
)
else:
col1, col2 = st.columns([0.3, 0.7])
with col1:
st.warning("Cannot submit queries less than 6 characters in length.")

View File

@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
import streamlit as st
from src.graphrag_api import GraphragAPI
UPLOAD_HELP_MESSAGE = """
This functionality is disabled while an existing Storage Container is selected.
Please deselect the existing Storage Container to upload new data.
"""
def upload_files(
client: GraphragAPI, key_prefix: str, disable_other_input: bool = False
):
"""
Reusable component to upload files to Blob Storage Container
"""
input_storage_name = st.text_input(
label="Enter Storage Name",
key=f"{key_prefix}-storage-name-input",
disabled=disable_other_input,
help=UPLOAD_HELP_MESSAGE,
)
file_upload = st.file_uploader(
"Upload Data",
type=["txt"],
key=f"{key_prefix}-file-uploader",
accept_multiple_files=True,
disabled=disable_other_input,
)
uploaded = st.button(
"Upload Files",
key=f"{key_prefix}-upload-button",
disabled=disable_other_input or input_storage_name == "",
)
if uploaded:
if file_upload and input_storage_name != "":
file_payloads = []
for file in file_upload:
file_payload = (
"files",
(file.name, file.read(), file.type),
)
file_payloads.append((file_payload))
response = client.upload_files(file_payloads, input_storage_name)
if response.status_code == 200:
st.success("Files uploaded successfully!")
else:
st.error(f"Error: {json.loads(response.text)}")
return uploaded

33
frontend/src/enums.py Normal file
View File

@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from enum import Enum
class PromptKeys(Enum):
ENTITY = "entity_extraction"
SUMMARY = "summarize_descriptions"
COMMUNITY = "community_report"
class PromptFileNames(Enum):
ENTITY = "entity_extraction_prompt.txt"
SUMMARY = "summarize_descriptions_prompt.txt"
COMMUNITY = "community_report_prompt.txt"
class PromptTextAreas(Enum):
ENTITY = "entity_text_area"
SUMMARY = "summary_text_area"
COMMUNITY = "community_text_area"
class StorageIndexVars(Enum):
SELECTED_STORAGE = "selected_storage"
INPUT_STORAGE = "input_storage"
SELECTED_INDEX = "selected_index"
class EnvVars(Enum):
APIM_SUBSCRIPTION_KEY = "APIM_SUBSCRIPTION_KEY"
DEPLOYMENT_URL = "DEPLOYMENT_URL"

175
frontend/src/functions.py Normal file
View File

@ -0,0 +1,175 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Optional
from zipfile import ZipFile
import streamlit as st
from dotenv import find_dotenv, load_dotenv
from src.enums import EnvVars, PromptKeys, StorageIndexVars
from src.graphrag_api import GraphragAPI
"""
This module contains functions that are used across the Streamlit app.
"""
def initialize_app(env_file: str = ".env", css_file: str = "style.css") -> bool:
"""
Initialize the Streamlit app with the necessary configurations.
"""
# set page configuration
st.set_page_config(initial_sidebar_state="expanded", layout="wide")
# set custom CSS
with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# initialize session state variables
set_session_state_variables()
# load environment variables
_ = load_dotenv(find_dotenv(filename=env_file) or None, override=True)
# either load from .env file or from session state
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = os.getenv(
EnvVars.APIM_SUBSCRIPTION_KEY.value,
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value],
)
st.session_state[EnvVars.DEPLOYMENT_URL.value] = os.getenv(
EnvVars.DEPLOYMENT_URL.value, st.session_state[EnvVars.DEPLOYMENT_URL.value]
)
if (
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value]
and st.session_state[EnvVars.DEPLOYMENT_URL.value]
):
st.session_state["headers"] = {
"Ocp-Apim-Subscription-Key": st.session_state[
EnvVars.APIM_SUBSCRIPTION_KEY.value
],
"Content-Type": "application/json",
}
st.session_state["headers_upload"] = {
"Ocp-Apim-Subscription-Key": st.session_state[
EnvVars.APIM_SUBSCRIPTION_KEY.value
]
}
return True
else:
return False
def set_session_state_variables() -> None:
"""
Initalizes most session state variables for the app.
"""
for key in PromptKeys:
value = key.value
if value not in st.session_state:
st.session_state[value] = ""
for key in StorageIndexVars:
value = key.value
if value not in st.session_state:
st.session_state[value] = ""
for key in EnvVars:
value = key.value
if value not in st.session_state:
st.session_state[value] = ""
if "saved_prompts" not in st.session_state:
st.session_state["saved_prompts"] = False
if "initialized" not in st.session_state:
st.session_state["initialized"] = False
if "new_upload" not in st.session_state:
st.session_state["new_upload"] = False
def update_session_state_prompt_vars(
entity_extract: Optional[str] = None,
summarize: Optional[str] = None,
community: Optional[str] = None,
initial_setting: bool = False,
prompt_dir: str = "./prompts",
) -> None:
"""
Updates the session state variables for the LLM prompts.
"""
if initial_setting:
entity_extract, summarize, community = get_prompts(prompt_dir)
if entity_extract:
st.session_state[PromptKeys.ENTITY.value] = entity_extract
if summarize:
st.session_state[PromptKeys.SUMMARY.value] = summarize
if community:
st.session_state[PromptKeys.COMMUNITY.value] = community
def generate_and_extract_prompts(
client: GraphragAPI,
storage_name: str,
zip_file_name: str = "prompts.zip",
limit: int = 5,
) -> None | Exception:
"""
Makes API call to generate LLM prompts, extracts prompts from zip file,
and updates the prompt session state variables.
"""
try:
client.generate_prompts(
storage_name=storage_name, zip_file_name=zip_file_name, limit=limit
)
_extract_prompts_from_zip(zip_file_name)
update_session_state_prompt_vars(initial_setting=True)
return
except Exception as e:
return e
def _extract_prompts_from_zip(zip_file_name: str = "prompts.zip"):
with ZipFile(zip_file_name, "r") as zip_ref:
zip_ref.extractall()
def open_file(file_path: str | Path):
with open(file_path, "r", encoding="utf-8") as file:
text = file.read()
return text
def zip_directory(directory_path: str, zip_path: str):
"""
Zips all contents of a directory into a single zip file.
Parameters:
- directory_path: str, the path of the directory to zip
- zip_path: str, the path where the zip file will be created
"""
root_dir_name = os.path.basename(directory_path.rstrip("/"))
with ZipFile(zip_path, "w") as zipf:
for root, _, files in os.walk(directory_path):
for file in files:
file_path = os.path.join(root, file)
relpath = os.path.relpath(file_path, start=directory_path)
arcname = os.path.join(root_dir_name, relpath)
zipf.write(file_path, arcname)
def get_prompts(prompt_dir: str = "./prompts"):
"""
Extract text from generated prompts. Assumes file names comply with pregenerated file name standards.
"""
prompt_paths = [
prompt for prompt in Path(prompt_dir).iterdir() if prompt.name.endswith(".txt")
]
entity_ext_prompt = [
open_file(path) for path in prompt_paths if path.name.startswith("entity")
][0]
summ_prompt = [
open_file(path) for path in prompt_paths if path.name.startswith("summ")
][0]
comm_report_prompt = [
open_file(path) for path in prompt_paths if path.name.startswith("community")
][0]
return entity_ext_prompt, summ_prompt, comm_report_prompt

View File

@ -0,0 +1,214 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from io import StringIO
import requests
import streamlit as st
from requests import Response
"""
This module contains the GraphRAG API class for making all external API calls
presumably to a GraphRAG instance deployed on Azure.
"""
class GraphragAPI:
"""
Primary interface for making REST API call to GraphRAG API.
"""
def __init__(self, api_url: str, apim_key: str):
self.api_url = api_url
self.apim_key = apim_key
self.headers = {
"Ocp-Apim-Subscription-Key": self.apim_key,
"Content-Type": "application/json",
}
self.upload_headers = {"Ocp-Apim-Subscription-Key": self.apim_key}
def get_storage_container_names(
self, storage_name_key: str = "storage_name"
) -> list[str] | Response | Exception:
"""
GET request to GraphRAG API for Azure Blob Storage Container names.
"""
try:
response = requests.get(f"{self.api_url}/data", headers=self.headers)
if response.status_code == 200:
return response.json()[storage_name_key]
else:
print(f"Error: {response.status_code}")
return response
except Exception as e:
print(f"Error: {str(e)}")
return e
def upload_files(self, file_payloads: dict, input_storage_name: str):
"""
Upload files to Azure Blob Storage Container.
"""
try:
response = requests.post(
self.api_url + "/data",
headers=self.upload_headers,
files=file_payloads,
params={"storage_name": input_storage_name},
)
if response.status_code == 200:
return response
except Exception as e:
print(f"Error: {str(e)}")
def get_index_names(
self, index_name_key: str = "index_name"
) -> list | Response | None:
"""
GET request to GraphRAG API for existing indexes.
"""
try:
response = requests.get(f"{self.api_url}/index", headers=self.headers)
if response.status_code == 200:
return response.json()[index_name_key]
else:
print(f"Error: {response.status_code}")
return response
except Exception as e:
print(f"Error: {str(e)}")
def build_index(
self,
storage_name: str,
index_name: str,
entity_extraction_prompt_filepath: str | StringIO = None,
community_prompt_filepath: str | StringIO = None,
summarize_description_prompt_filepath: str | StringIO = None,
) -> requests.Response:
"""
Create a search index.
This function kicks off a job that builds a knowledge graph (KG)
index from files located in a blob storage container.
"""
url = self.api_url + "/index"
prompt_files = dict()
if entity_extraction_prompt_filepath:
prompt_files["entity_extraction_prompt"] = (
open(entity_extraction_prompt_filepath, "r", encoding="utf-8")
if isinstance(entity_extraction_prompt_filepath, str)
else entity_extraction_prompt_filepath
)
if community_prompt_filepath:
prompt_files["community_report_prompt"] = (
open(community_prompt_filepath, "r", encoding="utf-8")
if isinstance(community_prompt_filepath, str)
else community_prompt_filepath
)
if summarize_description_prompt_filepath:
prompt_files["summarize_descriptions_prompt"] = (
open(summarize_description_prompt_filepath, "r", encoding="utf-8")
if isinstance(summarize_description_prompt_filepath, str)
else summarize_description_prompt_filepath
)
return requests.post(
url,
files=prompt_files if len(prompt_files) > 0 else None,
params={"index_name": index_name, "storage_name": storage_name},
headers=self.headers,
)
def check_index_status(self, index_name: str) -> Response | None:
"""
Check the status of a running index job.
"""
url = self.api_url + f"/index/status/{index_name}"
try:
response = requests.get(url, headers=self.headers)
if response.status_code == 200:
return response
else:
print(f"Error: {response.status_code}")
return response
except Exception as e:
print(f"Error: {str(e)}")
def health_check(self) -> int | Response:
"""
Check the health of the APIM endpoint.
"""
url = self.api_url + "/health"
try:
response = requests.get(url, headers=self.headers)
return response.status_code
except Exception as e:
print(f"Error: {str(e)}")
return e
def query_index(self, index_name: str | list[str], query_type: str, query: str):
"""
Submite query to GraphRAG API using specific index and query type.
"""
try:
request = {
"index_name": index_name,
"query": query,
"reformat_context_data": True,
}
response = requests.post(
f"{self.api_url}/query/{query_type.lower()}",
headers=self.headers,
json=request,
)
if response.status_code == 200:
return response.json()
else:
st.error(
f"Error with {query_type} search: {response.status_code} {response.json()}"
)
except Exception as e:
st.error(f"Error with {query_type} search: {str(e)}")
def global_streaming_query(
self, index_name: str | list[str], query: str
) -> Response | None:
"""
Returns a streaming response object for a global query.
"""
url = f"{self.api_url}/experimental/query/global/streaming"
try:
query_response = requests.post(
url,
json={"index_name": index_name, "query": query},
headers=self.headers,
stream=True,
)
return query_response
except Exception as e:
print(f"Error: {str(e)}")
def get_source_entity(self, index_name: str, entity_id: str) -> dict | None:
try:
response = requests.get(
f"{self.api_url}/source/entity/{index_name}/{entity_id}",
headers=self.headers,
)
if response.status_code == 200:
return response.json()
else:
return response
except Exception as e:
print(f"Error: {str(e)}")
def generate_prompts(
self, storage_name: str, zip_file_name: str = "prompts.zip", limit: int = 1
) -> None:
"""
Generate graphrag prompts using data provided in a specific storage container.
"""
url = self.api_url + "/index/config/prompts"
params = {"storage_name": storage_name, "limit": limit}
with requests.get(url, params=params, headers=self.headers, stream=True) as r:
r.raise_for_status()
with open(zip_file_name, "wb") as f:
for chunk in r.iter_content():
f.write(chunk)

142
frontend/style.css Normal file
View File

@ -0,0 +1,142 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi8 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi5 > div > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn5 > div:nth-child(4) > div > div > div > div > div{
margin-top: 1.6em;
}
[data-testid="stHeadingDivider"] {
background-color: #3d9df3; /* Set your desired color */
height: 1px;
}
#microsoft-graphrag-copilot > div > span {
text-align: center;
margin-top: -1em;
}
/* Tooltip container */
.tooltip {
position: relative;
display: inline-block;
border-bottom: 1px dotted black; /* If you want dots under the hoverable text */
}
/* Tooltip text */
.tooltip .tooltiptext {
visibility: hidden;
width: 120px;
background-color: #555;
color: #fff;
text-align: center;
border-radius: 6px;
padding: 5px;
position: absolute;
z-index: 1;
bottom: 125%;
left: 50%;
margin-left: -60px;
opacity: 0;
transition: opacity 0.3s;
}
/* Show the tooltip text when you hover over the tooltip container */
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;
}
.gray-box {
background-color: #ffffff;
padding: 10px;
width: 80%;
}
.center-container {
margin-top: -10em;
display: flex;
align-items: center;
justify-content: center;
height: 100vh;
}
.footer {
display: flex;
justify-content: center;
align-items: center;
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: #f1f1f1;
text-align: center;
padding: 5px;
z-index: 1000;
}
.footer p{
font-size: 12px;
}
/* CSS */
button[kind="primary"] {
background-color: #1d9445;
border: 0;
border-radius: 56px;
color: #fff;
cursor: pointer;
display: inline-block;
font-family: system-ui,-apple-system,system-ui,"Segoe UI",Roboto,Ubuntu,"Helvetica Neue",sans-serif;
font-size: 58px;
font-weight: 600;
outline: 0;
padding: 16px 21px;
position: relative;
text-align: center;
text-decoration: none;
transition: all .3s;
user-select: none;
-webkit-user-select: none;
touch-action: manipulation;
}
button[kind="primary"]:before {
background-color: initial;
background-image: linear-gradient(#fff 0, rgba(255, 255, 255, 0) 100%);
border-radius: 125px;
content: "";
height: 50%;
left: 4%;
opacity: .5;
position: absolute;
top: 0;
transition: all .3s;
width: 62%;
}
button[kind="primary"]:hover {
box-shadow: rgba(255, 255, 255, .2) 0 3px 15px inset, rgba(0, 0, 0, .1) 0 3px 5px, rgba(0, 0, 0, .1) 0 10px 13px;
transform: scale(1.05);
}
@media (min-width: 768px) {
button[kind="primary"] {
padding: 15px 34px;
margin: 20px auto;
}
}
.element-container:has(>.stTextArea), .stTextArea {
display: block;
margin-left: auto;
margin-right: auto;
}
.stTextArea textarea {
height: 500px;
/*background-color: #a7b0a4;*/
}