mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-12-25 06:06:24 +00:00
Update frontend UI app (#174)
Co-authored-by: dorbaker <dorbaker@microsoft.com>
This commit is contained in:
parent
38096c8e86
commit
680cfc055e
@ -12,15 +12,16 @@
|
||||
],
|
||||
"remoteUser": "vscode",
|
||||
"remoteEnv": {
|
||||
// We add the .venv to the beginning of the path env in the Dockerfile
|
||||
// so that we use the proper python, however vscode rewrites/overwrites
|
||||
// the PATH in the image and puts /usr/local/bin in front of our .venv
|
||||
// path. This fixes that issue.
|
||||
"PATH": "${containerEnv:PATH}",
|
||||
// Add src folder to PYTHONPATH so that we can import modules that
|
||||
// are in the source dir
|
||||
"PYTHONPATH": "/graphrag-accelerator/backend/:$PATH"
|
||||
},
|
||||
// We add the .venv to the beginning of the path env in the Dockerfile
|
||||
// so that we use the proper python, however vscode rewrites/overwrites
|
||||
// the PATH in the image and puts /usr/local/bin in front of our .venv
|
||||
// path. This fixes that issue.
|
||||
"PATH": "${containerEnv:PATH}",
|
||||
// Add src folder to PYTHONPATH so that we can import modules that
|
||||
// are in the source dir
|
||||
"PYTHONPATH": "/graphrag-accelerator/backend/:$PATH",
|
||||
"AZURE_CLI_DISABLE_CONNECTION_VERIFICATION": "1"
|
||||
},
|
||||
"mounts": [
|
||||
// NOTE: we reference both HOME and USERPROFILE environment variables to simultaneously support both Windows and Unix environments
|
||||
// in most default situations, only one variable will exist (Windows has USERPROFILE and unix has HOME) and a reference to the other variable will result in an empty string
|
||||
|
||||
@ -16,6 +16,7 @@ from kubernetes import (
|
||||
client,
|
||||
config,
|
||||
)
|
||||
|
||||
from src.api.azure_clients import AzureStorageClientManager
|
||||
from src.api.common import sanitize_name
|
||||
from src.models import PipelineJob
|
||||
|
||||
951
poetry.lock → backend/poetry.lock
generated
951
poetry.lock → backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -9,7 +9,6 @@ authors = [
|
||||
"Gabriel Nieves <gnievesponce@microsoft.com>",
|
||||
"Douglas Orbaker <dorbaker@microsoft.com>",
|
||||
"Chris Sanchez <chrissanchez@microsoft.com>",
|
||||
"Katy Smith <katysmith@microsoft.com>",
|
||||
"Shane Solomon <shane.solomon@microsoft.com>",
|
||||
]
|
||||
readme = "README.md"
|
||||
@ -33,7 +32,7 @@ pytest = ">=8.2.1"
|
||||
wikipedia = ">=1.4.0"
|
||||
|
||||
[tool.poetry.group.backend.dependencies]
|
||||
adlfs = ">=2023.10.0"
|
||||
adlfs = ">=2024.7.0"
|
||||
applicationinsights = ">=0.11.10"
|
||||
attrs = ">=23.2.0"
|
||||
azure-core = ">=1.30.1"
|
||||
@ -66,12 +65,6 @@ tiktoken = ">=0.6.0"
|
||||
uvicorn = ">=0.23.2"
|
||||
urllib3 = ">=2.2.2"
|
||||
|
||||
[tool.poetry.group.frontend.dependencies]
|
||||
python-dotenv = ">=0.19.1"
|
||||
requests = "*"
|
||||
streamlit = ">=0.88.0"
|
||||
streamlit-nested-layout = "==0.1.3"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
line-length = 88
|
||||
@ -127,7 +127,9 @@ async def setup_indexing_pipeline(
|
||||
)
|
||||
# if indexing job is in a failed state, delete the associated K8s job and pod to allow for a new job to be scheduled
|
||||
if PipelineJobState(existing_job.status) == PipelineJobState.FAILED:
|
||||
_delete_k8s_job(f"indexing-job-{sanitized_index_name}", os.environ["AKS_NAMESPACE"])
|
||||
_delete_k8s_job(
|
||||
f"indexing-job-{sanitized_index_name}", os.environ["AKS_NAMESPACE"]
|
||||
)
|
||||
# reset the pipeline job details
|
||||
existing_job._status = PipelineJobState.SCHEDULED
|
||||
existing_job._percent_complete = 0
|
||||
|
||||
@ -68,11 +68,15 @@ async def lifespan(app: FastAPI):
|
||||
] = pod.spec.service_account_name
|
||||
# retrieve list of existing cronjobs
|
||||
batch_v1 = client.BatchV1Api()
|
||||
namespace_cronjobs = batch_v1.list_namespaced_cron_job(namespace=os.environ["AKS_NAMESPACE"])
|
||||
namespace_cronjobs = batch_v1.list_namespaced_cron_job(
|
||||
namespace=os.environ["AKS_NAMESPACE"]
|
||||
)
|
||||
cronjob_names = [cronjob.metadata.name for cronjob in namespace_cronjobs.items]
|
||||
# create cronjob if it does not exist
|
||||
if manifest["metadata"]["name"] not in cronjob_names:
|
||||
batch_v1.create_namespaced_cron_job(namespace=os.environ["AKS_NAMESPACE"], body=manifest)
|
||||
batch_v1.create_namespaced_cron_job(
|
||||
namespace=os.environ["AKS_NAMESPACE"], body=manifest
|
||||
)
|
||||
except Exception as e:
|
||||
print("Failed to create graphrag cronjob.")
|
||||
reporter = ReporterSingleton().get_instance()
|
||||
|
||||
@ -10,6 +10,7 @@ from azure.cosmos import (
|
||||
)
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from azure.storage.blob.aio import BlobServiceClient as BlobServiceClientAsync
|
||||
|
||||
from src.api.azure_clients import AzureStorageClientManager
|
||||
|
||||
|
||||
|
||||
@ -11,11 +11,11 @@ ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
|
||||
ENV PYTHONPATH=/backend
|
||||
|
||||
COPY poetry.lock pyproject.toml /
|
||||
COPY backend /backend
|
||||
RUN pip install poetry \
|
||||
RUN cd backend \
|
||||
&& pip install poetry \
|
||||
&& poetry config virtualenvs.create false \
|
||||
&& poetry install --without frontend
|
||||
&& poetry install
|
||||
|
||||
# download all nltk data that graphrag requires
|
||||
RUN python -m nltk.downloader punkt averaged_perceptron_tagger maxent_ne_chunker words wordnet
|
||||
|
||||
@ -7,11 +7,11 @@ ENV PIP_ROOT_USER_ACTION=ignore
|
||||
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
|
||||
|
||||
COPY poetry.lock pyproject.toml /
|
||||
COPY frontend /frontend
|
||||
RUN pip install poetry \
|
||||
RUN cd frontend \
|
||||
&& pip install poetry \
|
||||
&& poetry config virtualenvs.create false \
|
||||
&& poetry install --without backend
|
||||
&& poetry install
|
||||
|
||||
WORKDIR /frontend
|
||||
EXPOSE 8080
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import os
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from src.components import tabs
|
||||
from src.components.index_pipeline import IndexPipeline
|
||||
from src.enums import EnvVars
|
||||
@ -16,29 +17,28 @@ st.session_state["initialized"] = True if initialized else False
|
||||
|
||||
|
||||
def graphrag_app(initialized: bool):
|
||||
# main entry point for app interface
|
||||
st.title("Microsoft GraphRAG Copilot")
|
||||
main_tab, prompt_gen_tab, prompt_edit_tab, index_tab, query_tab = st.tabs(
|
||||
[
|
||||
"**Intro**",
|
||||
"**1. Prompt Generation**",
|
||||
"**2. Prompt Configuration**",
|
||||
"**3. Index**",
|
||||
"**4. Query**",
|
||||
]
|
||||
)
|
||||
main_tab, prompt_gen_tab, prompt_edit_tab, index_tab, query_tab = st.tabs([
|
||||
"**Intro**",
|
||||
"**1. Prompt Generation**",
|
||||
"**2. Prompt Configuration**",
|
||||
"**3. Index**",
|
||||
"**4. Query**",
|
||||
])
|
||||
# display only the main tab if a connection to an existing APIM has not been initialized
|
||||
with main_tab:
|
||||
tabs.get_main_tab(initialized)
|
||||
|
||||
# if not initialized, only main tab is displayed
|
||||
if initialized:
|
||||
# assign API request information
|
||||
# setup API request information
|
||||
COLUMN_WIDTHS = [0.275, 0.45, 0.275]
|
||||
api_url = st.session_state[EnvVars.DEPLOYMENT_URL.value]
|
||||
apim_url = st.session_state[EnvVars.DEPLOYMENT_URL.value]
|
||||
apim_key = st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value]
|
||||
client = GraphragAPI(api_url, apim_key)
|
||||
# perform health check to verify connectivity
|
||||
client = GraphragAPI(apim_url, apim_key)
|
||||
if not client.health_check_passed():
|
||||
st.error("APIM Connection Error")
|
||||
st.stop()
|
||||
indexPipe = IndexPipeline(client, COLUMN_WIDTHS)
|
||||
|
||||
# display tabs
|
||||
with prompt_gen_tab:
|
||||
tabs.get_prompt_generation_tab(client, COLUMN_WIDTHS)
|
||||
@ -48,9 +48,7 @@ def graphrag_app(initialized: bool):
|
||||
tabs.get_index_tab(indexPipe)
|
||||
with query_tab:
|
||||
tabs.get_query_tab(client)
|
||||
|
||||
deployer_email = os.getenv("DEPLOYER_EMAIL", "deployer@email.com")
|
||||
|
||||
footer = f"""
|
||||
<div class="footer">
|
||||
<p> Responses may be inaccurate; please review all responses for accuracy. Learn more about Azure OpenAI code of conduct <a href="https://learn.microsoft.com/en-us/legal/cognitive-services/openai/code-of-conduct"> here</a>. </br> For feedback, email us at <a href="mailto:{deployer_email}">{deployer_email}</a>.</p>
|
||||
|
||||
3109
frontend/poetry.lock
generated
Normal file
3109
frontend/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
55
frontend/pyproject.toml
Normal file
55
frontend/pyproject.toml
Normal file
@ -0,0 +1,55 @@
|
||||
[tool.poetry]
|
||||
name = "graphrag-solution-accelerator"
|
||||
version = "0.1.1"
|
||||
description = ""
|
||||
authors = [
|
||||
"Josh Bradley <joshbradley@microsoft.com>",
|
||||
"Newman Cheng <newmancheng@microsoft.com>",
|
||||
"Christine DiFonzo <cdifonzo@microsoft.com>",
|
||||
"Gabriel Nieves <gnievesponce@microsoft.com>",
|
||||
"Douglas Orbaker <dorbaker@microsoft.com>",
|
||||
"Chris Sanchez <chrissanchez@microsoft.com>",
|
||||
"Shane Solomon <shane.solomon@microsoft.com>",
|
||||
]
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
package-mode = false
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "~3.10"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
detect-secrets = ">=1.5.0"
|
||||
devtools = ">=0.12.2"
|
||||
flake8 = ">=6.1.0"
|
||||
ipython = "*"
|
||||
jupyter = "*"
|
||||
pre-commit = ">=3.6.0"
|
||||
ruff = ">=0.1.13"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = ">=8.2.1"
|
||||
wikipedia = ">=1.4.0"
|
||||
|
||||
[tool.poetry.group.frontend.dependencies]
|
||||
requests = "*"
|
||||
streamlit = ">=1.38.0"
|
||||
streamlit-nested-layout = "*"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
line-length = 88
|
||||
indent-width = 4
|
||||
|
||||
[tool.ruff.format]
|
||||
preview = true
|
||||
quote-style = "double"
|
||||
|
||||
[tool.ruff.lint]
|
||||
preview = true
|
||||
select = ["E", "F", "I"]
|
||||
ignore = ["E402", "E501", "F821"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@ -50,7 +50,6 @@ class IndexPipeline:
|
||||
key_prefix="index",
|
||||
disable_other_input=disable_other_input,
|
||||
)
|
||||
|
||||
if select_storage_name != "":
|
||||
disable_other_input = True
|
||||
|
||||
@ -138,7 +137,6 @@ class IndexPipeline:
|
||||
divider=True,
|
||||
help="Select an index to check the status of what stage indexing is in. Indexing must be complete in order to be able to execute queries.",
|
||||
)
|
||||
|
||||
options_indexes = self.client.get_index_names()
|
||||
# create logic for defaulting to running job index if one exists
|
||||
new_index_name = st.session_state["index-name-input"]
|
||||
|
||||
@ -26,8 +26,7 @@ def login():
|
||||
form_submit = st.form_submit_button("Login")
|
||||
if form_submit:
|
||||
client = GraphragAPI(apim_url, apim_sub_key)
|
||||
status_code = client.health_check()
|
||||
if status_code == 200:
|
||||
if client.health_check_passed():
|
||||
st.success("Login Successful")
|
||||
st.session_state[EnvVars.DEPLOYMENT_URL.value] = apim_url
|
||||
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = apim_sub_key
|
||||
@ -36,4 +35,3 @@ def login():
|
||||
else:
|
||||
st.error("Login Failed")
|
||||
st.error("Please check the APIM Gateway URL and Subscription Key")
|
||||
return status_code
|
||||
|
||||
@ -20,7 +20,7 @@ class GraphQuery:
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_type: Literal["Global Streaming", "Global", "Local"],
|
||||
query_type: Literal["Global Streaming", "Local Streaming", "Global", "Local"],
|
||||
search_index: str | list[str],
|
||||
query: str,
|
||||
) -> None:
|
||||
@ -51,6 +51,8 @@ class GraphQuery:
|
||||
match query_type:
|
||||
case "Global Streaming":
|
||||
_ = self.global_streaming_search(search_index, query)
|
||||
case "Local Streaming":
|
||||
_ = self.local_streaming_search(search_index, query)
|
||||
case "Global":
|
||||
_ = self.global_search(search_index, query)
|
||||
case "Local":
|
||||
@ -109,7 +111,73 @@ class GraphQuery:
|
||||
"Double-click on content to expand text", "red", False
|
||||
)
|
||||
)
|
||||
self._build_st_dataframe(context_list)
|
||||
self._build_st_dataframe(
|
||||
context_list[0]["reports"], drop_columns=[]
|
||||
)
|
||||
else:
|
||||
print(query_response.reason, query_response.content)
|
||||
raise Exception("Received unexpected response from server")
|
||||
|
||||
def local_streaming_search(self, search_index: str | list[str], query: str) -> None:
|
||||
"""
|
||||
Executes a local streaming query on the specified index.
|
||||
Handles the response and displays the results in the Streamlit app.
|
||||
"""
|
||||
query_response = self.client.local_streaming_query(search_index, query)
|
||||
assistant_response = ""
|
||||
context_list = []
|
||||
|
||||
if query_response.status_code == 200:
|
||||
text_placeholder = st.empty()
|
||||
for chunk in query_response.iter_lines(
|
||||
# allow up to 256KB to avoid excessive many reads
|
||||
chunk_size=256 * GraphQuery.KILOBYTE,
|
||||
decode_unicode=True,
|
||||
):
|
||||
try:
|
||||
payload = json.loads(chunk)
|
||||
except json.JSONDecodeError as e:
|
||||
# In the event that a chunk is not a complete JSON object,
|
||||
# document it for further analysis.
|
||||
print(chunk)
|
||||
raise e
|
||||
|
||||
token = payload["token"]
|
||||
context = payload["context"]
|
||||
if (token != "<EOM>") and (context is None):
|
||||
assistant_response += token
|
||||
text_placeholder.write(assistant_response)
|
||||
elif (token == "<EOM>") and (context is not None):
|
||||
context_list.append(context)
|
||||
|
||||
if not assistant_response:
|
||||
st.write(
|
||||
self.format_md_text(
|
||||
"Not enough contextual data to support your query: No results found.\tTry another query.",
|
||||
"red",
|
||||
True,
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
with self._create_section_expander("Query Context"):
|
||||
st.write(
|
||||
self.format_md_text(
|
||||
"Double-click on content to expand text", "red", False
|
||||
)
|
||||
)
|
||||
self._build_st_dataframe(
|
||||
context_list[0]["reports"], drop_columns=[]
|
||||
)
|
||||
self._build_st_dataframe(
|
||||
context_list[0]["entities"], drop_columns=[]
|
||||
)
|
||||
self._build_st_dataframe(
|
||||
context_list[0]["relationships"], drop_columns=[]
|
||||
)
|
||||
self._build_st_dataframe(
|
||||
context_list[0]["sources"], drop_columns=[]
|
||||
)
|
||||
else:
|
||||
print(query_response.reason, query_response.content)
|
||||
raise Exception("Received unexpected response from server")
|
||||
@ -216,14 +284,12 @@ class GraphQuery:
|
||||
drop_columns: list[str] = ["id", "index_id", "index_name", "in_context"],
|
||||
entity_df: bool = False,
|
||||
rel_df: bool = False,
|
||||
) -> st.dataframe:
|
||||
) -> st.dataframe: # type: ignore
|
||||
df_context = (
|
||||
data if isinstance(data, pd.DataFrame) else pd.DataFrame.from_records(data)
|
||||
)
|
||||
if any(drop_columns):
|
||||
for column in drop_columns:
|
||||
if column in df_context.columns:
|
||||
df_context = df_context.drop(column, axis=1)
|
||||
df_context.drop(columns=drop_columns, inplace=True, axis=1, errors="ignore")
|
||||
if entity_df:
|
||||
return st.dataframe(
|
||||
df_context,
|
||||
@ -267,7 +333,7 @@ class GraphQuery:
|
||||
|
||||
def _create_section_expander(
|
||||
self, title: str, color: str = "blue", bold: bool = True, expanded: bool = False
|
||||
) -> st.expander:
|
||||
) -> st.expander: # type: ignore
|
||||
"""
|
||||
Creates an expander in the Streamlit app with the specified title and content.
|
||||
"""
|
||||
|
||||
@ -63,8 +63,6 @@ def get_prompt_generation_tab(
|
||||
"""
|
||||
Displays content of Prompt Generation Tab
|
||||
"""
|
||||
# hard set limit to 5 files to reduce overly long processing times and to reduce over sampling errors.
|
||||
num_chunks = num_chunks if num_chunks <= 5 else 5
|
||||
_, col2, _ = st.columns(column_widths)
|
||||
with col2:
|
||||
st.header(
|
||||
@ -133,24 +131,10 @@ def get_prompt_generation_tab(
|
||||
"Prompts generated successfully! Move on to the next tab to configure the prompts."
|
||||
)
|
||||
else:
|
||||
# assume limit parameter is too high
|
||||
# limit parameter was too high
|
||||
st.warning(
|
||||
"You do not have enough data to generate prompts. Retrying with a smaller sample size."
|
||||
)
|
||||
while num_chunks > 1:
|
||||
num_chunks -= 1
|
||||
generated = generate_and_extract_prompts(
|
||||
client=client,
|
||||
storage_name=select_prompt_storage,
|
||||
limit=num_chunks,
|
||||
)
|
||||
if not isinstance(generated, Exception):
|
||||
st.success(
|
||||
"Prompts generated successfully! Move on to the next tab to configure the prompts."
|
||||
)
|
||||
break
|
||||
else:
|
||||
st.warning(f"Retrying with sample size: {num_chunks}")
|
||||
|
||||
|
||||
def get_prompt_configuration_tab(
|
||||
@ -237,7 +221,7 @@ def get_query_tab(client: GraphragAPI) -> None:
|
||||
with col1:
|
||||
query_type = st.selectbox(
|
||||
"Query Type",
|
||||
["Global Streaming", "Global", "Local"],
|
||||
["Global Streaming", "Local Streaming", "Global", "Local"],
|
||||
help="Select the query type - Each yeilds different results of specificity. Global queries focus on the entire graph structure. Local queries focus on a set of communities (subgraphs) in the graph that are more connected to each other than they are to the rest of the graph structure and can focus on very specific entities in the graph. Global streaming is a global query that displays results as they appear live.",
|
||||
)
|
||||
with col2:
|
||||
@ -253,7 +237,6 @@ def get_query_tab(client: GraphragAPI) -> None:
|
||||
|
||||
disabled = True if not any(select_index_search) else False
|
||||
col3, col4 = st.columns([0.8, 0.2])
|
||||
|
||||
with col3:
|
||||
search_bar = st.text_input("Query", key="search-query", disabled=disabled)
|
||||
with col4:
|
||||
|
||||
@ -4,30 +4,30 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PromptKeys(Enum):
|
||||
class PromptKeys(str, Enum):
|
||||
ENTITY = "entity_extraction"
|
||||
SUMMARY = "summarize_descriptions"
|
||||
COMMUNITY = "community_report"
|
||||
|
||||
|
||||
class PromptFileNames(Enum):
|
||||
class PromptFileNames(str, Enum):
|
||||
ENTITY = "entity_extraction_prompt.txt"
|
||||
SUMMARY = "summarize_descriptions_prompt.txt"
|
||||
COMMUNITY = "community_report_prompt.txt"
|
||||
|
||||
|
||||
class PromptTextAreas(Enum):
|
||||
class PromptTextAreas(str, Enum):
|
||||
ENTITY = "entity_text_area"
|
||||
SUMMARY = "summary_text_area"
|
||||
COMMUNITY = "community_text_area"
|
||||
|
||||
|
||||
class StorageIndexVars(Enum):
|
||||
class StorageIndexVars(str, Enum):
|
||||
SELECTED_STORAGE = "selected_storage"
|
||||
INPUT_STORAGE = "input_storage"
|
||||
SELECTED_INDEX = "selected_index"
|
||||
|
||||
|
||||
class EnvVars(Enum):
|
||||
class EnvVars(str, Enum):
|
||||
APIM_SUBSCRIPTION_KEY = "APIM_SUBSCRIPTION_KEY"
|
||||
DEPLOYMENT_URL = "DEPLOYMENT_URL"
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import Optional
|
||||
from zipfile import ZipFile
|
||||
|
||||
import streamlit as st
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
||||
from src.enums import EnvVars, PromptKeys, StorageIndexVars
|
||||
from src.graphrag_api import GraphragAPI
|
||||
@ -17,7 +16,7 @@ This module contains functions that are used across the Streamlit app.
|
||||
"""
|
||||
|
||||
|
||||
def initialize_app(env_file: str = ".env", css_file: str = "style.css") -> bool:
|
||||
def initialize_app(css_file: str = "style.css") -> bool:
|
||||
"""
|
||||
Initialize the Streamlit app with the necessary configurations.
|
||||
"""
|
||||
@ -31,10 +30,7 @@ def initialize_app(env_file: str = ".env", css_file: str = "style.css") -> bool:
|
||||
# initialize session state variables
|
||||
set_session_state_variables()
|
||||
|
||||
# load environment variables
|
||||
_ = load_dotenv(find_dotenv(filename=env_file) or None, override=True)
|
||||
|
||||
# either load from .env file or from session state
|
||||
# load settings from environment variables
|
||||
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value] = os.getenv(
|
||||
EnvVars.APIM_SUBSCRIPTION_KEY.value,
|
||||
st.session_state[EnvVars.APIM_SUBSCRIPTION_KEY.value],
|
||||
|
||||
@ -131,17 +131,16 @@ class GraphragAPI:
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
|
||||
def health_check(self) -> int | Response:
|
||||
def health_check_passed(self) -> bool:
|
||||
"""
|
||||
Check the health of the APIM endpoint.
|
||||
"""
|
||||
url = self.api_url + "/health"
|
||||
try:
|
||||
response = requests.get(url, headers=self.headers)
|
||||
return response.status_code
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
return e
|
||||
return response.ok
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def query_index(self, index_name: str | list[str], query_type: str, query: str):
|
||||
"""
|
||||
@ -174,7 +173,25 @@ class GraphragAPI:
|
||||
"""
|
||||
Returns a streaming response object for a global query.
|
||||
"""
|
||||
url = f"{self.api_url}/experimental/query/global/streaming"
|
||||
url = f"{self.api_url}/query/streaming/global"
|
||||
try:
|
||||
query_response = requests.post(
|
||||
url,
|
||||
json={"index_name": index_name, "query": query},
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
)
|
||||
return query_response
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
|
||||
def local_streaming_query(
|
||||
self, index_name: str | list[str], query: str
|
||||
) -> Response | None:
|
||||
"""
|
||||
Returns a streaming response object for a global query.
|
||||
"""
|
||||
url = f"{self.api_url}/query/streaming/local"
|
||||
try:
|
||||
query_response = requests.post(
|
||||
url,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user