Josh Bradley 680cfc055e
Update frontend UI app (#174)
Co-authored-by: dorbaker <dorbaker@microsoft.com>
2024-09-19 01:09:26 -04:00

172 lines
5.5 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Optional
from zipfile import ZipFile
import streamlit as st
from src.enums import EnvVars, PromptKeys, StorageIndexVars
from src.graphrag_api import GraphragAPI
"""
This module contains functions that are used across the Streamlit app.
"""
def initialize_app(css_file: str = "style.css") -> bool:
"""
Initialize the Streamlit app with the necessary configurations.
"""
# set page configuration
st.set_page_config(initial_sidebar_state="expanded", layout="wide")
# set custom CSS
with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# initialize session state variables
set_session_state_variables()
# load settings from environment variables
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = os.getenv(
EnvVars.APIM_SUBSCRIPTION_KEY.value,
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value],
)
st.session_state[EnvVars.DEPLOYMENT_URL.value] = os.getenv(
EnvVars.DEPLOYMENT_URL.value, st.session_state[EnvVars.DEPLOYMENT_URL.value]
)
if (
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value]
and st.session_state[EnvVars.DEPLOYMENT_URL.value]
):
st.session_state["headers"] = {
"Ocp-Apim-Subscription-Key": st.session_state[
EnvVars.APIM_SUBSCRIPTION_KEY.value
],
"Content-Type": "application/json",
}
st.session_state["headers_upload"] = {
"Ocp-Apim-Subscription-Key": st.session_state[
EnvVars.APIM_SUBSCRIPTION_KEY.value
]
}
return True
else:
return False
def set_session_state_variables() -> None:
"""
Initalizes most session state variables for the app.
"""
for key in PromptKeys:
value = key.value
if value not in st.session_state:
st.session_state[value] = ""
for key in StorageIndexVars:
value = key.value
if value not in st.session_state:
st.session_state[value] = ""
for key in EnvVars:
value = key.value
if value not in st.session_state:
st.session_state[value] = ""
if "saved_prompts" not in st.session_state:
st.session_state["saved_prompts"] = False
if "initialized" not in st.session_state:
st.session_state["initialized"] = False
if "new_upload" not in st.session_state:
st.session_state["new_upload"] = False
def update_session_state_prompt_vars(
entity_extract: Optional[str] = None,
summarize: Optional[str] = None,
community: Optional[str] = None,
initial_setting: bool = False,
prompt_dir: str = "./prompts",
) -> None:
"""
Updates the session state variables for the LLM prompts.
"""
if initial_setting:
entity_extract, summarize, community = get_prompts(prompt_dir)
if entity_extract:
st.session_state[PromptKeys.ENTITY.value] = entity_extract
if summarize:
st.session_state[PromptKeys.SUMMARY.value] = summarize
if community:
st.session_state[PromptKeys.COMMUNITY.value] = community
def generate_and_extract_prompts(
client: GraphragAPI,
storage_name: str,
zip_file_name: str = "prompts.zip",
limit: int = 5,
) -> None | Exception:
"""
Makes API call to generate LLM prompts, extracts prompts from zip file,
and updates the prompt session state variables.
"""
try:
client.generate_prompts(
storage_name=storage_name, zip_file_name=zip_file_name, limit=limit
)
_extract_prompts_from_zip(zip_file_name)
update_session_state_prompt_vars(initial_setting=True)
return
except Exception as e:
return e
def _extract_prompts_from_zip(zip_file_name: str = "prompts.zip"):
with ZipFile(zip_file_name, "r") as zip_ref:
zip_ref.extractall()
def open_file(file_path: str | Path):
with open(file_path, "r", encoding="utf-8") as file:
text = file.read()
return text
def zip_directory(directory_path: str, zip_path: str):
"""
Zips all contents of a directory into a single zip file.
Parameters:
- directory_path: str, the path of the directory to zip
- zip_path: str, the path where the zip file will be created
"""
root_dir_name = os.path.basename(directory_path.rstrip("/"))
with ZipFile(zip_path, "w") as zipf:
for root, _, files in os.walk(directory_path):
for file in files:
file_path = os.path.join(root, file)
relpath = os.path.relpath(file_path, start=directory_path)
arcname = os.path.join(root_dir_name, relpath)
zipf.write(file_path, arcname)
def get_prompts(prompt_dir: str = "./prompts"):
"""
Extract text from generated prompts. Assumes file names comply with pregenerated file name standards.
"""
prompt_paths = [
prompt for prompt in Path(prompt_dir).iterdir() if prompt.name.endswith(".txt")
]
entity_ext_prompt = [
open_file(path) for path in prompt_paths if path.name.startswith("entity")
][0]
summ_prompt = [
open_file(path) for path in prompt_paths if path.name.startswith("summ")
][0]
comm_report_prompt = [
open_file(path) for path in prompt_paths if path.name.startswith("community")
][0]
return entity_ext_prompt, summ_prompt, comm_report_prompt