mirror of
				https://github.com/Azure-Samples/graphrag-accelerator.git
				synced 2025-11-03 20:19:53 +00:00 
			
		
		
		
	Add frontend application (#68)
Co-authored-by: americanthinker <americanthinker@gmail.com> Co-authored-by: Tim <timothymeyers@users.noreply.github.com> Co-authored-by: Christine Caggiano <cdifonzo@microsoft.com>
This commit is contained in:
		
							parent
							
								
									5dd5060d32
								
							
						
					
					
						commit
						0abbfb2a5f
					
				
							
								
								
									
										13
									
								
								.github/workflows/dev.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										13
									
								
								.github/workflows/dev.yaml
									
									
									
									
										vendored
									
									
								
							@ -52,3 +52,16 @@ jobs:
 | 
			
		||||
          context: .
 | 
			
		||||
          file: docker/Dockerfile-backend
 | 
			
		||||
          push: false
 | 
			
		||||
  build-frontend:
 | 
			
		||||
    needs: [lint-check]
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    if: ${{ !github.event.pull_request.draft }}
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Checkout repository
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
      - name: Build docker image
 | 
			
		||||
        uses: docker/build-push-action@v2
 | 
			
		||||
        with:
 | 
			
		||||
          context: .
 | 
			
		||||
          file: docker/Dockerfile-frontend
 | 
			
		||||
          push: false
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -166,4 +166,4 @@ main.parameters.json
 | 
			
		||||
**/charts/*.tgz
 | 
			
		||||
**/Chart.lock
 | 
			
		||||
 | 
			
		||||
.history
 | 
			
		||||
.history
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										18
									
								
								docker/Dockerfile-frontend
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								docker/Dockerfile-frontend
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,18 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
FROM python:3.10
 | 
			
		||||
 | 
			
		||||
ENV PIP_ROOT_USER_ACTION=ignore
 | 
			
		||||
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
 | 
			
		||||
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
 | 
			
		||||
 | 
			
		||||
COPY poetry.lock pyproject.toml /
 | 
			
		||||
COPY frontend /frontend
 | 
			
		||||
RUN pip install poetry \
 | 
			
		||||
    && poetry config virtualenvs.create false \
 | 
			
		||||
    && poetry install --without backend
 | 
			
		||||
 | 
			
		||||
WORKDIR /frontend
 | 
			
		||||
EXPOSE 8080
 | 
			
		||||
CMD ["streamlit", "run", "app.py", "--server.port", "8080"]
 | 
			
		||||
							
								
								
									
										5
									
								
								frontend/.streamlit/config.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								frontend/.streamlit/config.toml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
[server]
 | 
			
		||||
enableXsrfProtection = false
 | 
			
		||||
							
								
								
									
										26
									
								
								frontend/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								frontend/README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,26 @@
 | 
			
		||||
# Frontend Application Launch Instructions
 | 
			
		||||
A small frontend application, a streamlit app, is provided to demonstrate how to build a UI on top of the solution accelerator API.
 | 
			
		||||
 | 
			
		||||
### 1. Deploy the GraphRAG solution accelerator
 | 
			
		||||
Follow instructions from the [deployment guide](../docs/DEPLOYMENT-GUIDE.md) to deploy a full instance of the solution accelerator.
 | 
			
		||||
 | 
			
		||||
### 2. (optional) Create a `.env` file:
 | 
			
		||||
 | 
			
		||||
| Variable Name | Required | Example | Description |
 | 
			
		||||
| :--- | --- | :--- | ---: |
 | 
			
		||||
DEPLOYMENT_URL        | No | https://<my_apim>.azure-api.net | Base url of the deployed graphrag API. Also referred to as the APIM Gateway URL.
 | 
			
		||||
APIM_SUBSCRIPTION_KEY | No | <subscription_key> | A [subscription key](https://learn.microsoft.com/en-us/azure/api-management/api-management-subscriptions) generated by APIM.
 | 
			
		||||
DEPLOYER_EMAIL        | No | deployer@email.com | Email address of the person/organization that deployed the solution accelerator.
 | 
			
		||||
 | 
			
		||||
### 3. Start UI
 | 
			
		||||
 | 
			
		||||
The frontend application can be run locally as a docker container. If a `.env` file is not provided, the UI will prompt the user for additional information.
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
# cd to the root directory of the repo
 | 
			
		||||
> docker build -t graphrag:frontend -f docker/Dockerfile-frontend .
 | 
			
		||||
> docker run --env-file <env_file> -p 8080:8080 graphrag:frontend
 | 
			
		||||
```
 | 
			
		||||
To access the app , visit `localhost:8080` in your browser.
 | 
			
		||||
 | 
			
		||||
This UI application can also be hosted in Azure as a [Web App](https://azure.microsoft.com/en-us/products/app-service/web).
 | 
			
		||||
							
								
								
									
										63
									
								
								frontend/app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								frontend/app.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,63 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import streamlit as st
 | 
			
		||||
from src.components import tabs
 | 
			
		||||
from src.components.index_pipeline import IndexPipeline
 | 
			
		||||
from src.enums import EnvVars
 | 
			
		||||
from src.functions import initialize_app
 | 
			
		||||
from src.graphrag_api import GraphragAPI
 | 
			
		||||
 | 
			
		||||
# Load environment variables
 | 
			
		||||
initialized = initialize_app()
 | 
			
		||||
st.session_state["initialized"] = True if initialized else False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def graphrag_app(initialized: bool):
 | 
			
		||||
    # main entry point for app interface
 | 
			
		||||
    st.title("Microsoft GraphRAG Copilot")
 | 
			
		||||
    main_tab, prompt_gen_tab, prompt_edit_tab, index_tab, query_tab = st.tabs(
 | 
			
		||||
        [
 | 
			
		||||
            "**Intro**",
 | 
			
		||||
            "**1. Prompt Generation**",
 | 
			
		||||
            "**2. Prompt Configuration**",
 | 
			
		||||
            "**3. Index**",
 | 
			
		||||
            "**4. Query**",
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
    with main_tab:
 | 
			
		||||
        tabs.get_main_tab(initialized)
 | 
			
		||||
 | 
			
		||||
    # if not initialized, only main tab is displayed
 | 
			
		||||
    if initialized:
 | 
			
		||||
        # assign API request information
 | 
			
		||||
        COLUMN_WIDTHS = [0.275, 0.45, 0.275]
 | 
			
		||||
        api_url = st.session_state[EnvVars.DEPLOYMENT_URL.value]
 | 
			
		||||
        apim_key = st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value]
 | 
			
		||||
        client = GraphragAPI(api_url, apim_key)
 | 
			
		||||
        indexPipe = IndexPipeline(client, COLUMN_WIDTHS)
 | 
			
		||||
 | 
			
		||||
        # display tabs
 | 
			
		||||
        with prompt_gen_tab:
 | 
			
		||||
            tabs.get_prompt_generation_tab(client, COLUMN_WIDTHS)
 | 
			
		||||
        with prompt_edit_tab:
 | 
			
		||||
            tabs.get_prompt_configuration_tab()
 | 
			
		||||
        with index_tab:
 | 
			
		||||
            tabs.get_index_tab(indexPipe)
 | 
			
		||||
        with query_tab:
 | 
			
		||||
            tabs.get_query_tab(client)
 | 
			
		||||
 | 
			
		||||
    deployer_email = os.getenv("DEPLOYER_EMAIL", "deployer@email.com")
 | 
			
		||||
 | 
			
		||||
    footer = f"""
 | 
			
		||||
        <div class="footer">
 | 
			
		||||
            <p> Responses may be inaccurate; please review all responses for accuracy. Learn more about Azure OpenAI code of conduct <a href="https://learn.microsoft.com/en-us/legal/cognitive-services/openai/code-of-conduct"> here</a>. </br> For feedback, email us at <a href="mailto:{deployer_email}">{deployer_email}</a>.</p>
 | 
			
		||||
        </div>
 | 
			
		||||
    """
 | 
			
		||||
    st.markdown(footer, unsafe_allow_html=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    graphrag_app(st.session_state["initialized"])
 | 
			
		||||
							
								
								
									
										0
									
								
								frontend/src/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								frontend/src/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								frontend/src/components/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								frontend/src/components/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										206
									
								
								frontend/src/components/index_pipeline.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								frontend/src/components/index_pipeline.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,206 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
from io import StringIO
 | 
			
		||||
 | 
			
		||||
import streamlit as st
 | 
			
		||||
 | 
			
		||||
from src.components.upload_files_component import upload_files
 | 
			
		||||
from src.enums import PromptKeys
 | 
			
		||||
from src.functions import GraphragAPI
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IndexPipeline:
 | 
			
		||||
    def __init__(self, client: GraphragAPI, column_widths: list[float]) -> None:
 | 
			
		||||
        self.client = client
 | 
			
		||||
        self.containers = client.get_storage_container_names()
 | 
			
		||||
        self.column_widths = column_widths
 | 
			
		||||
 | 
			
		||||
    def storage_data_step(self):
 | 
			
		||||
        """
 | 
			
		||||
        Builds the Storage Data Step for the Indexing Pipeline.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        disable_other_input = False
 | 
			
		||||
        _, col2, _ = st.columns(self.column_widths)
 | 
			
		||||
 | 
			
		||||
        with col2:
 | 
			
		||||
            st.header(
 | 
			
		||||
                "1. Data Storage",
 | 
			
		||||
                divider=True,
 | 
			
		||||
                help="Select a Data Storage Container to upload data to or select an existing container to use for indexing. The data will be processed by the LLM to create a Knowledge Graph.",
 | 
			
		||||
            )
 | 
			
		||||
            select_storage_name = st.selectbox(
 | 
			
		||||
                label="Select an existing Storage Container.",
 | 
			
		||||
                options=[""] + self.containers
 | 
			
		||||
                if isinstance(self.containers, list)
 | 
			
		||||
                else [],
 | 
			
		||||
                key="index-storage",
 | 
			
		||||
                index=0,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if select_storage_name != "":
 | 
			
		||||
                disable_other_input = True
 | 
			
		||||
            st.write("Or...")
 | 
			
		||||
            with st.expander("Upload data to a storage container."):
 | 
			
		||||
                # TODO: validate storage container name before uploading
 | 
			
		||||
                # TODO: add user message that option not available while existing storage container is selected
 | 
			
		||||
                upload_files(
 | 
			
		||||
                    self.client,
 | 
			
		||||
                    key_prefix="index",
 | 
			
		||||
                    disable_other_input=disable_other_input,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                if select_storage_name != "":
 | 
			
		||||
                    disable_other_input = True
 | 
			
		||||
 | 
			
		||||
    def build_index_step(self):
 | 
			
		||||
        """
 | 
			
		||||
        Creates the Build Index Step for the Indexing Pipeline.
 | 
			
		||||
        """
 | 
			
		||||
        _, col2, _ = st.columns(self.column_widths)
 | 
			
		||||
        with col2:
 | 
			
		||||
            st.header(
 | 
			
		||||
                "2. Build Index",
 | 
			
		||||
                divider=True,
 | 
			
		||||
                help="Building an index will process the data from step 1 and create a Knowledge Graph suitable for querying. The LLM will use either the default prompt configuration or the prompts that you generated previously. To track the status of an indexing job, use the check index status below.",
 | 
			
		||||
            )
 | 
			
		||||
            # use data from either the selected storage container or the uploaded data
 | 
			
		||||
            select_storage_name = st.session_state["index-storage"]
 | 
			
		||||
            input_storage_name = (
 | 
			
		||||
                st.session_state["index-storage-name-input"]
 | 
			
		||||
                if st.session_state["index-upload-button"]
 | 
			
		||||
                else ""
 | 
			
		||||
            )
 | 
			
		||||
            storage_selection = select_storage_name or input_storage_name
 | 
			
		||||
 | 
			
		||||
            # Allow user to choose either default or custom prompts
 | 
			
		||||
            custom_prompts = any([st.session_state[k.value] for k in PromptKeys])
 | 
			
		||||
            prompt_options = ["Default", "Custom"] if custom_prompts else ["Default"]
 | 
			
		||||
            prompt_choice = st.radio(
 | 
			
		||||
                "Choose LLM Prompt Configuration",
 | 
			
		||||
                options=prompt_options,
 | 
			
		||||
                index=1 if custom_prompts else 0,
 | 
			
		||||
                key="prompt-config-choice",
 | 
			
		||||
                horizontal=True,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Create new index name
 | 
			
		||||
            index_name = st.text_input("Enter Index Name", key="index-name-input")
 | 
			
		||||
 | 
			
		||||
            st.write(f"Selected Storage Container: **:blue[{storage_selection}]**")
 | 
			
		||||
            if st.button(
 | 
			
		||||
                "Build Index",
 | 
			
		||||
                help="You must enter both an Index Name and Select a Storage Container to enable this button",
 | 
			
		||||
                disabled=not index_name or not storage_selection,
 | 
			
		||||
            ):
 | 
			
		||||
                entity_prompt = (
 | 
			
		||||
                    StringIO(st.session_state[PromptKeys.ENTITY.value])
 | 
			
		||||
                    if prompt_choice == "Custom"
 | 
			
		||||
                    else None
 | 
			
		||||
                )
 | 
			
		||||
                summarize_prompt = (
 | 
			
		||||
                    StringIO(st.session_state[PromptKeys.SUMMARY.value])
 | 
			
		||||
                    if prompt_choice == "Custom"
 | 
			
		||||
                    else None
 | 
			
		||||
                )
 | 
			
		||||
                community_prompt = (
 | 
			
		||||
                    StringIO(st.session_state[PromptKeys.COMMUNITY.value])
 | 
			
		||||
                    if prompt_choice == "Custom"
 | 
			
		||||
                    else None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                response = self.client.build_index(
 | 
			
		||||
                    storage_name=storage_selection,
 | 
			
		||||
                    index_name=index_name,
 | 
			
		||||
                    entity_extraction_prompt_filepath=entity_prompt,
 | 
			
		||||
                    summarize_description_prompt_filepath=summarize_prompt,
 | 
			
		||||
                    community_prompt_filepath=community_prompt,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                if response.status_code == 200:
 | 
			
		||||
                    st.success(
 | 
			
		||||
                        f"Job submitted successfully, using {prompt_choice} prompts!"
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    st.error(
 | 
			
		||||
                        f"Failed to submit job.\nStatus: {response.json()['detail']}"
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    def check_status_step(self):
 | 
			
		||||
        """
 | 
			
		||||
        Checks the progress of a running indexing job.
 | 
			
		||||
        """
 | 
			
		||||
        _, col2, _ = st.columns(self.column_widths)
 | 
			
		||||
        with col2:
 | 
			
		||||
            st.header(
 | 
			
		||||
                "3. Check Index Status",
 | 
			
		||||
                divider=True,
 | 
			
		||||
                help="Select an index to check the status of what stage indexing is in. Indexing must be complete in order to be able to execute queries.",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            options_indexes = self.client.get_index_names()
 | 
			
		||||
            # create logic for defaulting to running job index if one exists
 | 
			
		||||
            new_index_name = st.session_state["index-name-input"]
 | 
			
		||||
            default_index = (
 | 
			
		||||
                options_indexes.index(new_index_name)
 | 
			
		||||
                if new_index_name in options_indexes
 | 
			
		||||
                else 0
 | 
			
		||||
            )
 | 
			
		||||
            index_name_select = st.selectbox(
 | 
			
		||||
                label="Select an index to check its status.",
 | 
			
		||||
                options=options_indexes if any(options_indexes) else [],
 | 
			
		||||
                index=default_index,
 | 
			
		||||
            )
 | 
			
		||||
            progress_bar = st.progress(0, text="Index Job Progress")
 | 
			
		||||
            if st.button("Check Status"):
 | 
			
		||||
                status_response = self.client.check_index_status(index_name_select)
 | 
			
		||||
                if status_response.status_code == 200:
 | 
			
		||||
                    status_response_text = status_response.json()
 | 
			
		||||
                    if status_response_text["status"] != "":
 | 
			
		||||
                        try:
 | 
			
		||||
                            # build status message
 | 
			
		||||
                            job_status = status_response_text["status"]
 | 
			
		||||
                            status_message = f"Status: {status_response_text['status']}"
 | 
			
		||||
                            st.success(status_message) if job_status in [
 | 
			
		||||
                                "running",
 | 
			
		||||
                                "complete",
 | 
			
		||||
                            ] else st.warning(status_message)
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            print(e)
 | 
			
		||||
                        try:
 | 
			
		||||
                            # build percent complete message
 | 
			
		||||
                            percent_complete = status_response_text["percent_complete"]
 | 
			
		||||
                            progress_bar.progress(float(percent_complete) / 100)
 | 
			
		||||
                            completion_message = (
 | 
			
		||||
                                f"Percent Complete: {percent_complete}% "
 | 
			
		||||
                            )
 | 
			
		||||
                            st.warning(
 | 
			
		||||
                                completion_message
 | 
			
		||||
                            ) if percent_complete < 100 else st.success(
 | 
			
		||||
                                completion_message
 | 
			
		||||
                            )
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            print(e)
 | 
			
		||||
                        try:
 | 
			
		||||
                            # build progress message
 | 
			
		||||
                            progress_status = status_response_text["progress"]
 | 
			
		||||
                            progress_status = (
 | 
			
		||||
                                progress_status if progress_status else "N/A"
 | 
			
		||||
                            )
 | 
			
		||||
                            progress_message = f"Progress: {progress_status}"
 | 
			
		||||
                            st.success(
 | 
			
		||||
                                progress_message
 | 
			
		||||
                            ) if progress_status != "N/A" else st.warning(
 | 
			
		||||
                                progress_message
 | 
			
		||||
                            )
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            print(e)
 | 
			
		||||
                    else:
 | 
			
		||||
                        st.warning(
 | 
			
		||||
                            f"No status information available for this index: {index_name_select}"
 | 
			
		||||
                        )
 | 
			
		||||
                else:
 | 
			
		||||
                    st.warning(
 | 
			
		||||
                        f"No workflow information available for this index: {index_name_select}"
 | 
			
		||||
                    )
 | 
			
		||||
							
								
								
									
										39
									
								
								frontend/src/components/login_sidebar.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								frontend/src/components/login_sidebar.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,39 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
import streamlit as st
 | 
			
		||||
 | 
			
		||||
from src.enums import EnvVars
 | 
			
		||||
from src.graphrag_api import GraphragAPI
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def login():
 | 
			
		||||
    """
 | 
			
		||||
    Login component that displays in the sidebar.  Requires the user to enter
 | 
			
		||||
    the APIM Gateway URL and Subscription Key to login.  After entering user
 | 
			
		||||
    credentials, a simple health check call is made to the GraphRAG API.
 | 
			
		||||
    """
 | 
			
		||||
    with st.sidebar:
 | 
			
		||||
        st.title(
 | 
			
		||||
            "Login",
 | 
			
		||||
            help="Enter your APIM credentials to get started.  Refreshing the browser will require you to login again.",
 | 
			
		||||
        )
 | 
			
		||||
        with st.form(key="login-form", clear_on_submit=True):
 | 
			
		||||
            apim_url = st.text_input("APIM Gateway URL", key="apim-url")
 | 
			
		||||
            apim_sub_key = st.text_input(
 | 
			
		||||
                "APIM Subscription Key", key="subscription-key"
 | 
			
		||||
            )
 | 
			
		||||
            form_submit = st.form_submit_button("Login")
 | 
			
		||||
            if form_submit:
 | 
			
		||||
                client = GraphragAPI(apim_url, apim_sub_key)
 | 
			
		||||
                status_code = client.health_check()
 | 
			
		||||
                if status_code == 200:
 | 
			
		||||
                    st.success("Login Successful")
 | 
			
		||||
                    st.session_state[EnvVars.DEPLOYMENT_URL.value] = apim_url
 | 
			
		||||
                    st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = apim_sub_key
 | 
			
		||||
                    st.session_state["initialized"] = True
 | 
			
		||||
                    st.rerun()
 | 
			
		||||
                else:
 | 
			
		||||
                    st.error("Login Failed")
 | 
			
		||||
                    st.error("Please check the APIM Gateway URL and Subscription Key")
 | 
			
		||||
                    return status_code
 | 
			
		||||
							
								
								
									
										89
									
								
								frontend/src/components/prompt_configuration.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								frontend/src/components/prompt_configuration.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,89 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import streamlit as st
 | 
			
		||||
 | 
			
		||||
from src.enums import PromptFileNames, PromptKeys, PromptTextAreas
 | 
			
		||||
from src.functions import zip_directory
 | 
			
		||||
 | 
			
		||||
SAVED_PROMPT_VAR = "saved_prompts"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_prompts(
 | 
			
		||||
    local_dir: str = "./edited_prompts/", zip_file_path: str = "edited_prompts.zip"
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Save prompts in memory and on disk as a zip file
 | 
			
		||||
    """
 | 
			
		||||
    st.session_state[SAVED_PROMPT_VAR] = True
 | 
			
		||||
    st.session_state[PromptKeys.ENTITY.value] = st.session_state[
 | 
			
		||||
        PromptTextAreas.ENTITY.value
 | 
			
		||||
    ]
 | 
			
		||||
    st.session_state[PromptKeys.SUMMARY.value] = st.session_state[
 | 
			
		||||
        PromptTextAreas.SUMMARY.value
 | 
			
		||||
    ]
 | 
			
		||||
    st.session_state[PromptKeys.COMMUNITY.value] = st.session_state[
 | 
			
		||||
        PromptTextAreas.COMMUNITY.value
 | 
			
		||||
    ]
 | 
			
		||||
    os.makedirs(local_dir, exist_ok=True)
 | 
			
		||||
    for key, filename in zip(PromptKeys, PromptFileNames):
 | 
			
		||||
        outpath = os.path.join(local_dir, filename.value)
 | 
			
		||||
        with open(outpath, "w", encoding="utf-8") as f:
 | 
			
		||||
            f.write(st.session_state[key.value])
 | 
			
		||||
    zip_directory(local_dir, zip_file_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def edit_prompts():
 | 
			
		||||
    """
 | 
			
		||||
    Re-edit the prompts
 | 
			
		||||
    """
 | 
			
		||||
    st.session_state[SAVED_PROMPT_VAR] = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prompt_editor(prompt_values: list[str]):
 | 
			
		||||
    """
 | 
			
		||||
    Container for prompt configurations
 | 
			
		||||
    """
 | 
			
		||||
    saved_prompts = st.session_state[SAVED_PROMPT_VAR]
 | 
			
		||||
 | 
			
		||||
    entity_ext_prompt, summ_prompt, comm_report_prompt = prompt_values
 | 
			
		||||
 | 
			
		||||
    with st.container(border=True):
 | 
			
		||||
        tab_labels = [
 | 
			
		||||
            "**Entity Extraction**",
 | 
			
		||||
            "**Summarize Descriptions**",
 | 
			
		||||
            "**Community Reports**",
 | 
			
		||||
        ]
 | 
			
		||||
        # subheaders = [f"{tab_label} Prompt" for tab_label in tab_labels]
 | 
			
		||||
        tab1, tab2, tab3 = st.tabs(tabs=tab_labels)
 | 
			
		||||
        with tab1:
 | 
			
		||||
            st.text_area(
 | 
			
		||||
                label="Entity Prompt",
 | 
			
		||||
                value=entity_ext_prompt,
 | 
			
		||||
                max_chars=20000,
 | 
			
		||||
                key="entity_text_area",
 | 
			
		||||
                label_visibility="hidden",
 | 
			
		||||
                disabled=saved_prompts,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        with tab2:
 | 
			
		||||
            st.text_area(
 | 
			
		||||
                label="Summarize Prompt",
 | 
			
		||||
                value=summ_prompt,
 | 
			
		||||
                max_chars=20000,
 | 
			
		||||
                key="summary_text_area",
 | 
			
		||||
                label_visibility="hidden",
 | 
			
		||||
                disabled=saved_prompts,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        with tab3:
 | 
			
		||||
            st.text_area(
 | 
			
		||||
                label="Community Reports Prompt",
 | 
			
		||||
                value=comm_report_prompt,
 | 
			
		||||
                max_chars=20000,
 | 
			
		||||
                key="community_text_area",
 | 
			
		||||
                label_visibility="hidden",
 | 
			
		||||
                disabled=saved_prompts,
 | 
			
		||||
            )
 | 
			
		||||
							
								
								
									
										274
									
								
								frontend/src/components/query.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										274
									
								
								frontend/src/components/query.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,274 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
from typing import Literal
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pandas as pd
 | 
			
		||||
import requests
 | 
			
		||||
import streamlit as st
 | 
			
		||||
 | 
			
		||||
from src.graphrag_api import GraphragAPI
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GraphQuery:
 | 
			
		||||
    KILOBYTE = 1024
 | 
			
		||||
 | 
			
		||||
    def __init__(self, client: GraphragAPI):
 | 
			
		||||
        self.client = client
 | 
			
		||||
 | 
			
		||||
    def search(
 | 
			
		||||
        self,
 | 
			
		||||
        query_type: Literal["Global Streaming", "Global", "Local"],
 | 
			
		||||
        search_index: str | list[str],
 | 
			
		||||
        query: str,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        idler_message_list = [
 | 
			
		||||
            "Querying the graph...",
 | 
			
		||||
            "Processing the query...",
 | 
			
		||||
            "The graph is working hard...",
 | 
			
		||||
            "Fetching the results...",
 | 
			
		||||
            "Reticulating splines...",
 | 
			
		||||
            "Almost there...",
 | 
			
		||||
            "The report format is customizable, for this demo we report back in executive summary format. It's prompt driven to change as you like!",
 | 
			
		||||
            "Just a few more seconds...",
 | 
			
		||||
            "You probably know these messages are just for fun...",
 | 
			
		||||
            "In the meantime, here's a fun fact: Did you know that the Microsoft GraphRAG Copilot is built on top of the Microsoft GraphRAG Solution Accelerator?",
 | 
			
		||||
            "The average graph query processes several textbooks worth of information to get you your answer.  I hope it was a good question!",
 | 
			
		||||
            "Shamelessly buying time...",
 | 
			
		||||
            "When the answer comes, make sure to check the context reports, the detail there is incredible!",
 | 
			
		||||
            "When we ingest data into the graph, the structure of language itself is used to create the graph structure. It's like a language-based neural network, using neural networks to understand language to network. It's a network-ception!",
 | 
			
		||||
            "The answers will come eventually, I promise.  In the meantime, I recommend a doppio espresso, or a nice cup of tea.  Or both!  The GraphRAG team runs on caffeine.",
 | 
			
		||||
            "The graph is a complex structure, but it's working hard to get you the answer you need.",
 | 
			
		||||
            "GraphRAG is step one in a long journey of understanding the world through language.  It's a big step, but there's so much more to come.",
 | 
			
		||||
            "The results are on their way...",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        message = np.random.choice(idler_message_list)
 | 
			
		||||
        with st.spinner(text=message):
 | 
			
		||||
            try:
 | 
			
		||||
                match query_type:
 | 
			
		||||
                    case "Global Streaming":
 | 
			
		||||
                        _ = self.global_streaming_search(search_index, query)
 | 
			
		||||
                    case "Global":
 | 
			
		||||
                        _ = self.global_search(search_index, query)
 | 
			
		||||
                    case "Local":
 | 
			
		||||
                        _ = self.local_search(search_index, query)
 | 
			
		||||
 | 
			
		||||
            except requests.exceptions.RequestException as e:
 | 
			
		||||
                st.error(f"Error with query {query_type}: {str(e)}")
 | 
			
		||||
 | 
			
		||||
    def global_streaming_search(
 | 
			
		||||
        self, search_index: str | list[str], query: str
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Executes a global streaming query on the specified index.
 | 
			
		||||
        Handles the response and displays the results in the Streamlit app.
 | 
			
		||||
        """
 | 
			
		||||
        query_response = self.client.global_streaming_query(search_index, query)
 | 
			
		||||
        assistant_response = ""
 | 
			
		||||
        context_list = []
 | 
			
		||||
 | 
			
		||||
        if query_response.status_code == 200:
 | 
			
		||||
            text_placeholder = st.empty()
 | 
			
		||||
            for chunk in query_response.iter_lines(
 | 
			
		||||
                # allow up to 256KB to avoid excessive many reads
 | 
			
		||||
                chunk_size=256 * GraphQuery.KILOBYTE,
 | 
			
		||||
                decode_unicode=True,
 | 
			
		||||
            ):
 | 
			
		||||
                try:
 | 
			
		||||
                    payload = json.loads(chunk)
 | 
			
		||||
                except json.JSONDecodeError as e:
 | 
			
		||||
                    # In the event that a chunk is not a complete JSON object,
 | 
			
		||||
                    # document it for further analysis.
 | 
			
		||||
                    print(chunk)
 | 
			
		||||
                    raise e
 | 
			
		||||
 | 
			
		||||
                token = payload["token"]
 | 
			
		||||
                context = payload["context"]
 | 
			
		||||
                if (token != "<EOM>") and (context is None):
 | 
			
		||||
                    assistant_response += token
 | 
			
		||||
                    text_placeholder.write(assistant_response)
 | 
			
		||||
                elif (token == "<EOM>") and (context is not None):
 | 
			
		||||
                    context_list.append(context)
 | 
			
		||||
 | 
			
		||||
            if not assistant_response:
 | 
			
		||||
                st.write(
 | 
			
		||||
                    self.format_md_text(
 | 
			
		||||
                        "Not enough contextual data to support your query: No results found.\tTry another query.",
 | 
			
		||||
                        "red",
 | 
			
		||||
                        True,
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                return
 | 
			
		||||
            else:
 | 
			
		||||
                with self._create_section_expander("Query Context"):
 | 
			
		||||
                    st.write(
 | 
			
		||||
                        self.format_md_text(
 | 
			
		||||
                            "Double-click on content to expand text", "red", False
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                    self._build_st_dataframe(context_list)
 | 
			
		||||
        else:
 | 
			
		||||
            print(query_response.reason, query_response.content)
 | 
			
		||||
            raise Exception("Received unexpected response from server")
 | 
			
		||||
 | 
			
		||||
    def global_search(self, search_index: str | list[str], query: str) -> None:
 | 
			
		||||
        query_response = self.client.query_index(
 | 
			
		||||
            index_name=search_index, query_type="Global", query=query
 | 
			
		||||
        )
 | 
			
		||||
        if query_response["result"] != "":
 | 
			
		||||
            with self._create_section_expander("Query Response", "black", True, True):
 | 
			
		||||
                st.write(query_response["result"])
 | 
			
		||||
            with self._create_section_expander("Query Context"):
 | 
			
		||||
                st.write(
 | 
			
		||||
                    self.format_md_text(
 | 
			
		||||
                        "Double-click on content to expand text", "red", False
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                self._build_st_dataframe(query_response["context_data"]["reports"])
 | 
			
		||||
 | 
			
		||||
    def local_search(self, search_index: str | list[str], query: str) -> None:
 | 
			
		||||
        query_response = self.client.query_index(
 | 
			
		||||
            index_name=search_index, query_type="Local", query=query
 | 
			
		||||
        )
 | 
			
		||||
        results = query_response["result"]
 | 
			
		||||
        if results != "":
 | 
			
		||||
            with self._create_section_expander("Query Response", "black", True, True):
 | 
			
		||||
                st.write(results)
 | 
			
		||||
 | 
			
		||||
        context_data = query_response["context_data"]
 | 
			
		||||
        reports = context_data["reports"]
 | 
			
		||||
        entities = context_data["entities"]
 | 
			
		||||
        relationships = context_data["relationships"]
 | 
			
		||||
        # sources = context_data["sources"]
 | 
			
		||||
 | 
			
		||||
        if any(reports):
 | 
			
		||||
            with self._create_section_expander("Query Context"):
 | 
			
		||||
                st.write(
 | 
			
		||||
                    self.format_md_text(
 | 
			
		||||
                        "Double-click on content to expand text", "red", False
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                self._build_st_dataframe(reports)
 | 
			
		||||
 | 
			
		||||
        if any(entities):
 | 
			
		||||
            with st.spinner("Loading context entities..."):
 | 
			
		||||
                with self._create_section_expander("Context Entities"):
 | 
			
		||||
                    df_entities = pd.DataFrame(entities)
 | 
			
		||||
                    self._build_st_dataframe(df_entities, entity_df=True)
 | 
			
		||||
 | 
			
		||||
                # TODO: Fix the next portion of code to provide a more granular entity view
 | 
			
		||||
                # for report in entities:
 | 
			
		||||
                #     entity_response = get_source_entity(
 | 
			
		||||
                #         report["index_name"], report["id"], self.api_url, self.headers
 | 
			
		||||
                #     )
 | 
			
		||||
                #     for unit in entity_response["text_units"]:
 | 
			
		||||
                #         response = requests.get(
 | 
			
		||||
                #             f"{self.api_url}/source/text/{report['index_name']}/{unit}",
 | 
			
		||||
                #             headers=self.headers,
 | 
			
		||||
                #         )
 | 
			
		||||
                #         text_info = response.json()
 | 
			
		||||
                #         if text_info is not None:
 | 
			
		||||
                #             with st.expander(
 | 
			
		||||
                #                 f" Entity: {report['entity']} - Source Document: {text_info['source_document']} "
 | 
			
		||||
                #             ):
 | 
			
		||||
                #                 st.write(text_info["text"])
 | 
			
		||||
 | 
			
		||||
        if any(relationships):
 | 
			
		||||
            with st.spinner("Loading context relationships..."):
 | 
			
		||||
                with self._create_section_expander("Context Relationships"):
 | 
			
		||||
                    df_relationships = pd.DataFrame(relationships)
 | 
			
		||||
                    self._build_st_dataframe(df_relationships, rel_df=True)
 | 
			
		||||
 | 
			
		||||
                    # TODO: Fix the next portion of code to provide a more granular relationship view
 | 
			
		||||
                    # for report in query_response["context_data"][
 | 
			
		||||
                    #     "relationships"
 | 
			
		||||
                    # ][:15]:
 | 
			
		||||
                    #     # with st.expander(
 | 
			
		||||
                    #     #     f"Source: {report['source']} Target: {report['target']} Rank: {report['rank']}"
 | 
			
		||||
                    #     # ):
 | 
			
		||||
                    #     # st.write(report["description"])
 | 
			
		||||
                    #     relationship_data = requests.get(
 | 
			
		||||
                    #         f"{self.api_url}/source/relationship/{report['index_name']}/{report['id']}",
 | 
			
		||||
                    #         headers=self.headers,
 | 
			
		||||
                    #     )
 | 
			
		||||
                    #     relationship_data = relationship_data.json()
 | 
			
		||||
                    #     for unit in relationship_data["text_units"]:
 | 
			
		||||
                    #         response = requests.get(
 | 
			
		||||
                    #             f"{self.api_url}/source/text/{report['index_name']}/{unit}",
 | 
			
		||||
                    #             headers=self.headers,
 | 
			
		||||
                    #         )
 | 
			
		||||
                    #         text_info_rel = response.json()
 | 
			
		||||
                    #         df_textinfo_rel = pd.DataFrame([text_info_rel])
 | 
			
		||||
                    #         with st.expander(
 | 
			
		||||
                    #             f"Source: {report['source']} Target: {report['target']} - Source Document: {sources['source_document']} "
 | 
			
		||||
                    #         ):
 | 
			
		||||
                    #             st.write(sources["text"])
 | 
			
		||||
                    #             st.dataframe(
 | 
			
		||||
                    #                 df_textinfo_rel, use_container_width=True
 | 
			
		||||
                    #             )
 | 
			
		||||
 | 
			
		||||
    def _build_st_dataframe(
 | 
			
		||||
        self,
 | 
			
		||||
        data: dict | pd.DataFrame,
 | 
			
		||||
        drop_columns: list[str] = ["id", "index_id", "index_name", "in_context"],
 | 
			
		||||
        entity_df: bool = False,
 | 
			
		||||
        rel_df: bool = False,
 | 
			
		||||
    ) -> st.dataframe:
 | 
			
		||||
        df_context = (
 | 
			
		||||
            data if isinstance(data, pd.DataFrame) else pd.DataFrame.from_records(data)
 | 
			
		||||
        )
 | 
			
		||||
        if any(drop_columns):
 | 
			
		||||
            for column in drop_columns:
 | 
			
		||||
                if column in df_context.columns:
 | 
			
		||||
                    df_context = df_context.drop(column, axis=1)
 | 
			
		||||
        if entity_df:
 | 
			
		||||
            return st.dataframe(
 | 
			
		||||
                df_context,
 | 
			
		||||
                use_container_width=True,
 | 
			
		||||
                column_config={
 | 
			
		||||
                    "entity": "Entity",
 | 
			
		||||
                    "description": "Description",
 | 
			
		||||
                    "number of relationships": "Number of Relationships",
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
        if rel_df:
 | 
			
		||||
            return st.dataframe(
 | 
			
		||||
                df_context,
 | 
			
		||||
                use_container_width=True,
 | 
			
		||||
                column_config={
 | 
			
		||||
                    "source": "Source",
 | 
			
		||||
                    "target": "Target",
 | 
			
		||||
                    "description": "Description",
 | 
			
		||||
                    "weight": "Weight",
 | 
			
		||||
                    "rank": "Rank",
 | 
			
		||||
                    "links": "Links",
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
        return st.dataframe(
 | 
			
		||||
            df_context,
 | 
			
		||||
            use_container_width=True,
 | 
			
		||||
            column_config={
 | 
			
		||||
                "title": "Report Title",
 | 
			
		||||
                "content": "Report Content",
 | 
			
		||||
                "rank": "Rank",
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def format_md_text(self, text: str, color: str, bold: bool) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Formats text for display in Streamlit app using Markdown syntax.
 | 
			
		||||
        """
 | 
			
		||||
        if bold:
 | 
			
		||||
            return f":{color}[**{text}**]"
 | 
			
		||||
        return f":{color}[{text}]"
 | 
			
		||||
 | 
			
		||||
    def _create_section_expander(
 | 
			
		||||
        self, title: str, color: str = "blue", bold: bool = True, expanded: bool = False
 | 
			
		||||
    ) -> st.expander:
 | 
			
		||||
        """
 | 
			
		||||
        Creates an expander in the Streamlit app with the specified title and content.
 | 
			
		||||
        """
 | 
			
		||||
        return st.expander(self.format_md_text(title, color, bold), expanded=expanded)
 | 
			
		||||
							
								
								
									
										275
									
								
								frontend/src/components/tabs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										275
									
								
								frontend/src/components/tabs.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,275 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from time import sleep
 | 
			
		||||
 | 
			
		||||
import streamlit as st
 | 
			
		||||
 | 
			
		||||
from src.components.index_pipeline import IndexPipeline
 | 
			
		||||
from src.components.login_sidebar import login
 | 
			
		||||
from src.components.prompt_configuration import (
 | 
			
		||||
    edit_prompts,
 | 
			
		||||
    prompt_editor,
 | 
			
		||||
    save_prompts,
 | 
			
		||||
)
 | 
			
		||||
from src.components.query import GraphQuery
 | 
			
		||||
from src.components.upload_files_component import upload_files
 | 
			
		||||
from src.enums import PromptKeys
 | 
			
		||||
from src.functions import generate_and_extract_prompts
 | 
			
		||||
from src.graphrag_api import GraphragAPI
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_main_tab(initialized: bool) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Displays content of Main Tab
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    url = "https://github.com/Azure-Samples/graphrag-accelerator/blob/main/TRANSPARENCY.md"
 | 
			
		||||
    content = f"""
 | 
			
		||||
    ##  Welcome to GraphRAG!
 | 
			
		||||
    Diving into complex information and uncovering semantic relationships utilizing generative AI has never been easier.
 | 
			
		||||
    Here's how you can get started with just a few clicks:
 | 
			
		||||
    - **PROMPT GENERATION:** (*Optional Step*)
 | 
			
		||||
        1. Generate fine-tuned prompts for graphrag customized to your data and domain.
 | 
			
		||||
        2. Select an existing Storage Container and click "Generate Prompts".
 | 
			
		||||
    - **PROMPT CONFIGURATION:** (*Optional Step*)
 | 
			
		||||
        1. Edit the generated prompts to best suit your needs.
 | 
			
		||||
        2. Once you are finished editing, click the "Save Prompts" button.
 | 
			
		||||
        3. Saving the prompts will store them for use in the follow-on Indexing step.
 | 
			
		||||
        4. You can also download the edited prompts for future reference.
 | 
			
		||||
    - **INDEXING:**
 | 
			
		||||
        1. Select an existing data storage container or upload data, to Index
 | 
			
		||||
        2. Name your index and select "Build Index" to begin building a GraphRAG Index.
 | 
			
		||||
        3. Check the status of the index as the job progresses.
 | 
			
		||||
    - **QUERYING:**
 | 
			
		||||
        1. Choose an existing index
 | 
			
		||||
        2. Specify a query type
 | 
			
		||||
        3. Click "Query" button to search and view insights.
 | 
			
		||||
 | 
			
		||||
    [GraphRAG]({url}) combines the power of RAG with a Graph structure, giving you insights at your fingertips.
 | 
			
		||||
    """
 | 
			
		||||
    # Display text in the gray box
 | 
			
		||||
    st.markdown(content, unsafe_allow_html=False)
 | 
			
		||||
    if not initialized:
 | 
			
		||||
        login()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_prompt_generation_tab(
 | 
			
		||||
    client: GraphragAPI,
 | 
			
		||||
    column_widths: list[float],
 | 
			
		||||
    num_chunks: int = 5,
 | 
			
		||||
) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Displays content of Prompt Generation Tab
 | 
			
		||||
    """
 | 
			
		||||
    # hard set limit to 5 files to reduce overly long processing times and to reduce over sampling errors.
 | 
			
		||||
    num_chunks = num_chunks if num_chunks <= 5 else 5
 | 
			
		||||
    _, col2, _ = st.columns(column_widths)
 | 
			
		||||
    with col2:
 | 
			
		||||
        st.header(
 | 
			
		||||
            "Generate Prompts (optional)",
 | 
			
		||||
            divider=True,
 | 
			
		||||
            help="Generate fine-tuned prompts for graphrag tailored to your data and domain.",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        st.write(
 | 
			
		||||
            "Select a storage container that contains your data. GraphRAG will use this data to generate domain-specific prompts for follow-on indexing."
 | 
			
		||||
        )
 | 
			
		||||
        storage_containers = client.get_storage_container_names()
 | 
			
		||||
 | 
			
		||||
        # if no storage containers, allow user to upload files
 | 
			
		||||
        if isinstance(storage_containers, list) and not (any(storage_containers)):
 | 
			
		||||
            st.warning(
 | 
			
		||||
                "No existing Storage Containers found. Please upload data to continue."
 | 
			
		||||
            )
 | 
			
		||||
            uploaded = upload_files(client, key_prefix="prompts-upload-1")
 | 
			
		||||
            if uploaded:
 | 
			
		||||
                # brief pause to allow success message to display
 | 
			
		||||
                sleep(1.5)
 | 
			
		||||
                st.rerun()
 | 
			
		||||
        else:
 | 
			
		||||
            select_prompt_storage = st.selectbox(
 | 
			
		||||
                "Select an existing Storage Container.",
 | 
			
		||||
                options=[""] + storage_containers
 | 
			
		||||
                if isinstance(storage_containers, list)
 | 
			
		||||
                else [],
 | 
			
		||||
                key="prompt-storage",
 | 
			
		||||
                index=0,
 | 
			
		||||
            )
 | 
			
		||||
            disable_other_input = True if select_prompt_storage != "" else False
 | 
			
		||||
            with st.expander("I want to upload new data...", expanded=False):
 | 
			
		||||
                new_upload = upload_files(
 | 
			
		||||
                    client,
 | 
			
		||||
                    key_prefix="prompts-upload-2",
 | 
			
		||||
                    disable_other_input=disable_other_input,
 | 
			
		||||
                )
 | 
			
		||||
                if new_upload:
 | 
			
		||||
                    # brief pause to allow success message to display
 | 
			
		||||
                    st.session_state["new_upload"] = True
 | 
			
		||||
                    sleep(1.5)
 | 
			
		||||
                    st.rerun()
 | 
			
		||||
            if st.session_state["new_upload"] and not select_prompt_storage:
 | 
			
		||||
                st.warning(
 | 
			
		||||
                    "Please select the newly uploaded Storage Container to continue."
 | 
			
		||||
                )
 | 
			
		||||
            st.write(f"**Selected Storage Container:** :blue[{select_prompt_storage}]")
 | 
			
		||||
            triggered = st.button(
 | 
			
		||||
                label="Generate Prompts",
 | 
			
		||||
                key="prompt-generation",
 | 
			
		||||
                help="Select either an existing Storage Container or upload new data to enable this button.\n\
 | 
			
		||||
                Then, click to generate custom prompts for the LLM.",
 | 
			
		||||
                disabled=not select_prompt_storage,
 | 
			
		||||
            )
 | 
			
		||||
            if triggered:
 | 
			
		||||
                with st.spinner("Generating LLM prompts for GraphRAG..."):
 | 
			
		||||
                    generated = generate_and_extract_prompts(
 | 
			
		||||
                        client=client,
 | 
			
		||||
                        storage_name=select_prompt_storage,
 | 
			
		||||
                        limit=num_chunks,
 | 
			
		||||
                    )
 | 
			
		||||
                    if not isinstance(generated, Exception):
 | 
			
		||||
                        st.success(
 | 
			
		||||
                            "Prompts generated successfully! Move on to the next tab to configure the prompts."
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
                        # assume limit parameter is too high
 | 
			
		||||
                        st.warning(
 | 
			
		||||
                            "You do not have enough data to generate prompts. Retrying with a smaller sample size."
 | 
			
		||||
                        )
 | 
			
		||||
                        while num_chunks > 1:
 | 
			
		||||
                            num_chunks -= 1
 | 
			
		||||
                            generated = generate_and_extract_prompts(
 | 
			
		||||
                                client=client,
 | 
			
		||||
                                storage_name=select_prompt_storage,
 | 
			
		||||
                                limit=num_chunks,
 | 
			
		||||
                            )
 | 
			
		||||
                            if not isinstance(generated, Exception):
 | 
			
		||||
                                st.success(
 | 
			
		||||
                                    "Prompts generated successfully! Move on to the next tab to configure the prompts."
 | 
			
		||||
                                )
 | 
			
		||||
                                break
 | 
			
		||||
                            else:
 | 
			
		||||
                                st.warning(f"Retrying with sample size: {num_chunks}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_prompt_configuration_tab(
 | 
			
		||||
    download_file_name: str = "edited_prompts.zip",
 | 
			
		||||
) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Displays content of Prompt Configuration Tab
 | 
			
		||||
    """
 | 
			
		||||
    st.header(
 | 
			
		||||
        "Configure Prompts (optional)",
 | 
			
		||||
        divider=True,
 | 
			
		||||
        help="Generate fine tuned prompts for the LLM specific to your data and domain.",
 | 
			
		||||
    )
 | 
			
		||||
    prompt_values = [st.session_state[k.value] for k in PromptKeys]
 | 
			
		||||
 | 
			
		||||
    if any(prompt_values):
 | 
			
		||||
        prompt_editor([prompt_values[0], prompt_values[1], prompt_values[2]])
 | 
			
		||||
        col1, col2, col3 = st.columns(3, gap="large")
 | 
			
		||||
        with col1:
 | 
			
		||||
            clicked = st.button(
 | 
			
		||||
                "Save Prompts",
 | 
			
		||||
                help="Save the edited prompts for use with the follow-on indexing step. This button must be clicked to enable downloading the prompts.",
 | 
			
		||||
                type="primary",
 | 
			
		||||
                key="save-prompt-button",
 | 
			
		||||
                on_click=save_prompts,
 | 
			
		||||
                kwargs={"zip_file_path": download_file_name},
 | 
			
		||||
            )
 | 
			
		||||
        with col2:
 | 
			
		||||
            st.button(
 | 
			
		||||
                "Edit Prompts",
 | 
			
		||||
                help="Allows user to re-edit the prompts after saving.",
 | 
			
		||||
                type="primary",
 | 
			
		||||
                key="edit-prompt-button",
 | 
			
		||||
                on_click=edit_prompts,
 | 
			
		||||
            )
 | 
			
		||||
        with col3:
 | 
			
		||||
            if os.path.exists(download_file_name):
 | 
			
		||||
                with open(download_file_name, "rb") as fp:
 | 
			
		||||
                    st.download_button(
 | 
			
		||||
                        "Download Prompts",
 | 
			
		||||
                        data=fp,
 | 
			
		||||
                        file_name=download_file_name,
 | 
			
		||||
                        help="Downloads the saved prompts as a zip file containing three LLM prompts in .txt format.",
 | 
			
		||||
                        mime="application/zip",
 | 
			
		||||
                        type="primary",
 | 
			
		||||
                        disabled=not st.session_state["saved_prompts"],
 | 
			
		||||
                        key="download-prompt-button",
 | 
			
		||||
                    )
 | 
			
		||||
        if clicked:
 | 
			
		||||
            st.success(
 | 
			
		||||
                "Prompts saved successfully! Downloading prompts is now enabled."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_index_tab(indexPipe: IndexPipeline) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Displays content of Index tab
 | 
			
		||||
    """
 | 
			
		||||
    indexPipe.storage_data_step()
 | 
			
		||||
    indexPipe.build_index_step()
 | 
			
		||||
    indexPipe.check_status_step()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def execute_query(
 | 
			
		||||
    query_engine: GraphQuery, query_type: str, search_index: str | list[str], query: str
 | 
			
		||||
) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Executes the query on the selected index
 | 
			
		||||
    """
 | 
			
		||||
    if query:
 | 
			
		||||
        query_engine.search(
 | 
			
		||||
            query_type=query_type, search_index=search_index, query=query
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        return st.warning("Please enter a query to search.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_query_tab(client: GraphragAPI) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Displays content of Query Tab
 | 
			
		||||
    """
 | 
			
		||||
    gquery = GraphQuery(client)
 | 
			
		||||
    col1, col2 = st.columns(2)
 | 
			
		||||
    with col1:
 | 
			
		||||
        query_type = st.selectbox(
 | 
			
		||||
            "Query Type",
 | 
			
		||||
            ["Global Streaming", "Global", "Local"],
 | 
			
		||||
            help="Select the query type - Each yeilds different results of specificity. Global queries focus on the entire graph structure. Local queries focus on a set of communities (subgraphs) in the graph that are more connected to each other than they are to the rest of the graph structure and can focus on very specific entities in the graph. Global streaming is a global query that displays results as they appear live.",
 | 
			
		||||
        )
 | 
			
		||||
    with col2:
 | 
			
		||||
        search_indexes = client.get_index_names()
 | 
			
		||||
        if not any(search_indexes):
 | 
			
		||||
            st.warning("No indexes found. Please build an index to continue.")
 | 
			
		||||
        select_index_search = st.multiselect(
 | 
			
		||||
            label="Index",
 | 
			
		||||
            options=search_indexes if any(search_indexes) else [],
 | 
			
		||||
            key="multiselect-index-search",
 | 
			
		||||
            help="Select the index(es) to query. The selected index(es) must have a complete status in order to yield query results without error. Use Check Index Status to confirm status.",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    disabled = True if not any(select_index_search) else False
 | 
			
		||||
    col3, col4 = st.columns([0.8, 0.2])
 | 
			
		||||
 | 
			
		||||
    with col3:
 | 
			
		||||
        search_bar = st.text_input("Query", key="search-query", disabled=disabled)
 | 
			
		||||
    with col4:
 | 
			
		||||
        search_button = st.button("QUERY", type="primary", disabled=disabled)
 | 
			
		||||
 | 
			
		||||
    # defining a query variable enables the use of either the search bar or the search button to trigger the query
 | 
			
		||||
    query = st.session_state["search-query"]
 | 
			
		||||
    if len(query) > 5:
 | 
			
		||||
        if (search_bar and search_button) and any(select_index_search):
 | 
			
		||||
            execute_query(
 | 
			
		||||
                query_engine=gquery,
 | 
			
		||||
                query_type=query_type,
 | 
			
		||||
                search_index=select_index_search,
 | 
			
		||||
                query=query,
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        col1, col2 = st.columns([0.3, 0.7])
 | 
			
		||||
        with col1:
 | 
			
		||||
            st.warning("Cannot submit queries less than 6 characters in length.")
 | 
			
		||||
							
								
								
									
										56
									
								
								frontend/src/components/upload_files_component.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								frontend/src/components/upload_files_component.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,56 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
import streamlit as st
 | 
			
		||||
 | 
			
		||||
from src.graphrag_api import GraphragAPI
 | 
			
		||||
 | 
			
		||||
UPLOAD_HELP_MESSAGE = """
 | 
			
		||||
This functionality is disabled while an existing Storage Container is selected.
 | 
			
		||||
Please deselect the existing Storage Container to upload new data.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upload_files(
 | 
			
		||||
    client: GraphragAPI, key_prefix: str, disable_other_input: bool = False
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Reusable component to upload files to Blob Storage Container
 | 
			
		||||
    """
 | 
			
		||||
    input_storage_name = st.text_input(
 | 
			
		||||
        label="Enter Storage Name",
 | 
			
		||||
        key=f"{key_prefix}-storage-name-input",
 | 
			
		||||
        disabled=disable_other_input,
 | 
			
		||||
        help=UPLOAD_HELP_MESSAGE,
 | 
			
		||||
    )
 | 
			
		||||
    file_upload = st.file_uploader(
 | 
			
		||||
        "Upload Data",
 | 
			
		||||
        type=["txt"],
 | 
			
		||||
        key=f"{key_prefix}-file-uploader",
 | 
			
		||||
        accept_multiple_files=True,
 | 
			
		||||
        disabled=disable_other_input,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    uploaded = st.button(
 | 
			
		||||
        "Upload Files",
 | 
			
		||||
        key=f"{key_prefix}-upload-button",
 | 
			
		||||
        disabled=disable_other_input or input_storage_name == "",
 | 
			
		||||
    )
 | 
			
		||||
    if uploaded:
 | 
			
		||||
        if file_upload and input_storage_name != "":
 | 
			
		||||
            file_payloads = []
 | 
			
		||||
            for file in file_upload:
 | 
			
		||||
                file_payload = (
 | 
			
		||||
                    "files",
 | 
			
		||||
                    (file.name, file.read(), file.type),
 | 
			
		||||
                )
 | 
			
		||||
                file_payloads.append((file_payload))
 | 
			
		||||
 | 
			
		||||
            response = client.upload_files(file_payloads, input_storage_name)
 | 
			
		||||
            if response.status_code == 200:
 | 
			
		||||
                st.success("Files uploaded successfully!")
 | 
			
		||||
            else:
 | 
			
		||||
                st.error(f"Error: {json.loads(response.text)}")
 | 
			
		||||
    return uploaded
 | 
			
		||||
							
								
								
									
										33
									
								
								frontend/src/enums.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								frontend/src/enums.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,33 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
from enum import Enum
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptKeys(Enum):
 | 
			
		||||
    ENTITY = "entity_extraction"
 | 
			
		||||
    SUMMARY = "summarize_descriptions"
 | 
			
		||||
    COMMUNITY = "community_report"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptFileNames(Enum):
 | 
			
		||||
    ENTITY = "entity_extraction_prompt.txt"
 | 
			
		||||
    SUMMARY = "summarize_descriptions_prompt.txt"
 | 
			
		||||
    COMMUNITY = "community_report_prompt.txt"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptTextAreas(Enum):
 | 
			
		||||
    ENTITY = "entity_text_area"
 | 
			
		||||
    SUMMARY = "summary_text_area"
 | 
			
		||||
    COMMUNITY = "community_text_area"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StorageIndexVars(Enum):
 | 
			
		||||
    SELECTED_STORAGE = "selected_storage"
 | 
			
		||||
    INPUT_STORAGE = "input_storage"
 | 
			
		||||
    SELECTED_INDEX = "selected_index"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EnvVars(Enum):
 | 
			
		||||
    APIM_SUBSCRIPTION_KEY = "APIM_SUBSCRIPTION_KEY"
 | 
			
		||||
    DEPLOYMENT_URL = "DEPLOYMENT_URL"
 | 
			
		||||
							
								
								
									
										175
									
								
								frontend/src/functions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								frontend/src/functions.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,175 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from zipfile import ZipFile
 | 
			
		||||
 | 
			
		||||
import streamlit as st
 | 
			
		||||
from dotenv import find_dotenv, load_dotenv
 | 
			
		||||
 | 
			
		||||
from src.enums import EnvVars, PromptKeys, StorageIndexVars
 | 
			
		||||
from src.graphrag_api import GraphragAPI
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
This module contains functions that are used across the Streamlit app.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def initialize_app(env_file: str = ".env", css_file: str = "style.css") -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Initialize the Streamlit app with the necessary configurations.
 | 
			
		||||
    """
 | 
			
		||||
    # set page configuration
 | 
			
		||||
    st.set_page_config(initial_sidebar_state="expanded", layout="wide")
 | 
			
		||||
 | 
			
		||||
    # set custom CSS
 | 
			
		||||
    with open(css_file) as f:
 | 
			
		||||
        st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
 | 
			
		||||
 | 
			
		||||
    # initialize session state variables
 | 
			
		||||
    set_session_state_variables()
 | 
			
		||||
 | 
			
		||||
    # load environment variables
 | 
			
		||||
    _ = load_dotenv(find_dotenv(filename=env_file) or None, override=True)
 | 
			
		||||
 | 
			
		||||
    # either load from .env file or from session state
 | 
			
		||||
    st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = os.getenv(
 | 
			
		||||
        EnvVars.APIM_SUBSCRIPTION_KEY.value,
 | 
			
		||||
        st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value],
 | 
			
		||||
    )
 | 
			
		||||
    st.session_state[EnvVars.DEPLOYMENT_URL.value] = os.getenv(
 | 
			
		||||
        EnvVars.DEPLOYMENT_URL.value, st.session_state[EnvVars.DEPLOYMENT_URL.value]
 | 
			
		||||
    )
 | 
			
		||||
    if (
 | 
			
		||||
        st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value]
 | 
			
		||||
        and st.session_state[EnvVars.DEPLOYMENT_URL.value]
 | 
			
		||||
    ):
 | 
			
		||||
        st.session_state["headers"] = {
 | 
			
		||||
            "Ocp-Apim-Subscription-Key": st.session_state[
 | 
			
		||||
                EnvVars.APIM_SUBSCRIPTION_KEY.value
 | 
			
		||||
            ],
 | 
			
		||||
            "Content-Type": "application/json",
 | 
			
		||||
        }
 | 
			
		||||
        st.session_state["headers_upload"] = {
 | 
			
		||||
            "Ocp-Apim-Subscription-Key": st.session_state[
 | 
			
		||||
                EnvVars.APIM_SUBSCRIPTION_KEY.value
 | 
			
		||||
            ]
 | 
			
		||||
        }
 | 
			
		||||
        return True
 | 
			
		||||
    else:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_session_state_variables() -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Initalizes most session state variables for the app.
 | 
			
		||||
    """
 | 
			
		||||
    for key in PromptKeys:
 | 
			
		||||
        value = key.value
 | 
			
		||||
        if value not in st.session_state:
 | 
			
		||||
            st.session_state[value] = ""
 | 
			
		||||
    for key in StorageIndexVars:
 | 
			
		||||
        value = key.value
 | 
			
		||||
        if value not in st.session_state:
 | 
			
		||||
            st.session_state[value] = ""
 | 
			
		||||
    for key in EnvVars:
 | 
			
		||||
        value = key.value
 | 
			
		||||
        if value not in st.session_state:
 | 
			
		||||
            st.session_state[value] = ""
 | 
			
		||||
    if "saved_prompts" not in st.session_state:
 | 
			
		||||
        st.session_state["saved_prompts"] = False
 | 
			
		||||
    if "initialized" not in st.session_state:
 | 
			
		||||
        st.session_state["initialized"] = False
 | 
			
		||||
    if "new_upload" not in st.session_state:
 | 
			
		||||
        st.session_state["new_upload"] = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_session_state_prompt_vars(
 | 
			
		||||
    entity_extract: Optional[str] = None,
 | 
			
		||||
    summarize: Optional[str] = None,
 | 
			
		||||
    community: Optional[str] = None,
 | 
			
		||||
    initial_setting: bool = False,
 | 
			
		||||
    prompt_dir: str = "./prompts",
 | 
			
		||||
) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Updates the session state variables for the LLM prompts.
 | 
			
		||||
    """
 | 
			
		||||
    if initial_setting:
 | 
			
		||||
        entity_extract, summarize, community = get_prompts(prompt_dir)
 | 
			
		||||
    if entity_extract:
 | 
			
		||||
        st.session_state[PromptKeys.ENTITY.value] = entity_extract
 | 
			
		||||
    if summarize:
 | 
			
		||||
        st.session_state[PromptKeys.SUMMARY.value] = summarize
 | 
			
		||||
    if community:
 | 
			
		||||
        st.session_state[PromptKeys.COMMUNITY.value] = community
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_and_extract_prompts(
 | 
			
		||||
    client: GraphragAPI,
 | 
			
		||||
    storage_name: str,
 | 
			
		||||
    zip_file_name: str = "prompts.zip",
 | 
			
		||||
    limit: int = 5,
 | 
			
		||||
) -> None | Exception:
 | 
			
		||||
    """
 | 
			
		||||
    Makes API call to generate LLM prompts, extracts prompts from zip file,
 | 
			
		||||
    and updates the prompt session state variables.
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        client.generate_prompts(
 | 
			
		||||
            storage_name=storage_name, zip_file_name=zip_file_name, limit=limit
 | 
			
		||||
        )
 | 
			
		||||
        _extract_prompts_from_zip(zip_file_name)
 | 
			
		||||
        update_session_state_prompt_vars(initial_setting=True)
 | 
			
		||||
        return
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        return e
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _extract_prompts_from_zip(zip_file_name: str = "prompts.zip"):
 | 
			
		||||
    with ZipFile(zip_file_name, "r") as zip_ref:
 | 
			
		||||
        zip_ref.extractall()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def open_file(file_path: str | Path):
 | 
			
		||||
    with open(file_path, "r", encoding="utf-8") as file:
 | 
			
		||||
        text = file.read()
 | 
			
		||||
    return text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def zip_directory(directory_path: str, zip_path: str):
 | 
			
		||||
    """
 | 
			
		||||
    Zips all contents of a directory into a single zip file.
 | 
			
		||||
 | 
			
		||||
    Parameters:
 | 
			
		||||
    - directory_path: str, the path of the directory to zip
 | 
			
		||||
    - zip_path: str, the path where the zip file will be created
 | 
			
		||||
    """
 | 
			
		||||
    root_dir_name = os.path.basename(directory_path.rstrip("/"))
 | 
			
		||||
    with ZipFile(zip_path, "w") as zipf:
 | 
			
		||||
        for root, _, files in os.walk(directory_path):
 | 
			
		||||
            for file in files:
 | 
			
		||||
                file_path = os.path.join(root, file)
 | 
			
		||||
                relpath = os.path.relpath(file_path, start=directory_path)
 | 
			
		||||
                arcname = os.path.join(root_dir_name, relpath)
 | 
			
		||||
                zipf.write(file_path, arcname)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_prompts(prompt_dir: str = "./prompts"):
 | 
			
		||||
    """
 | 
			
		||||
    Extract text from generated prompts.  Assumes file names comply with pregenerated file name standards.
 | 
			
		||||
    """
 | 
			
		||||
    prompt_paths = [
 | 
			
		||||
        prompt for prompt in Path(prompt_dir).iterdir() if prompt.name.endswith(".txt")
 | 
			
		||||
    ]
 | 
			
		||||
    entity_ext_prompt = [
 | 
			
		||||
        open_file(path) for path in prompt_paths if path.name.startswith("entity")
 | 
			
		||||
    ][0]
 | 
			
		||||
    summ_prompt = [
 | 
			
		||||
        open_file(path) for path in prompt_paths if path.name.startswith("summ")
 | 
			
		||||
    ][0]
 | 
			
		||||
    comm_report_prompt = [
 | 
			
		||||
        open_file(path) for path in prompt_paths if path.name.startswith("community")
 | 
			
		||||
    ][0]
 | 
			
		||||
    return entity_ext_prompt, summ_prompt, comm_report_prompt
 | 
			
		||||
							
								
								
									
										214
									
								
								frontend/src/graphrag_api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										214
									
								
								frontend/src/graphrag_api.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,214 @@
 | 
			
		||||
# Copyright (c) Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License.
 | 
			
		||||
 | 
			
		||||
from io import StringIO
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import streamlit as st
 | 
			
		||||
from requests import Response
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
This module contains the GraphRAG API class for making all external API calls
 | 
			
		||||
presumably to a GraphRAG instance deployed on Azure.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GraphragAPI:
 | 
			
		||||
    """
 | 
			
		||||
    Primary interface for making REST API call to GraphRAG API.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, api_url: str, apim_key: str):
 | 
			
		||||
        self.api_url = api_url
 | 
			
		||||
        self.apim_key = apim_key
 | 
			
		||||
        self.headers = {
 | 
			
		||||
            "Ocp-Apim-Subscription-Key": self.apim_key,
 | 
			
		||||
            "Content-Type": "application/json",
 | 
			
		||||
        }
 | 
			
		||||
        self.upload_headers = {"Ocp-Apim-Subscription-Key": self.apim_key}
 | 
			
		||||
 | 
			
		||||
    def get_storage_container_names(
 | 
			
		||||
        self, storage_name_key: str = "storage_name"
 | 
			
		||||
    ) -> list[str] | Response | Exception:
 | 
			
		||||
        """
 | 
			
		||||
        GET request to GraphRAG API for Azure Blob Storage Container names.
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            response = requests.get(f"{self.api_url}/data", headers=self.headers)
 | 
			
		||||
            if response.status_code == 200:
 | 
			
		||||
                return response.json()[storage_name_key]
 | 
			
		||||
            else:
 | 
			
		||||
                print(f"Error: {response.status_code}")
 | 
			
		||||
                return response
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"Error: {str(e)}")
 | 
			
		||||
            return e
 | 
			
		||||
 | 
			
		||||
    def upload_files(self, file_payloads: dict, input_storage_name: str):
 | 
			
		||||
        """
 | 
			
		||||
        Upload files to Azure Blob Storage Container.
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            response = requests.post(
 | 
			
		||||
                self.api_url + "/data",
 | 
			
		||||
                headers=self.upload_headers,
 | 
			
		||||
                files=file_payloads,
 | 
			
		||||
                params={"storage_name": input_storage_name},
 | 
			
		||||
            )
 | 
			
		||||
            if response.status_code == 200:
 | 
			
		||||
                return response
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"Error: {str(e)}")
 | 
			
		||||
 | 
			
		||||
    def get_index_names(
 | 
			
		||||
        self, index_name_key: str = "index_name"
 | 
			
		||||
    ) -> list | Response | None:
 | 
			
		||||
        """
 | 
			
		||||
        GET request to GraphRAG API for existing indexes.
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            response = requests.get(f"{self.api_url}/index", headers=self.headers)
 | 
			
		||||
            if response.status_code == 200:
 | 
			
		||||
                return response.json()[index_name_key]
 | 
			
		||||
            else:
 | 
			
		||||
                print(f"Error: {response.status_code}")
 | 
			
		||||
                return response
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"Error: {str(e)}")
 | 
			
		||||
 | 
			
		||||
    def build_index(
 | 
			
		||||
        self,
 | 
			
		||||
        storage_name: str,
 | 
			
		||||
        index_name: str,
 | 
			
		||||
        entity_extraction_prompt_filepath: str | StringIO = None,
 | 
			
		||||
        community_prompt_filepath: str | StringIO = None,
 | 
			
		||||
        summarize_description_prompt_filepath: str | StringIO = None,
 | 
			
		||||
    ) -> requests.Response:
 | 
			
		||||
        """
 | 
			
		||||
        Create a search index.
 | 
			
		||||
        This function kicks off a job that builds a knowledge graph (KG)
 | 
			
		||||
        index from files located in a blob storage container.
 | 
			
		||||
        """
 | 
			
		||||
        url = self.api_url + "/index"
 | 
			
		||||
        prompt_files = dict()
 | 
			
		||||
        if entity_extraction_prompt_filepath:
 | 
			
		||||
            prompt_files["entity_extraction_prompt"] = (
 | 
			
		||||
                open(entity_extraction_prompt_filepath, "r", encoding="utf-8")
 | 
			
		||||
                if isinstance(entity_extraction_prompt_filepath, str)
 | 
			
		||||
                else entity_extraction_prompt_filepath
 | 
			
		||||
            )
 | 
			
		||||
        if community_prompt_filepath:
 | 
			
		||||
            prompt_files["community_report_prompt"] = (
 | 
			
		||||
                open(community_prompt_filepath, "r", encoding="utf-8")
 | 
			
		||||
                if isinstance(community_prompt_filepath, str)
 | 
			
		||||
                else community_prompt_filepath
 | 
			
		||||
            )
 | 
			
		||||
        if summarize_description_prompt_filepath:
 | 
			
		||||
            prompt_files["summarize_descriptions_prompt"] = (
 | 
			
		||||
                open(summarize_description_prompt_filepath, "r", encoding="utf-8")
 | 
			
		||||
                if isinstance(summarize_description_prompt_filepath, str)
 | 
			
		||||
                else summarize_description_prompt_filepath
 | 
			
		||||
            )
 | 
			
		||||
        return requests.post(
 | 
			
		||||
            url,
 | 
			
		||||
            files=prompt_files if len(prompt_files) > 0 else None,
 | 
			
		||||
            params={"index_name": index_name, "storage_name": storage_name},
 | 
			
		||||
            headers=self.headers,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def check_index_status(self, index_name: str) -> Response | None:
 | 
			
		||||
        """
 | 
			
		||||
        Check the status of a running index job.
 | 
			
		||||
        """
 | 
			
		||||
        url = self.api_url + f"/index/status/{index_name}"
 | 
			
		||||
        try:
 | 
			
		||||
            response = requests.get(url, headers=self.headers)
 | 
			
		||||
            if response.status_code == 200:
 | 
			
		||||
                return response
 | 
			
		||||
            else:
 | 
			
		||||
                print(f"Error: {response.status_code}")
 | 
			
		||||
                return response
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"Error: {str(e)}")
 | 
			
		||||
 | 
			
		||||
    def health_check(self) -> int | Response:
 | 
			
		||||
        """
 | 
			
		||||
        Check the health of the APIM endpoint.
 | 
			
		||||
        """
 | 
			
		||||
        url = self.api_url + "/health"
 | 
			
		||||
        try:
 | 
			
		||||
            response = requests.get(url, headers=self.headers)
 | 
			
		||||
            return response.status_code
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"Error: {str(e)}")
 | 
			
		||||
            return e
 | 
			
		||||
 | 
			
		||||
    def query_index(self, index_name: str | list[str], query_type: str, query: str):
 | 
			
		||||
        """
 | 
			
		||||
        Submite query to GraphRAG API using specific index and query type.
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            request = {
 | 
			
		||||
                "index_name": index_name,
 | 
			
		||||
                "query": query,
 | 
			
		||||
                "reformat_context_data": True,
 | 
			
		||||
            }
 | 
			
		||||
            response = requests.post(
 | 
			
		||||
                f"{self.api_url}/query/{query_type.lower()}",
 | 
			
		||||
                headers=self.headers,
 | 
			
		||||
                json=request,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if response.status_code == 200:
 | 
			
		||||
                return response.json()
 | 
			
		||||
            else:
 | 
			
		||||
                st.error(
 | 
			
		||||
                    f"Error with {query_type} search: {response.status_code} {response.json()}"
 | 
			
		||||
                )
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            st.error(f"Error with {query_type} search: {str(e)}")
 | 
			
		||||
 | 
			
		||||
    def global_streaming_query(
 | 
			
		||||
        self, index_name: str | list[str], query: str
 | 
			
		||||
    ) -> Response | None:
 | 
			
		||||
        """
 | 
			
		||||
        Returns a streaming response object for a global query.
 | 
			
		||||
        """
 | 
			
		||||
        url = f"{self.api_url}/experimental/query/global/streaming"
 | 
			
		||||
        try:
 | 
			
		||||
            query_response = requests.post(
 | 
			
		||||
                url,
 | 
			
		||||
                json={"index_name": index_name, "query": query},
 | 
			
		||||
                headers=self.headers,
 | 
			
		||||
                stream=True,
 | 
			
		||||
            )
 | 
			
		||||
            return query_response
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"Error: {str(e)}")
 | 
			
		||||
 | 
			
		||||
    def get_source_entity(self, index_name: str, entity_id: str) -> dict | None:
 | 
			
		||||
        try:
 | 
			
		||||
            response = requests.get(
 | 
			
		||||
                f"{self.api_url}/source/entity/{index_name}/{entity_id}",
 | 
			
		||||
                headers=self.headers,
 | 
			
		||||
            )
 | 
			
		||||
            if response.status_code == 200:
 | 
			
		||||
                return response.json()
 | 
			
		||||
            else:
 | 
			
		||||
                return response
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"Error: {str(e)}")
 | 
			
		||||
 | 
			
		||||
    def generate_prompts(
 | 
			
		||||
        self, storage_name: str, zip_file_name: str = "prompts.zip", limit: int = 1
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Generate graphrag prompts using data provided in a specific storage container.
 | 
			
		||||
        """
 | 
			
		||||
        url = self.api_url + "/index/config/prompts"
 | 
			
		||||
        params = {"storage_name": storage_name, "limit": limit}
 | 
			
		||||
        with requests.get(url, params=params, headers=self.headers, stream=True) as r:
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
            with open(zip_file_name, "wb") as f:
 | 
			
		||||
                for chunk in r.iter_content():
 | 
			
		||||
                    f.write(chunk)
 | 
			
		||||
							
								
								
									
										142
									
								
								frontend/style.css
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								frontend/style.css
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,142 @@
 | 
			
		||||
/*
 | 
			
		||||
Copyright (c) Microsoft Corporation.
 | 
			
		||||
Licensed under the MIT License.
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
 | 
			
		||||
 | 
			
		||||
#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi8 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi5 > div > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn5 > div:nth-child(4) > div > div > div > div > div{
 | 
			
		||||
    margin-top: 1.6em;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[data-testid="stHeadingDivider"] {
 | 
			
		||||
    background-color: #3d9df3;  /* Set your desired color */
 | 
			
		||||
    height: 1px;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#microsoft-graphrag-copilot > div > span {
 | 
			
		||||
    text-align: center;
 | 
			
		||||
    margin-top: -1em;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Tooltip container */
 | 
			
		||||
.tooltip {
 | 
			
		||||
    position: relative;
 | 
			
		||||
    display: inline-block;
 | 
			
		||||
    border-bottom: 1px dotted black; /* If you want dots under the hoverable text */
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* Tooltip text */
 | 
			
		||||
  .tooltip .tooltiptext {
 | 
			
		||||
    visibility: hidden;
 | 
			
		||||
    width: 120px;
 | 
			
		||||
    background-color: #555;
 | 
			
		||||
    color: #fff;
 | 
			
		||||
    text-align: center;
 | 
			
		||||
    border-radius: 6px;
 | 
			
		||||
    padding: 5px;
 | 
			
		||||
    position: absolute;
 | 
			
		||||
    z-index: 1;
 | 
			
		||||
    bottom: 125%;
 | 
			
		||||
    left: 50%;
 | 
			
		||||
    margin-left: -60px;
 | 
			
		||||
    opacity: 0;
 | 
			
		||||
    transition: opacity 0.3s;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* Show the tooltip text when you hover over the tooltip container */
 | 
			
		||||
  .tooltip:hover .tooltiptext {
 | 
			
		||||
    visibility: visible;
 | 
			
		||||
    opacity: 1;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  .gray-box {
 | 
			
		||||
 | 
			
		||||
    background-color: #ffffff;
 | 
			
		||||
    padding: 10px;
 | 
			
		||||
    width: 80%;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
.center-container {
 | 
			
		||||
  margin-top: -10em;
 | 
			
		||||
  display: flex;
 | 
			
		||||
  align-items: center;
 | 
			
		||||
  justify-content: center;
 | 
			
		||||
  height: 100vh;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
.footer {
 | 
			
		||||
  display: flex;
 | 
			
		||||
  justify-content: center;
 | 
			
		||||
  align-items: center;
 | 
			
		||||
  position: fixed;
 | 
			
		||||
  left: 0;
 | 
			
		||||
  bottom: 0;
 | 
			
		||||
  width: 100%;
 | 
			
		||||
  background-color: #f1f1f1;
 | 
			
		||||
  text-align: center;
 | 
			
		||||
  padding: 5px;
 | 
			
		||||
  z-index: 1000;
 | 
			
		||||
}
 | 
			
		||||
.footer p{
 | 
			
		||||
  font-size: 12px;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* CSS */
 | 
			
		||||
button[kind="primary"] {
 | 
			
		||||
  background-color: #1d9445;
 | 
			
		||||
  border: 0;
 | 
			
		||||
  border-radius: 56px;
 | 
			
		||||
  color: #fff;
 | 
			
		||||
  cursor: pointer;
 | 
			
		||||
  display: inline-block;
 | 
			
		||||
  font-family: system-ui,-apple-system,system-ui,"Segoe UI",Roboto,Ubuntu,"Helvetica Neue",sans-serif;
 | 
			
		||||
  font-size: 58px;
 | 
			
		||||
  font-weight: 600;
 | 
			
		||||
  outline: 0;
 | 
			
		||||
  padding: 16px 21px;
 | 
			
		||||
  position: relative;
 | 
			
		||||
  text-align: center;
 | 
			
		||||
  text-decoration: none;
 | 
			
		||||
  transition: all .3s;
 | 
			
		||||
  user-select: none;
 | 
			
		||||
  -webkit-user-select: none;
 | 
			
		||||
  touch-action: manipulation;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
button[kind="primary"]:before {
 | 
			
		||||
  background-color: initial;
 | 
			
		||||
  background-image: linear-gradient(#fff 0, rgba(255, 255, 255, 0) 100%);
 | 
			
		||||
  border-radius: 125px;
 | 
			
		||||
  content: "";
 | 
			
		||||
  height: 50%;
 | 
			
		||||
  left: 4%;
 | 
			
		||||
  opacity: .5;
 | 
			
		||||
  position: absolute;
 | 
			
		||||
  top: 0;
 | 
			
		||||
  transition: all .3s;
 | 
			
		||||
  width: 62%;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
button[kind="primary"]:hover {
 | 
			
		||||
  box-shadow: rgba(255, 255, 255, .2) 0 3px 15px inset, rgba(0, 0, 0, .1) 0 3px 5px, rgba(0, 0, 0, .1) 0 10px 13px;
 | 
			
		||||
  transform: scale(1.05);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@media (min-width: 768px) {
 | 
			
		||||
button[kind="primary"] {
 | 
			
		||||
    padding: 15px 34px;
 | 
			
		||||
    margin: 20px auto;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
.element-container:has(>.stTextArea), .stTextArea {
 | 
			
		||||
  display: block;
 | 
			
		||||
  margin-left: auto;
 | 
			
		||||
  margin-right: auto;
 | 
			
		||||
}
 | 
			
		||||
.stTextArea textarea {
 | 
			
		||||
  height: 500px;
 | 
			
		||||
  /*background-color: #a7b0a4;*/
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user