mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-10-15 10:48:39 +00:00
refactor and reorganize indexing code out of api code
This commit is contained in:
parent
e85c9c006e
commit
a8bf6733df
@ -20,19 +20,19 @@ spec:
|
||||
serviceAccountName: PLACEHOLDER
|
||||
restartPolicy: OnFailure
|
||||
containers:
|
||||
- name: index-job-manager
|
||||
image: PLACEHOLDER
|
||||
imagePullPolicy: Always
|
||||
resources:
|
||||
requests:
|
||||
cpu: "0.5"
|
||||
memory: "0.5Gi"
|
||||
limits:
|
||||
cpu: "1"
|
||||
memory: "1Gi"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: graphrag
|
||||
command:
|
||||
- python
|
||||
- "manage-indexing-jobs.py"
|
||||
- name: index-job-manager
|
||||
image: PLACEHOLDER
|
||||
imagePullPolicy: Always
|
||||
resources:
|
||||
requests:
|
||||
cpu: "0.5"
|
||||
memory: "0.5Gi"
|
||||
limits:
|
||||
cpu: "1"
|
||||
memory: "1Gi"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: graphrag
|
||||
command:
|
||||
- python
|
||||
- "manage-indexing-jobs.py"
|
@ -21,17 +21,17 @@ spec:
|
||||
nodeSelector:
|
||||
workload: graphrag-indexing
|
||||
containers:
|
||||
- name: graphrag
|
||||
image: PLACEHOLDER
|
||||
imagePullPolicy: Always
|
||||
resources:
|
||||
requests:
|
||||
cpu: "5"
|
||||
memory: "36Gi"
|
||||
limits:
|
||||
cpu: "8"
|
||||
memory: "64Gi"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: graphrag
|
||||
command: [PLACEHOLDER]
|
||||
- name: graphrag
|
||||
image: PLACEHOLDER
|
||||
imagePullPolicy: Always
|
||||
resources:
|
||||
requests:
|
||||
cpu: "5"
|
||||
memory: "36Gi"
|
||||
limits:
|
||||
cpu: "8"
|
||||
memory: "64Gi"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: graphrag
|
||||
command: [PLACEHOLDER]
|
@ -18,9 +18,9 @@ from kubernetes import (
|
||||
)
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.api.common import sanitize_name
|
||||
from src.logger.logger_singleton import LoggerSingleton
|
||||
from src.typing.pipeline import PipelineJobState
|
||||
from src.utils.common import sanitize_name
|
||||
from src.utils.pipeline import PipelineJob
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ def schedule_indexing_job(index_name: str):
|
||||
)
|
||||
except Exception:
|
||||
reporter = LoggerSingleton().get_instance()
|
||||
reporter.on_error(
|
||||
reporter.error(
|
||||
"Index job manager encountered error scheduling indexing job",
|
||||
)
|
||||
# In the event of a catastrophic scheduling failure, something in k8s or the job manifest is likely broken.
|
||||
@ -68,14 +68,14 @@ def _generate_aks_job_manifest(
|
||||
The manifest must be valid YAML with certain values replaced by the provided arguments.
|
||||
"""
|
||||
# NOTE: this file location is relative to the WORKDIR set in Dockerfile-backend
|
||||
with open("indexing-job-template.yaml", "r") as f:
|
||||
with open("index-job.yaml", "r") as f:
|
||||
manifest = yaml.safe_load(f)
|
||||
manifest["metadata"]["name"] = f"indexing-job-{sanitize_name(index_name)}"
|
||||
manifest["spec"]["template"]["spec"]["serviceAccountName"] = service_account_name
|
||||
manifest["spec"]["template"]["spec"]["containers"][0]["image"] = docker_image_name
|
||||
manifest["spec"]["template"]["spec"]["containers"][0]["command"] = [
|
||||
"python",
|
||||
"run-indexing-job.py",
|
||||
"src/indexer/indexer.py",
|
||||
f"-i={index_name}",
|
||||
]
|
||||
return manifest
|
||||
|
321
backend/poetry.lock
generated
321
backend/poetry.lock
generated
@ -215,17 +215,6 @@ files = [
|
||||
[package.dependencies]
|
||||
six = "*"
|
||||
|
||||
[[package]]
|
||||
name = "applicationinsights"
|
||||
version = "0.11.10"
|
||||
description = "This project extends the Application Insights API surface to support Python."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "applicationinsights-0.11.10-py2.py3-none-any.whl", hash = "sha256:e89a890db1c6906b6a7d0bcfd617dac83974773c64573147c8d6654f9cf2a6ea"},
|
||||
{file = "applicationinsights-0.11.10.tar.gz", hash = "sha256:0b761f3ef0680acf4731906dfc1807faa6f2a57168ae74592db0084a6099f7b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "appnope"
|
||||
version = "0.1.4"
|
||||
@ -313,6 +302,23 @@ types-python-dateutil = ">=2.8.10"
|
||||
doc = ["doc8", "sphinx (>=7.0.0)", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx_rtd_theme (>=1.3.0)"]
|
||||
test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (==3.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "asgiref"
|
||||
version = "3.8.1"
|
||||
description = "ASGI specs, helper code, and adapters"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "asgiref-3.8.1-py3-none-any.whl", hash = "sha256:3e1e3ecc849832fe52ccf2cb6686b7a55f82bb1d6aee72a58826471390335e47"},
|
||||
{file = "asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"]
|
||||
|
||||
[[package]]
|
||||
name = "asttokens"
|
||||
version = "2.4.1"
|
||||
@ -422,6 +428,21 @@ typing-extensions = ">=4.6.0"
|
||||
[package.extras]
|
||||
aio = ["aiohttp (>=3.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "azure-core-tracing-opentelemetry"
|
||||
version = "1.0.0b11"
|
||||
description = "Microsoft Azure Azure Core OpenTelemetry plugin Library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "azure-core-tracing-opentelemetry-1.0.0b11.tar.gz", hash = "sha256:a230d1555838b5d07b7594221cd639ea7bc24e29c881e5675e311c6067bad4f5"},
|
||||
{file = "azure_core_tracing_opentelemetry-1.0.0b11-py3-none-any.whl", hash = "sha256:016cefcaff2900fb5cdb7a8a7abd03e9c266622c06e26b3fe6dafa54c4b48bf5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
azure-core = ">=1.24.0,<2.0.0"
|
||||
opentelemetry-api = ">=1.12.0,<2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "azure-cosmos"
|
||||
version = "4.9.0"
|
||||
@ -471,6 +492,31 @@ msal = ">=1.30.0"
|
||||
msal-extensions = ">=1.2.0"
|
||||
typing-extensions = ">=4.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "azure-monitor-opentelemetry"
|
||||
version = "1.6.4"
|
||||
description = "Microsoft Azure Monitor Opentelemetry Distro Client Library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "azure_monitor_opentelemetry-1.6.4-py3-none-any.whl", hash = "sha256:014142ffa420bc2b287ff3bd30de6c31d64b2846423d011a8280334d7afcb01a"},
|
||||
{file = "azure_monitor_opentelemetry-1.6.4.tar.gz", hash = "sha256:9f5ce4c666caf1f9b536f8ab4ee207dff94777d568517c74f26e3327f75c3fc3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
azure-core = ">=1.28.0,<2.0.0"
|
||||
azure-core-tracing-opentelemetry = ">=1.0.0b11,<1.1.0"
|
||||
azure-monitor-opentelemetry-exporter = ">=1.0.0b31,<1.1.0"
|
||||
opentelemetry-instrumentation-django = ">=0.49b0,<1.0"
|
||||
opentelemetry-instrumentation-fastapi = ">=0.49b0,<1.0"
|
||||
opentelemetry-instrumentation-flask = ">=0.49b0,<1.0"
|
||||
opentelemetry-instrumentation-psycopg2 = ">=0.49b0,<1.0"
|
||||
opentelemetry-instrumentation-requests = ">=0.49b0,<1.0"
|
||||
opentelemetry-instrumentation-urllib = ">=0.49b0,<1.0"
|
||||
opentelemetry-instrumentation-urllib3 = ">=0.49b0,<1.0"
|
||||
opentelemetry-resource-detector-azure = ">=0.1.4,<0.2.0"
|
||||
opentelemetry-sdk = ">=1.28,<2.0"
|
||||
|
||||
[[package]]
|
||||
name = "azure-monitor-opentelemetry-exporter"
|
||||
version = "1.0.0b33"
|
||||
@ -3519,13 +3565,13 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.59.7"
|
||||
version = "1.59.8"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "openai-1.59.7-py3-none-any.whl", hash = "sha256:cfa806556226fa96df7380ab2e29814181d56fea44738c2b0e581b462c268692"},
|
||||
{file = "openai-1.59.7.tar.gz", hash = "sha256:043603def78c00befb857df9f0a16ee76a3af5984ba40cb7ee5e2f40db4646bf"},
|
||||
{file = "openai-1.59.8-py3-none-any.whl", hash = "sha256:a8b8ee35c4083b88e6da45406d883cf6bd91a98ab7dd79178b8bc24c8bfb09d9"},
|
||||
{file = "openai-1.59.8.tar.gz", hash = "sha256:ac4bda5fa9819fdc6127e8ea8a63501f425c587244bc653c7c11a8ad84f953e1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -3557,6 +3603,234 @@ files = [
|
||||
deprecated = ">=1.2.6"
|
||||
importlib-metadata = ">=6.0,<=8.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation"
|
||||
version = "0.50b0"
|
||||
description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation-0.50b0-py3-none-any.whl", hash = "sha256:b8f9fc8812de36e1c6dffa5bfc6224df258841fb387b6dfe5df15099daa10630"},
|
||||
{file = "opentelemetry_instrumentation-0.50b0.tar.gz", hash = "sha256:7d98af72de8dec5323e5202e46122e5f908592b22c6d24733aad619f07d82979"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = ">=1.4,<2.0"
|
||||
opentelemetry-semantic-conventions = "0.50b0"
|
||||
packaging = ">=18.0"
|
||||
wrapt = ">=1.0.0,<2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-asgi"
|
||||
version = "0.50b0"
|
||||
description = "ASGI instrumentation for OpenTelemetry"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation_asgi-0.50b0-py3-none-any.whl", hash = "sha256:2ba1297f746e55dec5a17fe825689da0613662fb25c004c3965a6c54b1d5be22"},
|
||||
{file = "opentelemetry_instrumentation_asgi-0.50b0.tar.gz", hash = "sha256:3ca4cb5616ae6a3e8ce86e7d5c360a8d8cc8ed722cf3dc8a5e44300774e87d49"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
asgiref = ">=3.0,<4.0"
|
||||
opentelemetry-api = ">=1.12,<2.0"
|
||||
opentelemetry-instrumentation = "0.50b0"
|
||||
opentelemetry-semantic-conventions = "0.50b0"
|
||||
opentelemetry-util-http = "0.50b0"
|
||||
|
||||
[package.extras]
|
||||
instruments = ["asgiref (>=3.0,<4.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-dbapi"
|
||||
version = "0.50b0"
|
||||
description = "OpenTelemetry Database API instrumentation"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation_dbapi-0.50b0-py3-none-any.whl", hash = "sha256:23a730c3d7372b04b8a9507d2a67c5efbf92ff718eaa002b81ffbaf2b01d270f"},
|
||||
{file = "opentelemetry_instrumentation_dbapi-0.50b0.tar.gz", hash = "sha256:2603ca39e216893026c185ca8c44c326c0a9a763d5afff2309bd6195c50b7c49"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = ">=1.12,<2.0"
|
||||
opentelemetry-instrumentation = "0.50b0"
|
||||
opentelemetry-semantic-conventions = "0.50b0"
|
||||
wrapt = ">=1.0.0,<2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-django"
|
||||
version = "0.50b0"
|
||||
description = "OpenTelemetry Instrumentation for Django"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation_django-0.50b0-py3-none-any.whl", hash = "sha256:ab7b4cd52b8f12420d968823f6bbfbc2a6ddb2af7a05fcb0d5b6755d338f1915"},
|
||||
{file = "opentelemetry_instrumentation_django-0.50b0.tar.gz", hash = "sha256:624fd0beb1ac827f2af31709c2da5cb55d8dc899c2449d6e8fcc9fa5538fd56b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = ">=1.12,<2.0"
|
||||
opentelemetry-instrumentation = "0.50b0"
|
||||
opentelemetry-instrumentation-wsgi = "0.50b0"
|
||||
opentelemetry-semantic-conventions = "0.50b0"
|
||||
opentelemetry-util-http = "0.50b0"
|
||||
|
||||
[package.extras]
|
||||
asgi = ["opentelemetry-instrumentation-asgi (==0.50b0)"]
|
||||
instruments = ["django (>=1.10)"]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-fastapi"
|
||||
version = "0.50b0"
|
||||
description = "OpenTelemetry FastAPI Instrumentation"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation_fastapi-0.50b0-py3-none-any.whl", hash = "sha256:8f03b738495e4705fbae51a2826389c7369629dace89d0f291c06ffefdff5e52"},
|
||||
{file = "opentelemetry_instrumentation_fastapi-0.50b0.tar.gz", hash = "sha256:16b9181682136da210295def2bb304a32fb9bdee9a935cdc9da43567f7c1149e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = ">=1.12,<2.0"
|
||||
opentelemetry-instrumentation = "0.50b0"
|
||||
opentelemetry-instrumentation-asgi = "0.50b0"
|
||||
opentelemetry-semantic-conventions = "0.50b0"
|
||||
opentelemetry-util-http = "0.50b0"
|
||||
|
||||
[package.extras]
|
||||
instruments = ["fastapi (>=0.58,<1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-flask"
|
||||
version = "0.50b0"
|
||||
description = "Flask instrumentation for OpenTelemetry"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation_flask-0.50b0-py3-none-any.whl", hash = "sha256:db7fb40191145f4356a793922c3fc80a33689e6a7c7c4c6def8aa1eedb0ac42a"},
|
||||
{file = "opentelemetry_instrumentation_flask-0.50b0.tar.gz", hash = "sha256:e56a820b1d43fdd5a57f7b481c4d6365210a48a1312c83af4185bc636977755f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = ">=1.12,<2.0"
|
||||
opentelemetry-instrumentation = "0.50b0"
|
||||
opentelemetry-instrumentation-wsgi = "0.50b0"
|
||||
opentelemetry-semantic-conventions = "0.50b0"
|
||||
opentelemetry-util-http = "0.50b0"
|
||||
packaging = ">=21.0"
|
||||
|
||||
[package.extras]
|
||||
instruments = ["flask (>=1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-psycopg2"
|
||||
version = "0.50b0"
|
||||
description = "OpenTelemetry psycopg2 instrumentation"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation_psycopg2-0.50b0-py3-none-any.whl", hash = "sha256:448297e63320711b5571f64bcf5d67ecf4856454c36d3bff6c3d01a4f8a48d18"},
|
||||
{file = "opentelemetry_instrumentation_psycopg2-0.50b0.tar.gz", hash = "sha256:86f8e507e98d8824f51bbc3c62121dbd4b8286063362f10b9dfa035a8da49f0b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = ">=1.12,<2.0"
|
||||
opentelemetry-instrumentation = "0.50b0"
|
||||
opentelemetry-instrumentation-dbapi = "0.50b0"
|
||||
|
||||
[package.extras]
|
||||
instruments = ["psycopg2 (>=2.7.3.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-requests"
|
||||
version = "0.50b0"
|
||||
description = "OpenTelemetry requests instrumentation"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation_requests-0.50b0-py3-none-any.whl", hash = "sha256:2c60a890988d6765de9230004d0af9071b3b2e1ddba4ca3b631cfb8a1722208d"},
|
||||
{file = "opentelemetry_instrumentation_requests-0.50b0.tar.gz", hash = "sha256:f8088c76f757985b492aad33331d21aec2f99c197472a57091c2e986a4b7ec8b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = ">=1.12,<2.0"
|
||||
opentelemetry-instrumentation = "0.50b0"
|
||||
opentelemetry-semantic-conventions = "0.50b0"
|
||||
opentelemetry-util-http = "0.50b0"
|
||||
|
||||
[package.extras]
|
||||
instruments = ["requests (>=2.0,<3.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-urllib"
|
||||
version = "0.50b0"
|
||||
description = "OpenTelemetry urllib instrumentation"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation_urllib-0.50b0-py3-none-any.whl", hash = "sha256:55024940fd41fbdd5a6ab5b6397660900b7a75e23f9ff7f61b4ae1279710a3ec"},
|
||||
{file = "opentelemetry_instrumentation_urllib-0.50b0.tar.gz", hash = "sha256:af3e9710635c3f8a5ec38adc772dfef0c1022d0196007baf4b74504e920b5d31"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = ">=1.12,<2.0"
|
||||
opentelemetry-instrumentation = "0.50b0"
|
||||
opentelemetry-semantic-conventions = "0.50b0"
|
||||
opentelemetry-util-http = "0.50b0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-urllib3"
|
||||
version = "0.50b0"
|
||||
description = "OpenTelemetry urllib3 instrumentation"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation_urllib3-0.50b0-py3-none-any.whl", hash = "sha256:c679b3908645b7d4d07c36960fe0efef490b403983e314108450146cc89bd675"},
|
||||
{file = "opentelemetry_instrumentation_urllib3-0.50b0.tar.gz", hash = "sha256:2c4a1d9f128eaf753871b1d90659c744691d039a6601ba546081347ae192bd0e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = ">=1.12,<2.0"
|
||||
opentelemetry-instrumentation = "0.50b0"
|
||||
opentelemetry-semantic-conventions = "0.50b0"
|
||||
opentelemetry-util-http = "0.50b0"
|
||||
wrapt = ">=1.0.0,<2.0.0"
|
||||
|
||||
[package.extras]
|
||||
instruments = ["urllib3 (>=1.0.0,<3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-wsgi"
|
||||
version = "0.50b0"
|
||||
description = "WSGI Middleware for OpenTelemetry"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_instrumentation_wsgi-0.50b0-py3-none-any.whl", hash = "sha256:4bc0fdf52b603507d6170a25504f0ceea358d7e90a2c0e8794b7b7eca5ea355c"},
|
||||
{file = "opentelemetry_instrumentation_wsgi-0.50b0.tar.gz", hash = "sha256:c25b5f1b664d984a41546a34cf2f893dcde6cf56922f88c475864e7df37edf4a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-api = ">=1.12,<2.0"
|
||||
opentelemetry-instrumentation = "0.50b0"
|
||||
opentelemetry-semantic-conventions = "0.50b0"
|
||||
opentelemetry-util-http = "0.50b0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-resource-detector-azure"
|
||||
version = "0.1.5"
|
||||
description = "Azure Resource Detector for OpenTelemetry"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_resource_detector_azure-0.1.5-py3-none-any.whl", hash = "sha256:4dcc5d54ab5c3b11226af39509bc98979a8b9e0f8a24c1b888783755d3bf00eb"},
|
||||
{file = "opentelemetry_resource_detector_azure-0.1.5.tar.gz", hash = "sha256:e0ba658a87c69eebc806e75398cd0e9f68a8898ea62de99bc1b7083136403710"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
opentelemetry-sdk = ">=1.21,<2.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-sdk"
|
||||
version = "1.29.0"
|
||||
@ -3588,6 +3862,17 @@ files = [
|
||||
deprecated = ">=1.2.6"
|
||||
opentelemetry-api = "1.29.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-util-http"
|
||||
version = "0.50b0"
|
||||
description = "Web util for OpenTelemetry"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "opentelemetry_util_http-0.50b0-py3-none-any.whl", hash = "sha256:21f8aedac861ffa3b850f8c0a6c373026189eb8630ac6e14a2bf8c55695cc090"},
|
||||
{file = "opentelemetry_util_http-0.50b0.tar.gz", hash = "sha256:dc4606027e1bc02aabb9533cc330dd43f874fca492e4175c31d7154f341754af"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "overrides"
|
||||
version = "7.7.0"
|
||||
@ -5874,13 +6159,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.29.0"
|
||||
version = "20.29.1"
|
||||
description = "Virtual Python Environment builder"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "virtualenv-20.29.0-py3-none-any.whl", hash = "sha256:c12311863497992dc4b8644f8ea82d3b35bb7ef8ee82e6630d76d0197c39baf9"},
|
||||
{file = "virtualenv-20.29.0.tar.gz", hash = "sha256:6345e1ff19d4b1296954cee076baaf58ff2a12a84a338c62b02eda39f20aa982"},
|
||||
{file = "virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779"},
|
||||
{file = "virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -6172,4 +6457,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "~3.10"
|
||||
content-hash = "f7039aa8ec1bb1d1f436703fcdb21f83711fd77c72ead5aa94f3d5fccd47a699"
|
||||
content-hash = "a7740ff5c5d60171f025a676b625678331eed3ac55c763c8057adb49de56f734"
|
||||
|
@ -40,12 +40,11 @@ wikipedia = ">=1.4.0"
|
||||
|
||||
[tool.poetry.group.backend.dependencies]
|
||||
adlfs = ">=2024.7.0"
|
||||
applicationinsights = ">=0.11.10"
|
||||
attrs = ">=23.2.0"
|
||||
azure-core = ">=1.30.1"
|
||||
azure-cosmos = ">=4.5.1"
|
||||
azure-identity = ">=1.15.0"
|
||||
azure-monitor-opentelemetry-exporter = "*"
|
||||
azure-monitor-opentelemetry = "^1.6.4"
|
||||
azure-search-documents = ">=11.4.0"
|
||||
azure-storage-blob = ">=12.19.0"
|
||||
environs = ">=9.5.0"
|
||||
@ -54,12 +53,10 @@ fastapi-offline = ">=1.7.3"
|
||||
fastparquet = ">=2023.10.1"
|
||||
fsspec = ">=2024.2.0"
|
||||
graphrag = { git = "https://github.com/microsoft/graphrag.git", branch = "main" }
|
||||
graspologic = ">=3.3.0"
|
||||
httpx = ">=0.25.2"
|
||||
kubernetes = ">=29.0.0"
|
||||
networkx = ">=3.2.1"
|
||||
nltk = "*"
|
||||
opentelemetry-sdk = ">=1.27.0"
|
||||
pandas = ">=2.2.1"
|
||||
pyaml-env = ">=1.2.1"
|
||||
pyarrow = ">=15.0.0"
|
||||
|
@ -6,6 +6,6 @@ asyncio_mode=auto
|
||||
; If executing these pytests locally, users may need to modify the cosmosdb connection string to use http protocol instead of https.
|
||||
; This depends on how the cosmosdb emulator has been configured (by the user) to run.
|
||||
env =
|
||||
COSMOS_CONNECTION_STRING=AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==
|
||||
COSMOS_CONNECTION_STRING=AccountEndpoint=http://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==
|
||||
STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;QueueEndpoint=http://127.0.0.1:10001/devstoreaccount1;TableEndpoint=http://127.0.0.1:10002/devstoreaccount1;
|
||||
TESTING=1
|
||||
|
@ -1,18 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from src import main # noqa: F401
|
||||
from src.api.index import _start_indexing_pipeline
|
||||
|
||||
parser = argparse.ArgumentParser(description="Kickoff indexing job.")
|
||||
parser.add_argument("-i", "--index-name", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(
|
||||
_start_indexing_pipeline(
|
||||
index_name=args.index_name,
|
||||
)
|
||||
)
|
@ -14,17 +14,17 @@ from fastapi import (
|
||||
)
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.api.common import (
|
||||
from src.logger import LoggerSingleton
|
||||
from src.typing.models import (
|
||||
BaseResponse,
|
||||
StorageNameList,
|
||||
)
|
||||
from src.utils.common import (
|
||||
delete_blob_container,
|
||||
delete_cosmos_container_item,
|
||||
sanitize_name,
|
||||
validate_blob_container_name,
|
||||
)
|
||||
from src.logger import LoggerSingleton
|
||||
from src.models import (
|
||||
BaseResponse,
|
||||
StorageNameList,
|
||||
)
|
||||
|
||||
data_route = APIRouter(
|
||||
prefix="/data",
|
||||
@ -53,7 +53,7 @@ async def get_all_data_storage_containers():
|
||||
items.append(item["human_readable_name"])
|
||||
except Exception:
|
||||
reporter = LoggerSingleton().get_instance()
|
||||
reporter.on_error("Error getting list of blob containers.")
|
||||
reporter.error("Error getting list of blob containers.")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Error getting list of blob containers."
|
||||
)
|
||||
@ -171,7 +171,7 @@ async def upload_files(
|
||||
return BaseResponse(status="File upload successful.")
|
||||
except Exception:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error("Error uploading files.", details={"files": files})
|
||||
logger.error("Error uploading files.", details={"files": files})
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error uploading files to container '{storage_name}'.",
|
||||
@ -197,7 +197,7 @@ async def delete_files(storage_name: str):
|
||||
delete_cosmos_container_item("container-store", sanitized_storage_name)
|
||||
except Exception:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error(
|
||||
logger.error(
|
||||
f"Error deleting container {storage_name}.",
|
||||
details={"Container": storage_name},
|
||||
)
|
||||
|
@ -8,11 +8,11 @@ from fastapi import (
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.api.common import (
|
||||
from src.logger import LoggerSingleton
|
||||
from src.utils.common import (
|
||||
sanitize_name,
|
||||
validate_index_file_exist,
|
||||
)
|
||||
from src.logger import LoggerSingleton
|
||||
|
||||
graph_route = APIRouter(
|
||||
prefix="/graph",
|
||||
@ -44,7 +44,7 @@ async def get_graphml_file(index_name: str):
|
||||
)
|
||||
except Exception:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error("Could not retrieve graphml file")
|
||||
logger.error("Could not retrieve graphml file")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Could not retrieve graphml file for index '{index_name}'.",
|
||||
|
@ -1,13 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
from time import time
|
||||
|
||||
import graphrag.api as api
|
||||
import yaml
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.search.documents.indexes import SearchIndexClient
|
||||
from fastapi import (
|
||||
@ -15,9 +11,6 @@ from fastapi import (
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
)
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.bootstrap import bootstrap
|
||||
from graphrag.index.create_pipeline_config import create_pipeline_config
|
||||
from kubernetes import (
|
||||
client as kubernetes_client,
|
||||
)
|
||||
@ -26,23 +19,18 @@ from kubernetes import (
|
||||
)
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.api.common import (
|
||||
delete_blob_container,
|
||||
sanitize_name,
|
||||
validate_blob_container_name,
|
||||
)
|
||||
from src.logger import (
|
||||
LoggerSingleton,
|
||||
PipelineJobWorkflowCallbacks,
|
||||
Reporters,
|
||||
load_pipeline_logger,
|
||||
)
|
||||
from src.models import (
|
||||
from src.logger import LoggerSingleton
|
||||
from src.typing.models import (
|
||||
BaseResponse,
|
||||
IndexNameList,
|
||||
IndexStatusResponse,
|
||||
)
|
||||
from src.typing.pipeline import PipelineJobState
|
||||
from src.utils.common import (
|
||||
delete_blob_container,
|
||||
sanitize_name,
|
||||
validate_blob_container_name,
|
||||
)
|
||||
from src.utils.pipeline import PipelineJob
|
||||
|
||||
index_route = APIRouter(
|
||||
@ -57,7 +45,7 @@ index_route = APIRouter(
|
||||
response_model=BaseResponse,
|
||||
responses={200: {"model": BaseResponse}},
|
||||
)
|
||||
async def setup_indexing_pipeline(
|
||||
async def schedule_indexing_job(
|
||||
storage_name: str,
|
||||
index_name: str,
|
||||
entity_extraction_prompt: UploadFile | None = None,
|
||||
@ -148,173 +136,6 @@ async def setup_indexing_pipeline(
|
||||
return BaseResponse(status="Indexing job scheduled")
|
||||
|
||||
|
||||
async def _start_indexing_pipeline(index_name: str):
|
||||
# get sanitized name
|
||||
sanitized_index_name = sanitize_name(index_name)
|
||||
|
||||
# update or create new item in container-store in cosmosDB
|
||||
azure_client_manager = AzureClientManager()
|
||||
blob_service_client = azure_client_manager.get_blob_service_client()
|
||||
if not blob_service_client.get_container_client(sanitized_index_name).exists():
|
||||
blob_service_client.create_container(sanitized_index_name)
|
||||
|
||||
cosmos_container_client = azure_client_manager.get_cosmos_container_client(
|
||||
database="graphrag", container="container-store"
|
||||
)
|
||||
cosmos_container_client.upsert_item({
|
||||
"id": sanitized_index_name,
|
||||
"human_readable_name": index_name,
|
||||
"type": "index",
|
||||
})
|
||||
|
||||
logger = LoggerSingleton().get_instance()
|
||||
pipelinejob = PipelineJob()
|
||||
pipeline_job = pipelinejob.load_item(sanitized_index_name)
|
||||
sanitized_storage_name = pipeline_job.sanitized_storage_name
|
||||
storage_name = pipeline_job.human_readable_index_name
|
||||
|
||||
# download nltk dependencies
|
||||
bootstrap()
|
||||
|
||||
# load custom pipeline settings
|
||||
this_directory = os.path.dirname(
|
||||
os.path.abspath(inspect.getfile(inspect.currentframe()))
|
||||
)
|
||||
data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml"))
|
||||
# dynamically set some values
|
||||
data["input"]["container_name"] = sanitized_storage_name
|
||||
data["storage"]["container_name"] = sanitized_index_name
|
||||
data["reporting"]["container_name"] = sanitized_index_name
|
||||
data["cache"]["container_name"] = sanitized_index_name
|
||||
if "vector_store" in data["embeddings"]:
|
||||
data["embeddings"]["vector_store"]["collection_name"] = (
|
||||
f"{sanitized_index_name}_description_embedding"
|
||||
)
|
||||
|
||||
# set prompt for entity extraction
|
||||
if pipeline_job.entity_extraction_prompt:
|
||||
fname = "entity-extraction-prompt.txt"
|
||||
with open(fname, "w") as outfile:
|
||||
outfile.write(pipeline_job.entity_extraction_prompt)
|
||||
data["entity_extraction"]["prompt"] = fname
|
||||
else:
|
||||
data.pop("entity_extraction")
|
||||
|
||||
# set prompt for summarize descriptions
|
||||
if pipeline_job.summarize_descriptions_prompt:
|
||||
fname = "summarize-descriptions-prompt.txt"
|
||||
with open(fname, "w") as outfile:
|
||||
outfile.write(pipeline_job.summarize_descriptions_prompt)
|
||||
data["summarize_descriptions"]["prompt"] = fname
|
||||
else:
|
||||
data.pop("summarize_descriptions")
|
||||
|
||||
# set prompt for community report
|
||||
if pipeline_job.community_report_prompt:
|
||||
fname = "community-report-prompt.txt"
|
||||
with open(fname, "w") as outfile:
|
||||
outfile.write(pipeline_job.community_report_prompt)
|
||||
data["community_reports"]["prompt"] = fname
|
||||
else:
|
||||
data.pop("community_reports")
|
||||
|
||||
# generate a default GraphRagConfig and override with custom settings
|
||||
parameters = create_graphrag_config(data, ".")
|
||||
|
||||
# reset pipeline job details
|
||||
pipeline_job.status = PipelineJobState.RUNNING
|
||||
pipeline_job.all_workflows = []
|
||||
pipeline_job.completed_workflows = []
|
||||
pipeline_job.failed_workflows = []
|
||||
pipeline_config = create_pipeline_config(parameters)
|
||||
for workflow in pipeline_config.workflows:
|
||||
pipeline_job.all_workflows.append(workflow.name)
|
||||
|
||||
# create new loggers/callbacks just for this job
|
||||
loggers = []
|
||||
logger_names = os.getenv("REPORTERS", Reporters.CONSOLE.name.upper()).split(",")
|
||||
for logger_name in logger_names:
|
||||
try:
|
||||
loggers.append(Reporters[logger_name.upper()])
|
||||
except KeyError:
|
||||
raise ValueError(f"Unknown logger type: {logger_name}")
|
||||
workflow_callbacks = load_pipeline_logger(
|
||||
index_name=index_name,
|
||||
num_workflow_steps=len(pipeline_job.all_workflows),
|
||||
reporting_dir=sanitized_index_name,
|
||||
reporters=loggers,
|
||||
)
|
||||
|
||||
# add pipeline job callback to monitor job progress
|
||||
pipeline_job_callback = PipelineJobWorkflowCallbacks(pipeline_job)
|
||||
|
||||
# run the pipeline
|
||||
try:
|
||||
await api.build_index(
|
||||
config=parameters,
|
||||
callbacks=[workflow_callbacks, pipeline_job_callback],
|
||||
)
|
||||
# if job is done, check if any workflow steps failed
|
||||
if len(pipeline_job.failed_workflows) > 0:
|
||||
pipeline_job.status = PipelineJobState.FAILED
|
||||
workflow_callbacks.on_log(
|
||||
message=f"Indexing pipeline encountered error for index'{index_name}'.",
|
||||
details={
|
||||
"index": index_name,
|
||||
"storage_name": storage_name,
|
||||
"status_message": "indexing pipeline encountered error",
|
||||
},
|
||||
)
|
||||
else:
|
||||
# record the workflow completion
|
||||
pipeline_job.status = PipelineJobState.COMPLETE
|
||||
pipeline_job.percent_complete = 100
|
||||
workflow_callbacks.on_log(
|
||||
message=f"Indexing pipeline complete for index'{index_name}'.",
|
||||
details={
|
||||
"index": index_name,
|
||||
"storage_name": storage_name,
|
||||
"status_message": "indexing pipeline complete",
|
||||
},
|
||||
)
|
||||
|
||||
pipeline_job.progress = (
|
||||
f"{len(pipeline_job.completed_workflows)} out of "
|
||||
f"{len(pipeline_job.all_workflows)} workflows completed successfully."
|
||||
)
|
||||
|
||||
del workflow_callbacks # garbage collect
|
||||
if pipeline_job.status == PipelineJobState.FAILED:
|
||||
exit(1) # signal to AKS that indexing job failed
|
||||
|
||||
except Exception as e:
|
||||
pipeline_job.status = PipelineJobState.FAILED
|
||||
|
||||
# update failed state in cosmos db
|
||||
error_details = {
|
||||
"index": index_name,
|
||||
"storage_name": storage_name,
|
||||
}
|
||||
# log error in local index directory logs
|
||||
workflow_callbacks.on_error(
|
||||
message=f"Indexing pipeline failed for index '{index_name}'.",
|
||||
cause=e,
|
||||
stack=traceback.format_exc(),
|
||||
details=error_details,
|
||||
)
|
||||
# log error in global index directory logs
|
||||
logger.on_error(
|
||||
message=f"Indexing pipeline failed for index '{index_name}'.",
|
||||
cause=e,
|
||||
stack=traceback.format_exc(),
|
||||
details=error_details,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error encountered during indexing job for index '{index_name}'.",
|
||||
)
|
||||
|
||||
|
||||
@index_route.get(
|
||||
"",
|
||||
summary="Get all indexes",
|
||||
@ -336,7 +157,7 @@ async def get_all_indexes():
|
||||
items.append(item["human_readable_name"])
|
||||
except Exception:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error("Error retrieving index names")
|
||||
logger.error("Error retrieving index names")
|
||||
return IndexNameList(index_name=items)
|
||||
|
||||
|
||||
@ -367,7 +188,7 @@ def _delete_k8s_job(job_name: str, namespace: str) -> None:
|
||||
batch_v1 = kubernetes_client.BatchV1Api()
|
||||
batch_v1.delete_namespaced_job(name=job_name, namespace=namespace)
|
||||
except Exception:
|
||||
logger.on_error(
|
||||
logger.error(
|
||||
message=f"Error deleting k8s job {job_name}.",
|
||||
details={"container": job_name},
|
||||
)
|
||||
@ -378,7 +199,7 @@ def _delete_k8s_job(job_name: str, namespace: str) -> None:
|
||||
if job_pod:
|
||||
core_v1.delete_namespaced_pod(job_pod, namespace=namespace)
|
||||
except Exception:
|
||||
logger.on_error(
|
||||
logger.error(
|
||||
message=f"Error deleting k8s pod for job {job_name}.",
|
||||
details={"container": job_name},
|
||||
)
|
||||
@ -441,7 +262,7 @@ async def delete_index(index_name: str):
|
||||
|
||||
except Exception:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error(
|
||||
logger.error(
|
||||
message=f"Error encountered while deleting all data for index {index_name}.",
|
||||
stack=None,
|
||||
details={"container": index_name},
|
||||
|
@ -14,17 +14,13 @@ from fastapi import (
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.api.common import (
|
||||
sanitize_name,
|
||||
)
|
||||
from src.logger import LoggerSingleton
|
||||
from src.utils.common import sanitize_name
|
||||
|
||||
index_configuration_route = APIRouter(
|
||||
prefix="/index/config", tags=["Index Configuration"]
|
||||
)
|
||||
prompt_tuning_route = APIRouter(prefix="/index/config", tags=["Index Configuration"])
|
||||
|
||||
|
||||
@index_configuration_route.get(
|
||||
@prompt_tuning_route.get(
|
||||
"/prompts",
|
||||
summary="Generate prompts from user-provided data.",
|
||||
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
|
||||
@ -66,7 +62,7 @@ async def generate_prompts(storage_name: str, limit: int = 5):
|
||||
error_details = {
|
||||
"storage_name": storage_name,
|
||||
}
|
||||
logger.on_error(
|
||||
logger.error(
|
||||
message="Auto-prompt generation failed.",
|
||||
cause=e,
|
||||
stack=traceback.format_exc(),
|
@ -26,17 +26,17 @@ from graphrag.vector_stores.base import (
|
||||
)
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.api.common import (
|
||||
sanitize_name,
|
||||
validate_index_file_exist,
|
||||
)
|
||||
from src.logger import LoggerSingleton
|
||||
from src.models import (
|
||||
from src.typing.models import (
|
||||
GraphRequest,
|
||||
GraphResponse,
|
||||
)
|
||||
from src.typing.pipeline import PipelineJobState
|
||||
from src.utils import query as query_helper
|
||||
from src.utils.common import (
|
||||
get_df,
|
||||
sanitize_name,
|
||||
validate_index_file_exist,
|
||||
)
|
||||
from src.utils.pipeline import PipelineJob
|
||||
|
||||
query_route = APIRouter(
|
||||
@ -116,7 +116,7 @@ async def global_query(request: GraphRequest):
|
||||
|
||||
# read the parquet files into DataFrames and add provenance information
|
||||
# note that nodes need to be set before communities so that max community id makes sense
|
||||
nodes_df = query_helper.get_df(nodes_table_path)
|
||||
nodes_df = get_df(nodes_table_path)
|
||||
for i in nodes_df["human_readable_id"]:
|
||||
links["nodes"][i + max_vals["nodes"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -134,7 +134,7 @@ async def global_query(request: GraphRequest):
|
||||
max_vals["nodes"] = nodes_df["human_readable_id"].max()
|
||||
nodes_dfs.append(nodes_df)
|
||||
|
||||
community_df = query_helper.get_df(community_report_table_path)
|
||||
community_df = get_df(community_report_table_path)
|
||||
for i in community_df["community"].astype(int):
|
||||
links["community"][i + max_vals["community"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -146,7 +146,7 @@ async def global_query(request: GraphRequest):
|
||||
max_vals["community"] = community_df["community"].astype(int).max()
|
||||
community_dfs.append(community_df)
|
||||
|
||||
entities_df = query_helper.get_df(entities_table_path)
|
||||
entities_df = get_df(entities_table_path)
|
||||
for i in entities_df["human_readable_id"]:
|
||||
links["entities"][i + max_vals["entities"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -197,7 +197,7 @@ async def global_query(request: GraphRequest):
|
||||
return GraphResponse(result=result[0], context_data=context_data)
|
||||
except Exception as e:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error(
|
||||
logger.error(
|
||||
message="Could not perform global search.",
|
||||
cause=e,
|
||||
stack=traceback.format_exc(),
|
||||
@ -287,7 +287,7 @@ async def local_query(request: GraphRequest):
|
||||
# read the parquet files into DataFrames and add provenance information
|
||||
|
||||
# note that nodes need to set before communities to that max community id makes sense
|
||||
nodes_df = query_helper.get_df(nodes_table_path)
|
||||
nodes_df = get_df(nodes_table_path)
|
||||
for i in nodes_df["human_readable_id"]:
|
||||
links["nodes"][i + max_vals["nodes"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -306,7 +306,7 @@ async def local_query(request: GraphRequest):
|
||||
max_vals["nodes"] = nodes_df["human_readable_id"].max()
|
||||
nodes_dfs.append(nodes_df)
|
||||
|
||||
community_df = query_helper.get_df(community_report_table_path)
|
||||
community_df = get_df(community_report_table_path)
|
||||
for i in community_df["community"].astype(int):
|
||||
links["community"][i + max_vals["community"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -318,7 +318,7 @@ async def local_query(request: GraphRequest):
|
||||
max_vals["community"] = community_df["community"].astype(int).max()
|
||||
community_dfs.append(community_df)
|
||||
|
||||
entities_df = query_helper.get_df(entities_table_path)
|
||||
entities_df = get_df(entities_table_path)
|
||||
for i in entities_df["human_readable_id"]:
|
||||
links["entities"][i + max_vals["entities"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -334,7 +334,7 @@ async def local_query(request: GraphRequest):
|
||||
max_vals["entities"] = entities_df["human_readable_id"].max()
|
||||
entities_dfs.append(entities_df)
|
||||
|
||||
relationships_df = query_helper.get_df(relationships_table_path)
|
||||
relationships_df = get_df(relationships_table_path)
|
||||
for i in relationships_df["human_readable_id"].astype(int):
|
||||
links["relationships"][i + max_vals["relationships"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -361,13 +361,13 @@ async def local_query(request: GraphRequest):
|
||||
)
|
||||
relationships_dfs.append(relationships_df)
|
||||
|
||||
text_units_df = query_helper.get_df(text_units_table_path)
|
||||
text_units_df = get_df(text_units_table_path)
|
||||
text_units_df["id"] = text_units_df["id"].apply(lambda x: f"{x}-{index_name}")
|
||||
text_units_dfs.append(text_units_df)
|
||||
|
||||
index_container_client = blob_service_client.get_container_client(index_name)
|
||||
if index_container_client.get_blob_client(COVARIATES_TABLE).exists():
|
||||
covariates_df = query_helper.get_df(covariates_table_path)
|
||||
covariates_df = get_df(covariates_table_path)
|
||||
if i in covariates_df["human_readable_id"].astype(int):
|
||||
links["covariates"][i + max_vals["covariates"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
|
@ -22,14 +22,14 @@ from graphrag.api.query import (
|
||||
from graphrag.config import create_graphrag_config
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.api.common import (
|
||||
from src.api.query import _is_index_complete
|
||||
from src.logger import LoggerSingleton
|
||||
from src.typing.models import GraphRequest
|
||||
from src.utils.common import (
|
||||
get_df,
|
||||
sanitize_name,
|
||||
validate_index_file_exist,
|
||||
)
|
||||
from src.api.query import _is_index_complete
|
||||
from src.logger import LoggerSingleton
|
||||
from src.models import GraphRequest
|
||||
from src.utils import query as query_helper
|
||||
|
||||
from .query import _get_embedding_description_store, _update_context
|
||||
|
||||
@ -107,7 +107,7 @@ async def global_search_streaming(request: GraphRequest):
|
||||
|
||||
# read parquet files into DataFrames and add provenance information
|
||||
# note that nodes need to set before communities to that max community id makes sense
|
||||
nodes_df = query_helper.get_df(nodes_table_path)
|
||||
nodes_df = get_df(nodes_table_path)
|
||||
for i in nodes_df["human_readable_id"]:
|
||||
links["nodes"][i + max_vals["nodes"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -125,7 +125,7 @@ async def global_search_streaming(request: GraphRequest):
|
||||
max_vals["nodes"] = nodes_df["human_readable_id"].max()
|
||||
nodes_dfs.append(nodes_df)
|
||||
|
||||
community_df = query_helper.get_df(community_report_table_path)
|
||||
community_df = get_df(community_report_table_path)
|
||||
for i in community_df["community"].astype(int):
|
||||
links["community"][i + max_vals["community"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -137,7 +137,7 @@ async def global_search_streaming(request: GraphRequest):
|
||||
max_vals["community"] = community_df["community"].astype(int).max()
|
||||
community_dfs.append(community_df)
|
||||
|
||||
entities_df = query_helper.get_df(entities_table_path)
|
||||
entities_df = get_df(entities_table_path)
|
||||
for i in entities_df["human_readable_id"]:
|
||||
links["entities"][i + max_vals["entities"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -188,7 +188,7 @@ async def global_search_streaming(request: GraphRequest):
|
||||
)
|
||||
except Exception as e:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error(
|
||||
logger.error(
|
||||
message="Error encountered while streaming global search response",
|
||||
cause=e,
|
||||
stack=traceback.format_exc(),
|
||||
@ -277,7 +277,7 @@ async def local_search_streaming(request: GraphRequest):
|
||||
# read the parquet files into DataFrames and add provenance information
|
||||
|
||||
# note that nodes need to set before communities to that max community id makes sense
|
||||
nodes_df = query_helper.get_df(nodes_table_path)
|
||||
nodes_df = get_df(nodes_table_path)
|
||||
for i in nodes_df["human_readable_id"]:
|
||||
links["nodes"][i + max_vals["nodes"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -296,7 +296,7 @@ async def local_search_streaming(request: GraphRequest):
|
||||
max_vals["nodes"] = nodes_df["human_readable_id"].max()
|
||||
nodes_dfs.append(nodes_df)
|
||||
|
||||
community_df = query_helper.get_df(community_report_table_path)
|
||||
community_df = get_df(community_report_table_path)
|
||||
for i in community_df["community"].astype(int):
|
||||
links["community"][i + max_vals["community"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -308,7 +308,7 @@ async def local_search_streaming(request: GraphRequest):
|
||||
max_vals["community"] = community_df["community"].astype(int).max()
|
||||
community_dfs.append(community_df)
|
||||
|
||||
entities_df = query_helper.get_df(entities_table_path)
|
||||
entities_df = get_df(entities_table_path)
|
||||
for i in entities_df["human_readable_id"]:
|
||||
links["entities"][i + max_vals["entities"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -326,7 +326,7 @@ async def local_search_streaming(request: GraphRequest):
|
||||
max_vals["entities"] = entities_df["human_readable_id"].max()
|
||||
entities_dfs.append(entities_df)
|
||||
|
||||
relationships_df = query_helper.get_df(relationships_table_path)
|
||||
relationships_df = get_df(relationships_table_path)
|
||||
for i in relationships_df["human_readable_id"].astype(int):
|
||||
links["relationships"][i + max_vals["relationships"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -353,7 +353,7 @@ async def local_search_streaming(request: GraphRequest):
|
||||
)
|
||||
relationships_dfs.append(relationships_df)
|
||||
|
||||
text_units_df = query_helper.get_df(text_units_table_path)
|
||||
text_units_df = get_df(text_units_table_path)
|
||||
text_units_df["id"] = text_units_df["id"].apply(
|
||||
lambda x: f"{x}-{index_name}"
|
||||
)
|
||||
@ -363,7 +363,7 @@ async def local_search_streaming(request: GraphRequest):
|
||||
index_name
|
||||
)
|
||||
if index_container_client.get_blob_client(COVARIATES_TABLE).exists():
|
||||
covariates_df = query_helper.get_df(covariates_table_path)
|
||||
covariates_df = get_df(covariates_table_path)
|
||||
if i in covariates_df["human_readable_id"].astype(int):
|
||||
links["covariates"][i + max_vals["covariates"] + 1] = {
|
||||
"index_name": sanitized_index_names_link[index_name],
|
||||
@ -431,7 +431,7 @@ async def local_search_streaming(request: GraphRequest):
|
||||
)
|
||||
except Exception as e:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error(
|
||||
logger.error(
|
||||
message="Error encountered while streaming local search response",
|
||||
cause=e,
|
||||
stack=traceback.format_exc(),
|
||||
|
@ -5,19 +5,19 @@
|
||||
import pandas as pd
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from src.api.common import (
|
||||
get_pandas_storage_options,
|
||||
sanitize_name,
|
||||
validate_index_file_exist,
|
||||
)
|
||||
from src.logger import LoggerSingleton
|
||||
from src.models import (
|
||||
from src.typing.models import (
|
||||
ClaimResponse,
|
||||
EntityResponse,
|
||||
RelationshipResponse,
|
||||
ReportResponse,
|
||||
TextUnitResponse,
|
||||
)
|
||||
from src.utils.common import (
|
||||
pandas_storage_options,
|
||||
sanitize_name,
|
||||
validate_index_file_exist,
|
||||
)
|
||||
|
||||
source_route = APIRouter(
|
||||
prefix="/source",
|
||||
@ -46,7 +46,7 @@ async def get_report_info(index_name: str, report_id: str):
|
||||
try:
|
||||
report_table = pd.read_parquet(
|
||||
f"abfs://{sanitized_index_name}/{COMMUNITY_REPORT_TABLE}",
|
||||
storage_options=get_pandas_storage_options(),
|
||||
storage_options=pandas_storage_options(),
|
||||
)
|
||||
# check if report_id exists in the index
|
||||
if not report_table["community"].isin([report_id]).any():
|
||||
@ -62,7 +62,7 @@ async def get_report_info(index_name: str, report_id: str):
|
||||
return ReportResponse(text=report_content)
|
||||
except Exception:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error("Could not get report.")
|
||||
logger.error("Could not get report.")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error retrieving report '{report_id}' from index '{index_name}'.",
|
||||
@ -83,11 +83,11 @@ async def get_chunk_info(index_name: str, text_unit_id: str):
|
||||
try:
|
||||
text_units = pd.read_parquet(
|
||||
f"abfs://{sanitized_index_name}/{TEXT_UNITS_TABLE}",
|
||||
storage_options=get_pandas_storage_options(),
|
||||
storage_options=pandas_storage_options(),
|
||||
)
|
||||
docs = pd.read_parquet(
|
||||
f"abfs://{sanitized_index_name}/{DOCUMENTS_TABLE}",
|
||||
storage_options=get_pandas_storage_options(),
|
||||
storage_options=pandas_storage_options(),
|
||||
)
|
||||
# rename columns for easy joining
|
||||
docs = docs[["id", "title"]].rename(
|
||||
@ -115,7 +115,7 @@ async def get_chunk_info(index_name: str, text_unit_id: str):
|
||||
)
|
||||
except Exception:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error("Could not get text chunk.")
|
||||
logger.error("Could not get text chunk.")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error retrieving text chunk '{text_unit_id}' from index '{index_name}'.",
|
||||
@ -135,7 +135,7 @@ async def get_entity_info(index_name: str, entity_id: int):
|
||||
try:
|
||||
entity_table = pd.read_parquet(
|
||||
f"abfs://{sanitized_index_name}/{ENTITY_EMBEDDING_TABLE}",
|
||||
storage_options=get_pandas_storage_options(),
|
||||
storage_options=pandas_storage_options(),
|
||||
)
|
||||
# check if entity_id exists in the index
|
||||
if not entity_table["human_readable_id"].isin([entity_id]).any():
|
||||
@ -148,7 +148,7 @@ async def get_entity_info(index_name: str, entity_id: int):
|
||||
)
|
||||
except Exception:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error("Could not get entity")
|
||||
logger.error("Could not get entity")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error retrieving entity '{entity_id}' from index '{index_name}'.",
|
||||
@ -175,7 +175,7 @@ async def get_claim_info(index_name: str, claim_id: int):
|
||||
try:
|
||||
claims_table = pd.read_parquet(
|
||||
f"abfs://{sanitized_index_name}/{COVARIATES_TABLE}",
|
||||
storage_options=get_pandas_storage_options(),
|
||||
storage_options=pandas_storage_options(),
|
||||
)
|
||||
claims_table.human_readable_id = claims_table.human_readable_id.astype(
|
||||
float
|
||||
@ -193,7 +193,7 @@ async def get_claim_info(index_name: str, claim_id: int):
|
||||
)
|
||||
except Exception:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error("Could not get claim.")
|
||||
logger.error("Could not get claim.")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error retrieving claim '{claim_id}' from index '{index_name}'.",
|
||||
@ -214,11 +214,11 @@ async def get_relationship_info(index_name: str, relationship_id: int):
|
||||
try:
|
||||
relationship_table = pd.read_parquet(
|
||||
f"abfs://{sanitized_index_name}/{RELATIONSHIPS_TABLE}",
|
||||
storage_options=get_pandas_storage_options(),
|
||||
storage_options=pandas_storage_options(),
|
||||
)
|
||||
entity_table = pd.read_parquet(
|
||||
f"abfs://{sanitized_index_name}/{ENTITY_EMBEDDING_TABLE}",
|
||||
storage_options=get_pandas_storage_options(),
|
||||
storage_options=pandas_storage_options(),
|
||||
)
|
||||
row = relationship_table[
|
||||
relationship_table.human_readable_id == str(relationship_id)
|
||||
@ -239,7 +239,7 @@ async def get_relationship_info(index_name: str, relationship_id: int):
|
||||
)
|
||||
except Exception:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error("Could not get relationship.")
|
||||
logger.error("Could not get relationship.")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error retrieving relationship '{relationship_id}' from index '{index_name}'.",
|
||||
|
189
backend/src/indexer/indexer.py
Normal file
189
backend/src/indexer/indexer.py
Normal file
@ -0,0 +1,189 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import graphrag.api as api
|
||||
import yaml
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.create_pipeline_config import create_pipeline_config
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.logger import (
|
||||
Logger,
|
||||
PipelineJobUpdater,
|
||||
load_pipeline_logger,
|
||||
)
|
||||
from src.typing.pipeline import PipelineJobState
|
||||
from src.utils.common import sanitize_name
|
||||
from src.utils.pipeline import PipelineJob
|
||||
|
||||
|
||||
def start_indexing_job(index_name: str):
|
||||
print("Start indexing job...")
|
||||
# get sanitized name
|
||||
sanitized_index_name = sanitize_name(index_name)
|
||||
|
||||
# update or create new item in container-store in cosmosDB
|
||||
azure_client_manager = AzureClientManager()
|
||||
blob_service_client = azure_client_manager.get_blob_service_client()
|
||||
if not blob_service_client.get_container_client(sanitized_index_name).exists():
|
||||
blob_service_client.create_container(sanitized_index_name)
|
||||
|
||||
cosmos_container_client = azure_client_manager.get_cosmos_container_client(
|
||||
database="graphrag", container="container-store"
|
||||
)
|
||||
cosmos_container_client.upsert_item({
|
||||
"id": sanitized_index_name,
|
||||
"human_readable_name": index_name,
|
||||
"type": "index",
|
||||
})
|
||||
|
||||
print("Initialize pipeline job...")
|
||||
pipelinejob = PipelineJob()
|
||||
pipeline_job = pipelinejob.load_item(sanitized_index_name)
|
||||
sanitized_storage_name = pipeline_job.sanitized_storage_name
|
||||
storage_name = pipeline_job.human_readable_index_name
|
||||
|
||||
# load custom pipeline settings
|
||||
this_directory = os.path.dirname(
|
||||
os.path.abspath(inspect.getfile(inspect.currentframe()))
|
||||
)
|
||||
data = yaml.safe_load(open(f"{this_directory}/settings.yaml"))
|
||||
# dynamically set some values
|
||||
data["input"]["container_name"] = sanitized_storage_name
|
||||
data["storage"]["container_name"] = sanitized_index_name
|
||||
data["reporting"]["container_name"] = sanitized_index_name
|
||||
data["cache"]["container_name"] = sanitized_index_name
|
||||
if "vector_store" in data["embeddings"]:
|
||||
data["embeddings"]["vector_store"]["collection_name"] = (
|
||||
f"{sanitized_index_name}_description_embedding"
|
||||
)
|
||||
|
||||
# set prompt for entity extraction
|
||||
if pipeline_job.entity_extraction_prompt:
|
||||
fname = "entity-extraction-prompt.txt"
|
||||
with open(fname, "w") as outfile:
|
||||
outfile.write(pipeline_job.entity_extraction_prompt)
|
||||
data["entity_extraction"]["prompt"] = fname
|
||||
else:
|
||||
data.pop("entity_extraction")
|
||||
|
||||
# set prompt for summarize descriptions
|
||||
if pipeline_job.summarize_descriptions_prompt:
|
||||
fname = "summarize-descriptions-prompt.txt"
|
||||
with open(fname, "w") as outfile:
|
||||
outfile.write(pipeline_job.summarize_descriptions_prompt)
|
||||
data["summarize_descriptions"]["prompt"] = fname
|
||||
else:
|
||||
data.pop("summarize_descriptions")
|
||||
|
||||
# set prompt for community report
|
||||
if pipeline_job.community_report_prompt:
|
||||
fname = "community-report-prompt.txt"
|
||||
with open(fname, "w") as outfile:
|
||||
outfile.write(pipeline_job.community_report_prompt)
|
||||
data["community_reports"]["prompt"] = fname
|
||||
else:
|
||||
data.pop("community_reports")
|
||||
|
||||
# generate default graphrag config parameters and override with custom settings
|
||||
parameters = create_graphrag_config(data, ".")
|
||||
|
||||
# reset pipeline job details
|
||||
pipeline_job.status = PipelineJobState.RUNNING
|
||||
pipeline_job.all_workflows = []
|
||||
pipeline_job.completed_workflows = []
|
||||
pipeline_job.failed_workflows = []
|
||||
pipeline_config = create_pipeline_config(parameters)
|
||||
for workflow in pipeline_config.workflows:
|
||||
pipeline_job.all_workflows.append(workflow.name)
|
||||
|
||||
# create new loggers/callbacks just for this job
|
||||
logger_names = []
|
||||
for logger_type in ["BLOB", "CONSOLE", "APP_INSIGHTS"]:
|
||||
logger_names.append(Logger[logger_type.upper()])
|
||||
print("Creating generic loggers...")
|
||||
logger: WorkflowCallbacks = load_pipeline_logger(
|
||||
logging_dir=sanitized_index_name,
|
||||
index_name=index_name,
|
||||
num_workflow_steps=len(pipeline_job.all_workflows),
|
||||
loggers=logger_names,
|
||||
)
|
||||
|
||||
# create pipeline job updater to monitor job progress
|
||||
print("Creating pipeline job updater...")
|
||||
pipeline_job_updater = PipelineJobUpdater(pipeline_job)
|
||||
|
||||
# run the pipeline
|
||||
try:
|
||||
print("Building index...")
|
||||
asyncio.run(
|
||||
api.build_index(
|
||||
config=parameters,
|
||||
callbacks=[logger, pipeline_job_updater],
|
||||
)
|
||||
)
|
||||
print("Index building complete")
|
||||
# if job is done, check if any pipeline steps failed
|
||||
if len(pipeline_job.failed_workflows) > 0:
|
||||
print("Indexing pipeline encountered error.")
|
||||
pipeline_job.status = PipelineJobState.FAILED
|
||||
logger.error(
|
||||
message=f"Indexing pipeline encountered error for index'{index_name}'.",
|
||||
details={
|
||||
"index": index_name,
|
||||
"storage_name": storage_name,
|
||||
"status_message": "indexing pipeline encountered error",
|
||||
},
|
||||
)
|
||||
else:
|
||||
print("Indexing pipeline complete.")
|
||||
# record the pipeline completion
|
||||
pipeline_job.status = PipelineJobState.COMPLETE
|
||||
pipeline_job.percent_complete = 100
|
||||
logger.log(
|
||||
message=f"Indexing pipeline complete for index'{index_name}'.",
|
||||
details={
|
||||
"index": index_name,
|
||||
"storage_name": storage_name,
|
||||
"status_message": "indexing pipeline complete",
|
||||
},
|
||||
)
|
||||
pipeline_job.progress = (
|
||||
f"{len(pipeline_job.completed_workflows)} out of "
|
||||
f"{len(pipeline_job.all_workflows)} workflows completed successfully."
|
||||
)
|
||||
if pipeline_job.status == PipelineJobState.FAILED:
|
||||
exit(1) # signal to AKS that indexing job failed
|
||||
except Exception as e:
|
||||
pipeline_job.status = PipelineJobState.FAILED
|
||||
# update failed state in cosmos db
|
||||
error_details = {
|
||||
"index": index_name,
|
||||
"storage_name": storage_name,
|
||||
}
|
||||
# log error in local index directory logs
|
||||
logger.error(
|
||||
message=f"Indexing pipeline failed for index '{index_name}'.",
|
||||
cause=e,
|
||||
stack=traceback.format_exc(),
|
||||
details=error_details,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Build a graphrag index.")
|
||||
parser.add_argument("-i", "--index-name", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(
|
||||
start_indexing_job(
|
||||
index_name=args.index_name,
|
||||
)
|
||||
)
|
@ -17,8 +17,8 @@ llm:
|
||||
model_supports_json: True
|
||||
tokens_per_minute: 80_000
|
||||
requests_per_minute: 480
|
||||
concurrent_requests: 25
|
||||
max_retries: 25
|
||||
concurrent_requests: 50
|
||||
max_retries: 250
|
||||
max_retry_wait: 60.0
|
||||
sleep_on_rate_limit_recommendation: True
|
||||
|
||||
@ -45,7 +45,7 @@ embeddings:
|
||||
deployment_name: $GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME
|
||||
cognitive_services_endpoint: $GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT
|
||||
tokens_per_minute: 350_000
|
||||
requests_per_minute: 2100
|
||||
requests_per_minute: 2_100
|
||||
|
||||
###################### Input settings ######################
|
||||
input:
|
@ -7,20 +7,20 @@ from src.logger.application_insights_workflow_callbacks import (
|
||||
from src.logger.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from src.logger.load_logger import load_pipeline_logger
|
||||
from src.logger.logger_singleton import LoggerSingleton
|
||||
from src.logger.pipeline_job_workflow_callbacks import PipelineJobWorkflowCallbacks
|
||||
from src.logger.pipeline_job_updater import PipelineJobUpdater
|
||||
from src.logger.typing import (
|
||||
Logger,
|
||||
PipelineAppInsightsReportingConfig,
|
||||
PipelineReportingConfigTypes,
|
||||
Reporters,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Reporters",
|
||||
"Logger",
|
||||
"ApplicationInsightsWorkflowCallbacks",
|
||||
"ConsoleWorkflowCallbacks",
|
||||
"LoggerSingleton",
|
||||
"PipelineAppInsightsReportingConfig",
|
||||
"PipelineJobWorkflowCallbacks",
|
||||
"PipelineJobUpdater",
|
||||
"PipelineReportingConfigTypes",
|
||||
"load_pipeline_logger",
|
||||
]
|
||||
|
@ -1,28 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
# from dataclasses import asdict
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from azure.monitor.opentelemetry.exporter import AzureMonitorLogExporter
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.monitor.opentelemetry import configure_azure_monitor
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from opentelemetry._logs import (
|
||||
get_logger_provider,
|
||||
set_logger_provider,
|
||||
)
|
||||
from opentelemetry.sdk._logs import (
|
||||
LoggerProvider,
|
||||
LoggingHandler,
|
||||
)
|
||||
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
|
||||
|
||||
|
||||
class ApplicationInsightsWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
@ -31,7 +19,6 @@ class ApplicationInsightsWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
_logger: logging.Logger
|
||||
_logger_name: str
|
||||
_logger_level: int
|
||||
_logger_level_name: str
|
||||
_properties: Dict[str, Any]
|
||||
_workflow_name: str
|
||||
_index_name: str
|
||||
@ -40,9 +27,7 @@ class ApplicationInsightsWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
logger_name: str | None = None,
|
||||
logger_level: int = logging.INFO,
|
||||
logger_name: str = "graphrag-accelerator",
|
||||
index_name: str = "",
|
||||
num_workflow_steps: int = 0,
|
||||
properties: Dict[str, Any] = {},
|
||||
@ -51,60 +36,31 @@ class ApplicationInsightsWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
Initialize the AppInsightsReporter.
|
||||
|
||||
Args:
|
||||
connection_string (str): The connection string for the App Insights instance.
|
||||
logger_name (str | None, optional): The name of the logger. Defaults to None.
|
||||
logger_level (int, optional): The logging level. Defaults to logging.INFO.
|
||||
index_name (str, optional): The name of an index. Defaults to "".
|
||||
num_workflow_steps (int): A list of workflow names ordered by their execution. Defaults to [].
|
||||
properties (Dict[str, Any], optional): Additional properties to be included in the log. Defaults to {}.
|
||||
"""
|
||||
self._logger: logging.Logger
|
||||
self._logger_name = logger_name
|
||||
self._logger_level = logger_level
|
||||
self._logger_level_name: str = logging.getLevelName(logger_level)
|
||||
self._properties = properties
|
||||
self._workflow_name = "N/A"
|
||||
self._index_name = index_name
|
||||
self._num_workflow_steps = num_workflow_steps
|
||||
self._processed_workflow_steps = [] # maintain a running list of workflow steps that get processed
|
||||
"""Create a new logger with an AppInsights handler."""
|
||||
self.__init_logger(connection_string=connection_string)
|
||||
self._properties = properties
|
||||
self._workflow_name = "N/A"
|
||||
self._processed_workflow_steps = [] # if logger is used in a pipeline job, maintain a running list of workflows that are processed
|
||||
# initialize a new logger with an AppInsights handler
|
||||
self.__init_logger()
|
||||
|
||||
def __init_logger(self, connection_string, max_logger_init_retries: int = 10):
|
||||
max_retry = max_logger_init_retries
|
||||
while not (hasattr(self, "_logger")):
|
||||
if max_retry == 0:
|
||||
raise Exception(
|
||||
"Failed to create logger. Could not disambiguate logger name."
|
||||
)
|
||||
|
||||
# generate a unique logger name
|
||||
current_time = str(time.time())
|
||||
unique_hash = hashlib.sha256(current_time.encode()).hexdigest()
|
||||
self._logger_name = f"{self.__class__.__name__}-{unique_hash}"
|
||||
if self._logger_name not in logging.Logger.manager.loggerDict:
|
||||
# attach azure monitor log exporter to logger provider
|
||||
logger_provider = LoggerProvider()
|
||||
set_logger_provider(logger_provider)
|
||||
exporter = AzureMonitorLogExporter(connection_string=connection_string)
|
||||
get_logger_provider().add_log_record_processor(
|
||||
BatchLogRecordProcessor(
|
||||
exporter=exporter,
|
||||
schedule_delay_millis=60000,
|
||||
)
|
||||
)
|
||||
# instantiate new logger
|
||||
self._logger = logging.getLogger(self._logger_name)
|
||||
self._logger.propagate = False
|
||||
# remove any existing handlers
|
||||
self._logger.handlers.clear()
|
||||
# fetch handler from logger provider and attach to class
|
||||
self._logger.addHandler(LoggingHandler())
|
||||
# set logging level
|
||||
self._logger.setLevel(logging.DEBUG)
|
||||
|
||||
# reduce sentinel counter value
|
||||
max_retry -= 1
|
||||
def __init_logger(self, max_logger_init_retries: int = 10):
|
||||
# Configure OpenTelemetry to use Azure Monitor with the
|
||||
# APPLICATIONINSIGHTS_CONNECTION_STRING environment variable
|
||||
configure_azure_monitor(
|
||||
logger_name=self._logger_name,
|
||||
disable_offline_storage=True,
|
||||
enable_live_metrics=True,
|
||||
credential=DefaultAzureCredential(),
|
||||
)
|
||||
self._logger = logging.getLogger(self._logger_name)
|
||||
|
||||
def _format_details(self, details: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
"""
|
||||
|
@ -7,6 +7,7 @@ from typing import List
|
||||
|
||||
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.logger.application_insights_workflow_callbacks import (
|
||||
@ -14,32 +15,32 @@ from src.logger.application_insights_workflow_callbacks import (
|
||||
)
|
||||
from src.logger.blob_workflow_callbacks import BlobWorkflowCallbacks
|
||||
from src.logger.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from src.logger.typing import Reporters
|
||||
from src.logger.typing import Logger
|
||||
|
||||
|
||||
def load_pipeline_logger(
|
||||
reporting_dir: str | None,
|
||||
reporters: List[Reporters] | None = [],
|
||||
logging_dir: str | None,
|
||||
index_name: str = "",
|
||||
num_workflow_steps: int = 0,
|
||||
loggers: List[Logger] = [],
|
||||
) -> WorkflowCallbacks:
|
||||
"""Create and load a list of loggers.
|
||||
|
||||
Loggers may be configured as generic loggers or associated with a specified indexing job.
|
||||
"""
|
||||
# always register the console logger as a fallback option
|
||||
if Reporters.CONSOLE not in reporters:
|
||||
reporters.append(Reporters.CONSOLE)
|
||||
if Logger.CONSOLE not in loggers:
|
||||
loggers.append(Logger.CONSOLE)
|
||||
|
||||
azure_client_manager = AzureClientManager()
|
||||
logger_callbacks = []
|
||||
for reporter in reporters:
|
||||
match reporter:
|
||||
case Reporters.BLOB:
|
||||
callback_manager = WorkflowCallbacksManager()
|
||||
for logger in loggers:
|
||||
match logger:
|
||||
case Logger.BLOB:
|
||||
# create a dedicated container for logs
|
||||
container_name = "logs"
|
||||
if reporting_dir is not None:
|
||||
container_name = os.path.join(reporting_dir, container_name)
|
||||
if logging_dir is not None:
|
||||
container_name = os.path.join(logging_dir, container_name)
|
||||
# ensure the root directory exists; if not, create it
|
||||
blob_service_client = azure_client_manager.get_blob_service_client()
|
||||
container_root = Path(container_name).parts[0]
|
||||
@ -47,8 +48,7 @@ def load_pipeline_logger(
|
||||
container_root
|
||||
).exists():
|
||||
blob_service_client.create_container(container_root)
|
||||
# register the blob reporter
|
||||
logger_callbacks.append(
|
||||
callback_manager.register(
|
||||
BlobWorkflowCallbacks(
|
||||
blob_service_client=blob_service_client,
|
||||
container_name=container_name,
|
||||
@ -56,25 +56,25 @@ def load_pipeline_logger(
|
||||
num_workflow_steps=num_workflow_steps,
|
||||
)
|
||||
)
|
||||
case Reporters.FILE:
|
||||
logger_callbacks.append(FileWorkflowCallbacks(dir=reporting_dir))
|
||||
case Reporters.APP_INSIGHTS:
|
||||
if os.getenv("APP_INSIGHTS_CONNECTION_STRING"):
|
||||
logger_callbacks.append(
|
||||
case Logger.FILE:
|
||||
callback_manager.register(FileWorkflowCallbacks(dir=logging_dir))
|
||||
case Logger.APP_INSIGHTS:
|
||||
if os.getenv("APPLICATIONINSIGHTS_CONNECTION_STRING"):
|
||||
callback_manager.register(
|
||||
ApplicationInsightsWorkflowCallbacks(
|
||||
connection_string=os.environ[
|
||||
"APP_INSIGHTS_CONNECTION_STRING"
|
||||
"APPLICATIONINSIGHTS_CONNECTION_STRING"
|
||||
],
|
||||
index_name=index_name,
|
||||
num_workflow_steps=num_workflow_steps,
|
||||
)
|
||||
)
|
||||
case Reporters.CONSOLE:
|
||||
logger_callbacks.append(
|
||||
case Logger.CONSOLE:
|
||||
callback_manager.register(
|
||||
ConsoleWorkflowCallbacks(
|
||||
index_name=index_name, num_workflow_steps=num_workflow_steps
|
||||
)
|
||||
)
|
||||
case _:
|
||||
print(f"WARNING: unknown reporter type: {reporter}. Skipping.")
|
||||
return logger_callbacks
|
||||
print(f"WARNING: unknown logger type: {logger}. Skipping.")
|
||||
return callback_manager
|
||||
|
@ -1,13 +1,11 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
from src.logger.load_logger import load_pipeline_logger
|
||||
from src.logger.typing import Reporters
|
||||
from src.logger.typing import Logger
|
||||
|
||||
|
||||
class LoggerSingleton:
|
||||
@ -15,23 +13,9 @@ class LoggerSingleton:
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> WorkflowCallbacks:
|
||||
if cls._instance is None:
|
||||
# Set up reporters based on environment variable or defaults
|
||||
if not cls._instance:
|
||||
reporters = []
|
||||
for reporter_name in os.getenv(
|
||||
"REPORTERS", Reporters.CONSOLE.name.upper()
|
||||
).split(","):
|
||||
try:
|
||||
reporters.append(Reporters[reporter_name.upper()])
|
||||
except KeyError:
|
||||
raise ValueError(f"Found unknown reporter: {reporter_name}")
|
||||
cls._instance = load_pipeline_logger(reporting_dir="", reporters=reporters)
|
||||
for logger_type in ["BLOB", "CONSOLE", "APP_INSIGHTS"]:
|
||||
reporters.append(Logger[logger_type])
|
||||
cls._instance = load_pipeline_logger(logging_dir="", loggers=reporters)
|
||||
return cls._instance
|
||||
|
||||
|
||||
def _is_valid_url(url: str) -> bool:
|
||||
try:
|
||||
result = urlparse(url)
|
||||
return all([result.scheme, result.netloc])
|
||||
except ValueError:
|
||||
return False
|
||||
|
@ -7,12 +7,12 @@ from src.typing.pipeline import PipelineJobState
|
||||
from src.utils.pipeline import PipelineJob
|
||||
|
||||
|
||||
class PipelineJobWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A reporter that writes to a stream (sys.stdout)."""
|
||||
class PipelineJobUpdater(NoopWorkflowCallbacks):
|
||||
"""A callback that records pipeline updates."""
|
||||
|
||||
def __init__(self, pipeline_job: "PipelineJob"):
|
||||
def __init__(self, pipeline_job: PipelineJob):
|
||||
"""
|
||||
This class defines a set of callback methods that can be used to report the progress and status of a workflow job.
|
||||
This class defines a set of callback methods that can be used to log the progress of a pipeline job.
|
||||
It inherits from the NoopWorkflowCallbacks class, which provides default implementations for all the callback methods.
|
||||
|
||||
Attributes:
|
@ -12,7 +12,7 @@ from graphrag.index.config.reporting import (
|
||||
from pydantic import Field as pydantic_Field
|
||||
|
||||
|
||||
class Reporters(Enum):
|
||||
class Logger(Enum):
|
||||
BLOB = (1, "blob")
|
||||
CONSOLE = (2, "console")
|
||||
FILE = (3, "file")
|
||||
@ -24,7 +24,7 @@ class PipelineAppInsightsReportingConfig(
|
||||
):
|
||||
"""Represents the ApplicationInsights reporting configuration for the pipeline."""
|
||||
|
||||
type: Literal["app_insights"] = Reporters.APP_INSIGHTS.name.lower()
|
||||
type: Literal["app_insights"] = Logger.APP_INSIGHTS.name.lower()
|
||||
"""The type of reporting."""
|
||||
|
||||
connection_string: str = pydantic_Field(
|
||||
|
@ -23,7 +23,7 @@ from src.api.azure_clients import AzureClientManager
|
||||
from src.api.data import data_route
|
||||
from src.api.graph import graph_route
|
||||
from src.api.index import index_route
|
||||
from src.api.index_configuration import index_configuration_route
|
||||
from src.api.prompt_tuning import prompt_tuning_route
|
||||
from src.api.query import query_route
|
||||
from src.api.query_streaming import query_streaming_route
|
||||
from src.api.source import source_route
|
||||
@ -37,7 +37,7 @@ async def catch_all_exceptions_middleware(request: Request, call_next):
|
||||
except Exception as e:
|
||||
reporter = LoggerSingleton().get_instance()
|
||||
stack = traceback.format_exc()
|
||||
reporter.on_error(
|
||||
reporter.error(
|
||||
message="Unexpected internal server error",
|
||||
cause=e,
|
||||
stack=stack,
|
||||
@ -82,7 +82,7 @@ async def lifespan(app: FastAPI):
|
||||
name=pod_name, namespace=os.environ["AKS_NAMESPACE"]
|
||||
)
|
||||
# load the cronjob manifest template and update PLACEHOLDER values with correct values using the pod spec
|
||||
with open("indexing-job-manager-template.yaml", "r") as f:
|
||||
with open("index-job-manager.yaml", "r") as f:
|
||||
manifest = yaml.safe_load(f)
|
||||
manifest["spec"]["jobTemplate"]["spec"]["template"]["spec"]["containers"][0][
|
||||
"image"
|
||||
@ -104,7 +104,7 @@ async def lifespan(app: FastAPI):
|
||||
except Exception as e:
|
||||
print("Failed to create graphrag cronjob.")
|
||||
logger = LoggerSingleton().get_instance()
|
||||
logger.on_error(
|
||||
logger.error(
|
||||
message="Failed to create graphrag cronjob",
|
||||
cause=str(e),
|
||||
stack=traceback.format_exc(),
|
||||
@ -133,7 +133,7 @@ app.include_router(data_route)
|
||||
app.include_router(index_route)
|
||||
app.include_router(query_route)
|
||||
app.include_router(query_streaming_route)
|
||||
app.include_router(index_configuration_route)
|
||||
app.include_router(prompt_tuning_route)
|
||||
app.include_router(source_route)
|
||||
app.include_router(graph_route)
|
||||
|
||||
|
@ -0,0 +1,2 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
@ -5,6 +5,7 @@ import hashlib
|
||||
import os
|
||||
import re
|
||||
|
||||
import pandas as pd
|
||||
from azure.cosmos import exceptions
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from fastapi import HTTPException
|
||||
@ -12,7 +13,17 @@ from fastapi import HTTPException
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
|
||||
|
||||
def get_pandas_storage_options() -> dict:
|
||||
def get_df(
|
||||
table_path: str,
|
||||
) -> pd.DataFrame:
|
||||
df = pd.read_parquet(
|
||||
table_path,
|
||||
storage_options=pandas_storage_options(),
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def pandas_storage_options() -> dict:
|
||||
"""Generate the storage options required by pandas to read parquet files from Storage."""
|
||||
# For more information on the options available, see: https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials
|
||||
azure_client_manager = AzureClientManager()
|
@ -9,8 +9,8 @@ from typing import (
|
||||
from azure.cosmos.exceptions import CosmosHttpResponseError
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.api.common import sanitize_name
|
||||
from src.typing.pipeline import PipelineJobState
|
||||
from src.utils.common import sanitize_name
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1,81 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from graphrag.query.indexer_adapters import (
|
||||
read_indexer_covariates,
|
||||
read_indexer_entities,
|
||||
read_indexer_relationships,
|
||||
read_indexer_reports,
|
||||
read_indexer_text_units,
|
||||
)
|
||||
|
||||
from src.api.common import get_pandas_storage_options
|
||||
|
||||
|
||||
def get_entities(
|
||||
entity_table_path: str,
|
||||
entity_embedding_table_path: str,
|
||||
community_level: int = 0,
|
||||
) -> pd.DataFrame:
|
||||
storage_options = get_pandas_storage_options()
|
||||
entity_df = pd.read_parquet(
|
||||
entity_table_path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
entity_embedding_df = pd.read_parquet(
|
||||
entity_embedding_table_path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
return pd.DataFrame(
|
||||
read_indexer_entities(entity_df, entity_embedding_df, community_level)
|
||||
)
|
||||
|
||||
|
||||
def get_reports(
|
||||
entity_table_path: str, community_report_table_path: str, community_level: int
|
||||
) -> pd.DataFrame:
|
||||
storage_options = get_pandas_storage_options()
|
||||
entity_df = pd.read_parquet(
|
||||
entity_table_path,
|
||||
storage_options=storage_options(),
|
||||
)
|
||||
report_df = pd.read_parquet(
|
||||
community_report_table_path,
|
||||
storage_options=storage_options(),
|
||||
)
|
||||
return pd.DataFrame(read_indexer_reports(report_df, entity_df, community_level))
|
||||
|
||||
|
||||
def get_relationships(relationships_table_path: str) -> pd.DataFrame:
|
||||
relationship_df = pd.read_parquet(
|
||||
relationships_table_path,
|
||||
storage_options=get_pandas_storage_options(),
|
||||
)
|
||||
return pd.DataFrame(read_indexer_relationships(relationship_df))
|
||||
|
||||
|
||||
def get_covariates(covariate_table_path: str) -> pd.DataFrame:
|
||||
covariate_df = pd.read_parquet(
|
||||
covariate_table_path,
|
||||
storage_options=get_pandas_storage_options(),
|
||||
)
|
||||
return pd.DataFrame(read_indexer_covariates(covariate_df))
|
||||
|
||||
|
||||
def get_text_units(text_unit_table_path: str) -> pd.DataFrame:
|
||||
text_unit_df = pd.read_parquet(
|
||||
text_unit_table_path,
|
||||
storage_options=get_pandas_storage_options(),
|
||||
)
|
||||
return pd.DataFrame(read_indexer_text_units(text_unit_df))
|
||||
|
||||
|
||||
def get_df(
|
||||
table_path: str,
|
||||
) -> pd.DataFrame:
|
||||
df = pd.read_parquet(
|
||||
table_path,
|
||||
storage_options=get_pandas_storage_options(),
|
||||
)
|
||||
return df
|
@ -10,8 +10,8 @@ from azure.cosmos import CosmosClient, PartitionKey
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.api.common import sanitize_name
|
||||
from src.main import app
|
||||
from src.utils.common import sanitize_name
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from src.api.common import (
|
||||
from src.utils.common import (
|
||||
retrieve_original_blob_container_name,
|
||||
sanitize_name,
|
||||
validate_blob_container_name,
|
||||
@ -12,7 +12,7 @@ from src.api.common import (
|
||||
|
||||
|
||||
def test_validate_blob_container_name():
|
||||
"""Test the src.api.common.validate_blob_container_name function."""
|
||||
"""Test the src.utils.common.validate_blob_container_name function."""
|
||||
# test valid container name
|
||||
assert validate_blob_container_name("validcontainername") is None
|
||||
# test invalid container name
|
||||
@ -33,7 +33,7 @@ def test_validate_blob_container_name():
|
||||
|
||||
|
||||
def test_retrieve_original_blob_container_name(container_with_graphml_file):
|
||||
"""Test the src.api.common.retrieve_original_blob_container_name function."""
|
||||
"""Test the src.utils.common.retrieve_original_blob_container_name function."""
|
||||
# test retrieving a valid container name
|
||||
original_name = container_with_graphml_file
|
||||
sanitized_name = sanitize_name(original_name)
|
||||
@ -43,7 +43,7 @@ def test_retrieve_original_blob_container_name(container_with_graphml_file):
|
||||
|
||||
|
||||
def test_validate_index_file_exist(container_with_graphml_file):
|
||||
"""Test the src.api.common.validate_index_file_exist function."""
|
||||
"""Test the src.utils.common.validate_index_file_exist function."""
|
||||
original_name = container_with_graphml_file
|
||||
sanitized_name = sanitize_name(original_name)
|
||||
# test with a valid index and valid file
|
||||
|
@ -46,8 +46,8 @@ def test_load_pipeline_logger_with_console(
|
||||
):
|
||||
"""Test load_pipeline_logger."""
|
||||
loggers = load_pipeline_logger(
|
||||
reporting_dir="logs",
|
||||
reporters=["app_insights", "blob", "console", "file"],
|
||||
logging_dir="logs",
|
||||
loggers=["app_insights", "blob", "console", "file"],
|
||||
index_name="test-index",
|
||||
num_workflow_steps=4,
|
||||
)
|
||||
|
@ -37,26 +37,26 @@ def workflow_callbacks(mock_logger):
|
||||
yield instance
|
||||
|
||||
|
||||
def test_on_workflow_start(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.on_workflow_start("test_workflow", object())
|
||||
def test_workflow_start(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.workflow_start("test_workflow", object())
|
||||
assert mock_logger.info.called
|
||||
|
||||
|
||||
def test_on_workflow_end(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.on_workflow_end("test_workflow", object())
|
||||
def test_workflow_end(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.workflow_end("test_workflow", object())
|
||||
assert mock_logger.info.called
|
||||
|
||||
|
||||
def test_on_log(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.on_log("test_log_message")
|
||||
def test_log(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.log("test_log_message")
|
||||
assert mock_logger.info.called
|
||||
|
||||
|
||||
def test_on_warning(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.on_warning("test_warning")
|
||||
def test_warning(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.warning("test_warning")
|
||||
assert mock_logger.warning.called
|
||||
|
||||
|
||||
def test_on_error(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.on_error("test_error", Exception("test_exception"))
|
||||
def test_error(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.error("test_error", Exception("test_exception"))
|
||||
assert mock_logger.error.called
|
||||
|
@ -34,16 +34,16 @@ def workflow_callbacks(mock_blob_service_client):
|
||||
|
||||
|
||||
def test_on_workflow_start(workflow_callbacks):
|
||||
workflow_callbacks.on_workflow_start("test_workflow", object())
|
||||
workflow_callbacks.workflow_start("test_workflow", object())
|
||||
# check if blob workflow callbacks _write_log() method was called
|
||||
assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called
|
||||
|
||||
|
||||
def test_on_workflow_end(workflow_callbacks):
|
||||
workflow_callbacks.on_workflow_end("test_workflow", object())
|
||||
workflow_callbacks.workflow_end("test_workflow", object())
|
||||
assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called
|
||||
|
||||
|
||||
def test_on_error(workflow_callbacks):
|
||||
workflow_callbacks.on_error("test_error", Exception("test_exception"))
|
||||
workflow_callbacks.error("test_error", Exception("test_exception"))
|
||||
assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called
|
||||
|
@ -34,26 +34,26 @@ def workflow_callbacks(mock_logger):
|
||||
yield instance
|
||||
|
||||
|
||||
def test_on_workflow_start(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.on_workflow_start("test_workflow", object())
|
||||
def test_workflow_start(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.workflow_start("test_workflow", object())
|
||||
assert mock_logger.info.called
|
||||
|
||||
|
||||
def test_on_workflow_end(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.on_workflow_end("test_workflow", object())
|
||||
def test_workflow_end(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.workflow_end("test_workflow", object())
|
||||
assert mock_logger.info.called
|
||||
|
||||
|
||||
def test_on_log(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.on_log("test_log_message")
|
||||
def test_log(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.log("test_log_message")
|
||||
assert mock_logger.info.called
|
||||
|
||||
|
||||
def test_on_warning(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.on_warning("test_warning")
|
||||
def test_warning(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.warning("test_warning")
|
||||
assert mock_logger.warning.called
|
||||
|
||||
|
||||
def test_on_error(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.on_error("test_error", Exception("test_exception"))
|
||||
def test_error(workflow_callbacks, mock_logger):
|
||||
workflow_callbacks.error("test_error", Exception("test_exception"))
|
||||
assert mock_logger.error.called
|
||||
|
@ -1,8 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# For more information about the base image visit:
|
||||
# https://mcr.microsoft.com/en-us/artifact/mar/devcontainers/python/about
|
||||
# For more information about the base image: https://mcr.microsoft.com/en-us/artifact/mar/devcontainers/python/about
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.10-bookworm
|
||||
|
||||
# default graphrag version will be 0.0.0 unless overridden by --build-arg
|
||||
@ -22,7 +21,6 @@ RUN cd backend \
|
||||
|
||||
# download all nltk data that graphrag requires
|
||||
RUN python -c "import nltk;nltk.download(['punkt','averaged_perceptron_tagger','maxent_ne_chunker','words','wordnet'])"
|
||||
|
||||
# download tiktoken model encodings
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt-3.5-turbo'); tiktoken.encoding_for_model('gpt-4'); tiktoken.encoding_for_model('gpt-4o');"
|
||||
|
||||
|
@ -95,7 +95,6 @@ In the `deploy.parameters.json` file, provide values for the following required
|
||||
`RESOURCE_BASE_NAME` | | No | Suffix to apply to all azure resource names. If not provided a unique suffix will be generated.
|
||||
`AISEARCH_ENDPOINT_SUFFIX` | | No | Suffix to apply to AI search endpoint. Will default to `search.windows.net` for Azure Commercial cloud but should be overridden for deployments in other Azure clouds.
|
||||
`AISEARCH_AUDIENCE` | | No | Audience for AAD for AI Search. Will default to `https://search.azure.com/` for Azure Commercial cloud but should be overridden for deployments in other Azure clouds.
|
||||
`REPORTERS` | blob,console,app_insights | No | The type of logging to enable. A comma separated string containing any of the following values: `[blob,console,file,app_insights]`. Will default to `"blob,console,app_insights"`.
|
||||
|
||||
### 5. Deploy solution accelerator to the resource group
|
||||
```shell
|
||||
|
@ -16,7 +16,6 @@ GRAPHRAG_IMAGE=""
|
||||
PUBLISHER_EMAIL=""
|
||||
PUBLISHER_NAME=""
|
||||
RESOURCE_BASE_NAME=""
|
||||
REPORTERS=""
|
||||
GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT=""
|
||||
CONTAINER_REGISTRY_NAME=""
|
||||
|
||||
@ -240,10 +239,6 @@ populateOptionalParams () {
|
||||
if [ ! -z "$RESOURCE_BASE_NAME" ]; then
|
||||
printf "\tsetting RESOURCE_BASE_NAME=$RESOURCE_BASE_NAME\n"
|
||||
fi
|
||||
if [ -z "$REPORTERS" ]; then
|
||||
REPORTERS="blob,console,app_insights"
|
||||
printf "\tsetting REPORTERS=blob,console,app_insights\n"
|
||||
fi
|
||||
if [ -z "$GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT" ]; then
|
||||
GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT="https://cognitiveservices.azure.com/.default"
|
||||
printf "\tsetting GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT=$GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT\n"
|
||||
@ -440,7 +435,6 @@ installGraphRAGHelmChart () {
|
||||
exitIfValueEmpty "$graphragImageName" "Unable to parse graphrag image name, exiting..."
|
||||
exitIfValueEmpty "$graphragImageVersion" "Unable to parse graphrag image version, exiting..."
|
||||
|
||||
local escapedReporters=$(sed "s/,/\\\,/g" <<< "$REPORTERS")
|
||||
reset_x=true
|
||||
if ! [ -o xtrace ]; then
|
||||
set -x
|
||||
@ -455,7 +449,7 @@ installGraphRAGHelmChart () {
|
||||
--set "master.image.repository=$containerRegistryName/$graphragImageName" \
|
||||
--set "master.image.tag=$graphragImageVersion" \
|
||||
--set "ingress.host=$graphragHostname" \
|
||||
--set "graphragConfig.APP_INSIGHTS_CONNECTION_STRING=$appInsightsConnectionString" \
|
||||
--set "graphragConfig.APPLICATIONINSIGHTS_CONNECTION_STRING=$appInsightsConnectionString" \
|
||||
--set "graphragConfig.AI_SEARCH_URL=https://$aiSearchName.$AISEARCH_ENDPOINT_SUFFIX" \
|
||||
--set "graphragConfig.AI_SEARCH_AUDIENCE=$AISEARCH_AUDIENCE" \
|
||||
--set "graphragConfig.COSMOS_URI_ENDPOINT=$cosmosEndpoint" \
|
||||
@ -466,7 +460,6 @@ installGraphRAGHelmChart () {
|
||||
--set "graphragConfig.GRAPHRAG_LLM_DEPLOYMENT_NAME=$GRAPHRAG_LLM_DEPLOYMENT_NAME" \
|
||||
--set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL=$GRAPHRAG_EMBEDDING_MODEL" \
|
||||
--set "graphragConfig.GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME=$GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME" \
|
||||
--set "graphragConfig.REPORTERS=$escapedReporters" \
|
||||
--set "graphragConfig.STORAGE_ACCOUNT_BLOB_URL=$storageAccountBlobUrl"
|
||||
|
||||
local helmResult=$?
|
||||
|
@ -32,7 +32,9 @@ ingress:
|
||||
graphragConfig:
|
||||
AI_SEARCH_AUDIENCE: ""
|
||||
AI_SEARCH_URL: ""
|
||||
APP_INSIGHTS_CONNECTION_STRING: ""
|
||||
APPLICATIONINSIGHTS_CONNECTION_STRING: ""
|
||||
# Must set hidden env variable to true to disable statsbeat. For more information: https://github.com/Azure/azure-sdk-for-python/issues/34804
|
||||
APPLICATIONINSIGHTS_STATSBEAT_DISABLED_ALL: "True"
|
||||
COSMOS_URI_ENDPOINT: ""
|
||||
GRAPHRAG_API_BASE: ""
|
||||
GRAPHRAG_API_VERSION: ""
|
||||
@ -41,7 +43,6 @@ graphragConfig:
|
||||
GRAPHRAG_LLM_DEPLOYMENT_NAME: ""
|
||||
GRAPHRAG_EMBEDDING_MODEL: ""
|
||||
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME: ""
|
||||
REPORTERS: "blob,console,app_insights"
|
||||
STORAGE_ACCOUNT_BLOB_URL: ""
|
||||
|
||||
master:
|
||||
@ -54,10 +55,12 @@ master:
|
||||
tag: ""
|
||||
podAnnotations: {}
|
||||
podLabels: {}
|
||||
podSecurityContext: {}
|
||||
podSecurityContext:
|
||||
{}
|
||||
# fsGroup: 2000
|
||||
|
||||
securityContext: {}
|
||||
securityContext:
|
||||
{}
|
||||
# capabilities:
|
||||
# drop:
|
||||
# - ALL
|
||||
@ -125,8 +128,8 @@ master:
|
||||
nodeAffinity:
|
||||
requiredDuringSchedulingIgnoredDuringExecution:
|
||||
nodeSelectorTerms:
|
||||
- matchExpressions:
|
||||
- key: workload
|
||||
operator: In
|
||||
values:
|
||||
- graphrag
|
||||
- matchExpressions:
|
||||
- key: workload
|
||||
operator: In
|
||||
values:
|
||||
- graphrag
|
||||
|
Loading…
x
Reference in New Issue
Block a user