# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os from pathlib import Path from typing import Optional from zipfile import ZipFile import streamlit as st from dotenv import find_dotenv, load_dotenv from src.enums import EnvVars, PromptKeys, StorageIndexVars from src.graphrag_api import GraphragAPI """ This module contains functions that are used across the Streamlit app. """ def initialize_app(env_file: str = ".env", css_file: str = "style.css") -> bool: """ Initialize the Streamlit app with the necessary configurations. """ # set page configuration st.set_page_config(initial_sidebar_state="expanded", layout="wide") # set custom CSS with open(css_file) as f: st.markdown(f"", unsafe_allow_html=True) # initialize session state variables set_session_state_variables() # load environment variables _ = load_dotenv(find_dotenv(filename=env_file) or None, override=True) # either load from .env file or from session state st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = os.getenv( EnvVars.APIM_SUBSCRIPTION_KEY.value, st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value], ) st.session_state[EnvVars.DEPLOYMENT_URL.value] = os.getenv( EnvVars.DEPLOYMENT_URL.value, st.session_state[EnvVars.DEPLOYMENT_URL.value] ) if ( st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] and st.session_state[EnvVars.DEPLOYMENT_URL.value] ): st.session_state["headers"] = { "Ocp-Apim-Subscription-Key": st.session_state[ EnvVars.APIM_SUBSCRIPTION_KEY.value ], "Content-Type": "application/json", } st.session_state["headers_upload"] = { "Ocp-Apim-Subscription-Key": st.session_state[ EnvVars.APIM_SUBSCRIPTION_KEY.value ] } return True else: return False def set_session_state_variables() -> None: """ Initalizes most session state variables for the app. """ for key in PromptKeys: value = key.value if value not in st.session_state: st.session_state[value] = "" for key in StorageIndexVars: value = key.value if value not in st.session_state: st.session_state[value] = "" for key in EnvVars: value = key.value if value not in st.session_state: st.session_state[value] = "" if "saved_prompts" not in st.session_state: st.session_state["saved_prompts"] = False if "initialized" not in st.session_state: st.session_state["initialized"] = False if "new_upload" not in st.session_state: st.session_state["new_upload"] = False def update_session_state_prompt_vars( entity_extract: Optional[str] = None, summarize: Optional[str] = None, community: Optional[str] = None, initial_setting: bool = False, prompt_dir: str = "./prompts", ) -> None: """ Updates the session state variables for the LLM prompts. """ if initial_setting: entity_extract, summarize, community = get_prompts(prompt_dir) if entity_extract: st.session_state[PromptKeys.ENTITY.value] = entity_extract if summarize: st.session_state[PromptKeys.SUMMARY.value] = summarize if community: st.session_state[PromptKeys.COMMUNITY.value] = community def generate_and_extract_prompts( client: GraphragAPI, storage_name: str, zip_file_name: str = "prompts.zip", limit: int = 5, ) -> None | Exception: """ Makes API call to generate LLM prompts, extracts prompts from zip file, and updates the prompt session state variables. """ try: client.generate_prompts( storage_name=storage_name, zip_file_name=zip_file_name, limit=limit ) _extract_prompts_from_zip(zip_file_name) update_session_state_prompt_vars(initial_setting=True) return except Exception as e: return e def _extract_prompts_from_zip(zip_file_name: str = "prompts.zip"): with ZipFile(zip_file_name, "r") as zip_ref: zip_ref.extractall() def open_file(file_path: str | Path): with open(file_path, "r", encoding="utf-8") as file: text = file.read() return text def zip_directory(directory_path: str, zip_path: str): """ Zips all contents of a directory into a single zip file. Parameters: - directory_path: str, the path of the directory to zip - zip_path: str, the path where the zip file will be created """ root_dir_name = os.path.basename(directory_path.rstrip("/")) with ZipFile(zip_path, "w") as zipf: for root, _, files in os.walk(directory_path): for file in files: file_path = os.path.join(root, file) relpath = os.path.relpath(file_path, start=directory_path) arcname = os.path.join(root_dir_name, relpath) zipf.write(file_path, arcname) def get_prompts(prompt_dir: str = "./prompts"): """ Extract text from generated prompts. Assumes file names comply with pregenerated file name standards. """ prompt_paths = [ prompt for prompt in Path(prompt_dir).iterdir() if prompt.name.endswith(".txt") ] entity_ext_prompt = [ open_file(path) for path in prompt_paths if path.name.startswith("entity") ][0] summ_prompt = [ open_file(path) for path in prompt_paths if path.name.startswith("summ") ][0] comm_report_prompt = [ open_file(path) for path in prompt_paths if path.name.startswith("community") ][0] return entity_ext_prompt, summ_prompt, comm_report_prompt