mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-06-27 04:39:57 +00:00
This commit is contained in:
parent
ced825c40f
commit
a0d987c4eb
@ -110,8 +110,8 @@ class IndexPipeline:
|
|||||||
)
|
)
|
||||||
|
|
||||||
response = self.client.build_index(
|
response = self.client.build_index(
|
||||||
storage_name=storage_selection,
|
storage_container_name=storage_selection,
|
||||||
index_name=index_name,
|
index_container_name=index_name,
|
||||||
entity_extraction_prompt_filepath=entity_prompt,
|
entity_extraction_prompt_filepath=entity_prompt,
|
||||||
summarize_description_prompt_filepath=summarize_prompt,
|
summarize_description_prompt_filepath=summarize_prompt,
|
||||||
community_prompt_filepath=community_prompt,
|
community_prompt_filepath=community_prompt,
|
||||||
|
@ -221,8 +221,11 @@ def get_query_tab(client: GraphragAPI) -> None:
|
|||||||
with col1:
|
with col1:
|
||||||
query_type = st.selectbox(
|
query_type = st.selectbox(
|
||||||
"Query Type",
|
"Query Type",
|
||||||
["Global Streaming", "Local Streaming", "Global", "Local"],
|
# ["Global Streaming", "Local Streaming", "Global", "Local"],
|
||||||
help="Select the query type - Each yeilds different results of specificity. Global queries focus on the entire graph structure. Local queries focus on a set of communities (subgraphs) in the graph that are more connected to each other than they are to the rest of the graph structure and can focus on very specific entities in the graph. Global streaming is a global query that displays results as they appear live.",
|
["Global", "Local"],
|
||||||
|
help=(
|
||||||
|
"Select the query type - Each yeilds different results of specificity. Global queries focus on the entire graph structure. Local queries focus on a set of communities (subgraphs) in the graph that are more connected to each other than they are to the rest of the graph structure and can focus on very specific entities in the graph."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
with col2:
|
with col2:
|
||||||
search_indexes = client.get_index_names()
|
search_indexes = client.get_index_names()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -116,13 +117,31 @@ def generate_and_extract_prompts(
|
|||||||
client.generate_prompts(
|
client.generate_prompts(
|
||||||
storage_name=storage_name, zip_file_name=zip_file_name, limit=limit
|
storage_name=storage_name, zip_file_name=zip_file_name, limit=limit
|
||||||
)
|
)
|
||||||
_extract_prompts_from_zip(zip_file_name)
|
_extract_prompts_from_json(zip_file_name)
|
||||||
update_session_state_prompt_vars(initial_setting=True)
|
update_session_state_prompt_vars(initial_setting=True, prompt_dir=".")
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return e
|
return e
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_prompts_from_json(json_file_name: str = "prompts.zip"):
|
||||||
|
with open(json_file_name, "r", encoding="utf-8") as file:
|
||||||
|
json_data = file.read()
|
||||||
|
|
||||||
|
json_data = json.loads(json_data)
|
||||||
|
|
||||||
|
with open("entity_extraction_prompt.txt", "w", encoding="utf-8") as file:
|
||||||
|
file.write(json_data["entity_extraction_prompt"])
|
||||||
|
|
||||||
|
with open("summarization_prompt.txt", "w", encoding="utf-8") as file:
|
||||||
|
file.write(json_data["entity_summarization_prompt"])
|
||||||
|
|
||||||
|
with open("community_summarization_prompt.txt", "w", encoding="utf-8") as file:
|
||||||
|
file.write(json_data["community_summarization_prompt"])
|
||||||
|
|
||||||
|
return json_data
|
||||||
|
|
||||||
|
|
||||||
def _extract_prompts_from_zip(zip_file_name: str = "prompts.zip"):
|
def _extract_prompts_from_zip(zip_file_name: str = "prompts.zip"):
|
||||||
with ZipFile(zip_file_name, "r") as zip_ref:
|
with ZipFile(zip_file_name, "r") as zip_ref:
|
||||||
zip_ref.extractall()
|
zip_ref.extractall()
|
||||||
|
@ -44,7 +44,9 @@ class GraphragAPI:
|
|||||||
print(f"Error: {str(e)}")
|
print(f"Error: {str(e)}")
|
||||||
return e
|
return e
|
||||||
|
|
||||||
def upload_files(self, file_payloads: dict, input_storage_name: str):
|
def upload_files(
|
||||||
|
self, file_payloads: dict, container_name: str, overwrite: bool = True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Upload files to Azure Blob Storage Container.
|
Upload files to Azure Blob Storage Container.
|
||||||
"""
|
"""
|
||||||
@ -53,7 +55,7 @@ class GraphragAPI:
|
|||||||
self.api_url + "/data",
|
self.api_url + "/data",
|
||||||
headers=self.upload_headers,
|
headers=self.upload_headers,
|
||||||
files=file_payloads,
|
files=file_payloads,
|
||||||
params={"storage_name": input_storage_name},
|
params={"container_name": container_name, "overwrite": overwrite},
|
||||||
)
|
)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response
|
return response
|
||||||
@ -78,8 +80,8 @@ class GraphragAPI:
|
|||||||
|
|
||||||
def build_index(
|
def build_index(
|
||||||
self,
|
self,
|
||||||
storage_name: str,
|
storage_container_name: str,
|
||||||
index_name: str,
|
index_container_name: str,
|
||||||
entity_extraction_prompt_filepath: str | StringIO = None,
|
entity_extraction_prompt_filepath: str | StringIO = None,
|
||||||
community_prompt_filepath: str | StringIO = None,
|
community_prompt_filepath: str | StringIO = None,
|
||||||
summarize_description_prompt_filepath: str | StringIO = None,
|
summarize_description_prompt_filepath: str | StringIO = None,
|
||||||
@ -112,7 +114,10 @@ class GraphragAPI:
|
|||||||
return requests.post(
|
return requests.post(
|
||||||
url,
|
url,
|
||||||
files=prompt_files if len(prompt_files) > 0 else None,
|
files=prompt_files if len(prompt_files) > 0 else None,
|
||||||
params={"index_name": index_name, "storage_name": storage_name},
|
params={
|
||||||
|
"storage_container_name": storage_container_name,
|
||||||
|
"index_container_name": index_container_name,
|
||||||
|
},
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -146,11 +151,20 @@ class GraphragAPI:
|
|||||||
"""
|
"""
|
||||||
Submite query to GraphRAG API using specific index and query type.
|
Submite query to GraphRAG API using specific index and query type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(index_name, list) and len(index_name) > 1:
|
||||||
|
st.error(
|
||||||
|
"Multiple index names are currently not supported via the UI. This functionality is being moved into the graphrag library and will be available in a coming release."
|
||||||
|
)
|
||||||
|
return {"result": ""}
|
||||||
|
|
||||||
|
index_name = index_name if isinstance(index_name, str) else index_name[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request = {
|
request = {
|
||||||
"index_name": index_name,
|
"index_name": index_name,
|
||||||
"query": query,
|
"query": query,
|
||||||
"reformat_context_data": True,
|
# "reformat_context_data": True,
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{self.api_url}/query/{query_type.lower()}",
|
f"{self.api_url}/query/{query_type.lower()}",
|
||||||
@ -223,7 +237,7 @@ class GraphragAPI:
|
|||||||
Generate graphrag prompts using data provided in a specific storage container.
|
Generate graphrag prompts using data provided in a specific storage container.
|
||||||
"""
|
"""
|
||||||
url = self.api_url + "/index/config/prompts"
|
url = self.api_url + "/index/config/prompts"
|
||||||
params = {"storage_name": storage_name, "limit": limit}
|
params = {"container_name": storage_name, "limit": limit}
|
||||||
with requests.get(url, params=params, headers=self.headers, stream=True) as r:
|
with requests.get(url, params=params, headers=self.headers, stream=True) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
with open(zip_file_name, "wb") as f:
|
with open(zip_file_name, "wb") as f:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user