176 lines
5.6 KiB
Python
Raw Normal View History

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Optional
from zipfile import ZipFile
import streamlit as st
from dotenv import find_dotenv, load_dotenv
from src.enums import EnvVars, PromptKeys, StorageIndexVars
from src.graphrag_api import GraphragAPI
"""
This module contains functions that are used across the Streamlit app.
"""
def initialize_app(env_file: str = ".env", css_file: str = "style.css") -> bool:
"""
Initialize the Streamlit app with the necessary configurations.
"""
# set page configuration
st.set_page_config(initial_sidebar_state="expanded", layout="wide")
# set custom CSS
with open(css_file) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# initialize session state variables
set_session_state_variables()
# load environment variables
_ = load_dotenv(find_dotenv(filename=env_file) or None, override=True)
# either load from .env file or from session state
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = os.getenv(
EnvVars.APIM_SUBSCRIPTION_KEY.value,
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value],
)
st.session_state[EnvVars.DEPLOYMENT_URL.value] = os.getenv(
EnvVars.DEPLOYMENT_URL.value, st.session_state[EnvVars.DEPLOYMENT_URL.value]
)
if (
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value]
and st.session_state[EnvVars.DEPLOYMENT_URL.value]
):
st.session_state["headers"] = {
"Ocp-Apim-Subscription-Key": st.session_state[
EnvVars.APIM_SUBSCRIPTION_KEY.value
],
"Content-Type": "application/json",
}
st.session_state["headers_upload"] = {
"Ocp-Apim-Subscription-Key": st.session_state[
EnvVars.APIM_SUBSCRIPTION_KEY.value
]
}
return True
else:
return False
def set_session_state_variables() -> None:
"""
Initalizes most session state variables for the app.
"""
for key in PromptKeys:
value = key.value
if value not in st.session_state:
st.session_state[value] = ""
for key in StorageIndexVars:
value = key.value
if value not in st.session_state:
st.session_state[value] = ""
for key in EnvVars:
value = key.value
if value not in st.session_state:
st.session_state[value] = ""
if "saved_prompts" not in st.session_state:
st.session_state["saved_prompts"] = False
if "initialized" not in st.session_state:
st.session_state["initialized"] = False
if "new_upload" not in st.session_state:
st.session_state["new_upload"] = False
def update_session_state_prompt_vars(
entity_extract: Optional[str] = None,
summarize: Optional[str] = None,
community: Optional[str] = None,
initial_setting: bool = False,
prompt_dir: str = "./prompts",
) -> None:
"""
Updates the session state variables for the LLM prompts.
"""
if initial_setting:
entity_extract, summarize, community = get_prompts(prompt_dir)
if entity_extract:
st.session_state[PromptKeys.ENTITY.value] = entity_extract
if summarize:
st.session_state[PromptKeys.SUMMARY.value] = summarize
if community:
st.session_state[PromptKeys.COMMUNITY.value] = community
def generate_and_extract_prompts(
client: GraphragAPI,
storage_name: str,
zip_file_name: str = "prompts.zip",
limit: int = 5,
) -> None | Exception:
"""
Makes API call to generate LLM prompts, extracts prompts from zip file,
and updates the prompt session state variables.
"""
try:
client.generate_prompts(
storage_name=storage_name, zip_file_name=zip_file_name, limit=limit
)
_extract_prompts_from_zip(zip_file_name)
update_session_state_prompt_vars(initial_setting=True)
return
except Exception as e:
return e
def _extract_prompts_from_zip(zip_file_name: str = "prompts.zip"):
with ZipFile(zip_file_name, "r") as zip_ref:
zip_ref.extractall()
def open_file(file_path: str | Path):
with open(file_path, "r", encoding="utf-8") as file:
text = file.read()
return text
def zip_directory(directory_path: str, zip_path: str):
"""
Zips all contents of a directory into a single zip file.
Parameters:
- directory_path: str, the path of the directory to zip
- zip_path: str, the path where the zip file will be created
"""
root_dir_name = os.path.basename(directory_path.rstrip("/"))
with ZipFile(zip_path, "w") as zipf:
for root, _, files in os.walk(directory_path):
for file in files:
file_path = os.path.join(root, file)
relpath = os.path.relpath(file_path, start=directory_path)
arcname = os.path.join(root_dir_name, relpath)
zipf.write(file_path, arcname)
def get_prompts(prompt_dir: str = "./prompts"):
"""
Extract text from generated prompts. Assumes file names comply with pregenerated file name standards.
"""
prompt_paths = [
prompt for prompt in Path(prompt_dir).iterdir() if prompt.name.endswith(".txt")
]
entity_ext_prompt = [
open_file(path) for path in prompt_paths if path.name.startswith("entity")
][0]
summ_prompt = [
open_file(path) for path in prompt_paths if path.name.startswith("summ")
][0]
comm_report_prompt = [
open_file(path) for path in prompt_paths if path.name.startswith("community")
][0]
return entity_ext_prompt, summ_prompt, comm_report_prompt