Update Frontend app to accommodate recent changes in big merge - Fixes #242 (#243)

This commit is contained in:
Tim 2025-02-12 18:51:18 -05:00 committed by GitHub
parent ced825c40f
commit a0d987c4eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 49 additions and 13 deletions

View File

@ -110,8 +110,8 @@ class IndexPipeline:
)
response = self.client.build_index(
storage_name=storage_selection,
index_name=index_name,
storage_container_name=storage_selection,
index_container_name=index_name,
entity_extraction_prompt_filepath=entity_prompt,
summarize_description_prompt_filepath=summarize_prompt,
community_prompt_filepath=community_prompt,

View File

@ -221,8 +221,11 @@ def get_query_tab(client: GraphragAPI) -> None:
with col1:
query_type = st.selectbox(
"Query Type",
["Global Streaming", "Local Streaming", "Global", "Local"],
help="Select the query type - Each yeilds different results of specificity. Global queries focus on the entire graph structure. Local queries focus on a set of communities (subgraphs) in the graph that are more connected to each other than they are to the rest of the graph structure and can focus on very specific entities in the graph. Global streaming is a global query that displays results as they appear live.",
# ["Global Streaming", "Local Streaming", "Global", "Local"],
["Global", "Local"],
help=(
"Select the query type - Each yeilds different results of specificity. Global queries focus on the entire graph structure. Local queries focus on a set of communities (subgraphs) in the graph that are more connected to each other than they are to the rest of the graph structure and can focus on very specific entities in the graph."
),
)
with col2:
search_indexes = client.get_index_names()

View File

@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
import os
from pathlib import Path
from typing import Optional
@ -116,13 +117,31 @@ def generate_and_extract_prompts(
client.generate_prompts(
storage_name=storage_name, zip_file_name=zip_file_name, limit=limit
)
_extract_prompts_from_zip(zip_file_name)
update_session_state_prompt_vars(initial_setting=True)
_extract_prompts_from_json(zip_file_name)
update_session_state_prompt_vars(initial_setting=True, prompt_dir=".")
return
except Exception as e:
return e
def _extract_prompts_from_json(json_file_name: str = "prompts.zip"):
with open(json_file_name, "r", encoding="utf-8") as file:
json_data = file.read()
json_data = json.loads(json_data)
with open("entity_extraction_prompt.txt", "w", encoding="utf-8") as file:
file.write(json_data["entity_extraction_prompt"])
with open("summarization_prompt.txt", "w", encoding="utf-8") as file:
file.write(json_data["entity_summarization_prompt"])
with open("community_summarization_prompt.txt", "w", encoding="utf-8") as file:
file.write(json_data["community_summarization_prompt"])
return json_data
def _extract_prompts_from_zip(zip_file_name: str = "prompts.zip"):
with ZipFile(zip_file_name, "r") as zip_ref:
zip_ref.extractall()

View File

@ -44,7 +44,9 @@ class GraphragAPI:
print(f"Error: {str(e)}")
return e
def upload_files(self, file_payloads: dict, input_storage_name: str):
def upload_files(
self, file_payloads: dict, container_name: str, overwrite: bool = True
):
"""
Upload files to Azure Blob Storage Container.
"""
@ -53,7 +55,7 @@ class GraphragAPI:
self.api_url + "/data",
headers=self.upload_headers,
files=file_payloads,
params={"storage_name": input_storage_name},
params={"container_name": container_name, "overwrite": overwrite},
)
if response.status_code == 200:
return response
@ -78,8 +80,8 @@ class GraphragAPI:
def build_index(
self,
storage_name: str,
index_name: str,
storage_container_name: str,
index_container_name: str,
entity_extraction_prompt_filepath: str | StringIO = None,
community_prompt_filepath: str | StringIO = None,
summarize_description_prompt_filepath: str | StringIO = None,
@ -112,7 +114,10 @@ class GraphragAPI:
return requests.post(
url,
files=prompt_files if len(prompt_files) > 0 else None,
params={"index_name": index_name, "storage_name": storage_name},
params={
"storage_container_name": storage_container_name,
"index_container_name": index_container_name,
},
headers=self.headers,
)
@ -146,11 +151,20 @@ class GraphragAPI:
"""
Submite query to GraphRAG API using specific index and query type.
"""
if isinstance(index_name, list) and len(index_name) > 1:
st.error(
"Multiple index names are currently not supported via the UI. This functionality is being moved into the graphrag library and will be available in a coming release."
)
return {"result": ""}
index_name = index_name if isinstance(index_name, str) else index_name[0]
try:
request = {
"index_name": index_name,
"query": query,
"reformat_context_data": True,
# "reformat_context_data": True,
}
response = requests.post(
f"{self.api_url}/query/{query_type.lower()}",
@ -223,7 +237,7 @@ class GraphragAPI:
Generate graphrag prompts using data provided in a specific storage container.
"""
url = self.api_url + "/index/config/prompts"
params = {"storage_name": storage_name, "limit": limit}
params = {"container_name": storage_name, "limit": limit}
with requests.get(url, params=params, headers=self.headers, stream=True) as r:
r.raise_for_status()
with open(zip_file_name, "wb") as f: