refactor and reorganize indexing code out of api code

This commit is contained in:
Josh Bradley 2025-01-21 00:29:48 -05:00
parent e85c9c006e
commit a8bf6733df
38 changed files with 729 additions and 594 deletions

View File

@ -20,19 +20,19 @@ spec:
serviceAccountName: PLACEHOLDER
restartPolicy: OnFailure
containers:
- name: index-job-manager
image: PLACEHOLDER
imagePullPolicy: Always
resources:
requests:
cpu: "0.5"
memory: "0.5Gi"
limits:
cpu: "1"
memory: "1Gi"
envFrom:
- configMapRef:
name: graphrag
command:
- python
- "manage-indexing-jobs.py"
- name: index-job-manager
image: PLACEHOLDER
imagePullPolicy: Always
resources:
requests:
cpu: "0.5"
memory: "0.5Gi"
limits:
cpu: "1"
memory: "1Gi"
envFrom:
- configMapRef:
name: graphrag
command:
- python
- "manage-indexing-jobs.py"

View File

@ -21,17 +21,17 @@ spec:
nodeSelector:
workload: graphrag-indexing
containers:
- name: graphrag
image: PLACEHOLDER
imagePullPolicy: Always
resources:
requests:
cpu: "5"
memory: "36Gi"
limits:
cpu: "8"
memory: "64Gi"
envFrom:
- configMapRef:
name: graphrag
command: [PLACEHOLDER]
- name: graphrag
image: PLACEHOLDER
imagePullPolicy: Always
resources:
requests:
cpu: "5"
memory: "36Gi"
limits:
cpu: "8"
memory: "64Gi"
envFrom:
- configMapRef:
name: graphrag
command: [PLACEHOLDER]

View File

@ -18,9 +18,9 @@ from kubernetes import (
)
from src.api.azure_clients import AzureClientManager
from src.api.common import sanitize_name
from src.logger.logger_singleton import LoggerSingleton
from src.typing.pipeline import PipelineJobState
from src.utils.common import sanitize_name
from src.utils.pipeline import PipelineJob
@ -48,7 +48,7 @@ def schedule_indexing_job(index_name: str):
)
except Exception:
reporter = LoggerSingleton().get_instance()
reporter.on_error(
reporter.error(
"Index job manager encountered error scheduling indexing job",
)
# In the event of a catastrophic scheduling failure, something in k8s or the job manifest is likely broken.
@ -68,14 +68,14 @@ def _generate_aks_job_manifest(
The manifest must be valid YAML with certain values replaced by the provided arguments.
"""
# NOTE: this file location is relative to the WORKDIR set in Dockerfile-backend
with open("indexing-job-template.yaml", "r") as f:
with open("index-job.yaml", "r") as f:
manifest = yaml.safe_load(f)
manifest["metadata"]["name"] = f"indexing-job-{sanitize_name(index_name)}"
manifest["spec"]["template"]["spec"]["serviceAccountName"] = service_account_name
manifest["spec"]["template"]["spec"]["containers"][0]["image"] = docker_image_name
manifest["spec"]["template"]["spec"]["containers"][0]["command"] = [
"python",
"run-indexing-job.py",
"src/indexer/indexer.py",
f"-i={index_name}",
]
return manifest

321
backend/poetry.lock generated
View File

@ -215,17 +215,6 @@ files = [
[package.dependencies]
six = "*"
[[package]]
name = "applicationinsights"
version = "0.11.10"
description = "This project extends the Application Insights API surface to support Python."
optional = false
python-versions = "*"
files = [
{file = "applicationinsights-0.11.10-py2.py3-none-any.whl", hash = "sha256:e89a890db1c6906b6a7d0bcfd617dac83974773c64573147c8d6654f9cf2a6ea"},
{file = "applicationinsights-0.11.10.tar.gz", hash = "sha256:0b761f3ef0680acf4731906dfc1807faa6f2a57168ae74592db0084a6099f7b3"},
]
[[package]]
name = "appnope"
version = "0.1.4"
@ -313,6 +302,23 @@ types-python-dateutil = ">=2.8.10"
doc = ["doc8", "sphinx (>=7.0.0)", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx_rtd_theme (>=1.3.0)"]
test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (==3.*)"]
[[package]]
name = "asgiref"
version = "3.8.1"
description = "ASGI specs, helper code, and adapters"
optional = false
python-versions = ">=3.8"
files = [
{file = "asgiref-3.8.1-py3-none-any.whl", hash = "sha256:3e1e3ecc849832fe52ccf2cb6686b7a55f82bb1d6aee72a58826471390335e47"},
{file = "asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590"},
]
[package.dependencies]
typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""}
[package.extras]
tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"]
[[package]]
name = "asttokens"
version = "2.4.1"
@ -422,6 +428,21 @@ typing-extensions = ">=4.6.0"
[package.extras]
aio = ["aiohttp (>=3.0)"]
[[package]]
name = "azure-core-tracing-opentelemetry"
version = "1.0.0b11"
description = "Microsoft Azure Azure Core OpenTelemetry plugin Library for Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "azure-core-tracing-opentelemetry-1.0.0b11.tar.gz", hash = "sha256:a230d1555838b5d07b7594221cd639ea7bc24e29c881e5675e311c6067bad4f5"},
{file = "azure_core_tracing_opentelemetry-1.0.0b11-py3-none-any.whl", hash = "sha256:016cefcaff2900fb5cdb7a8a7abd03e9c266622c06e26b3fe6dafa54c4b48bf5"},
]
[package.dependencies]
azure-core = ">=1.24.0,<2.0.0"
opentelemetry-api = ">=1.12.0,<2.0.0"
[[package]]
name = "azure-cosmos"
version = "4.9.0"
@ -471,6 +492,31 @@ msal = ">=1.30.0"
msal-extensions = ">=1.2.0"
typing-extensions = ">=4.0.0"
[[package]]
name = "azure-monitor-opentelemetry"
version = "1.6.4"
description = "Microsoft Azure Monitor Opentelemetry Distro Client Library for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "azure_monitor_opentelemetry-1.6.4-py3-none-any.whl", hash = "sha256:014142ffa420bc2b287ff3bd30de6c31d64b2846423d011a8280334d7afcb01a"},
{file = "azure_monitor_opentelemetry-1.6.4.tar.gz", hash = "sha256:9f5ce4c666caf1f9b536f8ab4ee207dff94777d568517c74f26e3327f75c3fc3"},
]
[package.dependencies]
azure-core = ">=1.28.0,<2.0.0"
azure-core-tracing-opentelemetry = ">=1.0.0b11,<1.1.0"
azure-monitor-opentelemetry-exporter = ">=1.0.0b31,<1.1.0"
opentelemetry-instrumentation-django = ">=0.49b0,<1.0"
opentelemetry-instrumentation-fastapi = ">=0.49b0,<1.0"
opentelemetry-instrumentation-flask = ">=0.49b0,<1.0"
opentelemetry-instrumentation-psycopg2 = ">=0.49b0,<1.0"
opentelemetry-instrumentation-requests = ">=0.49b0,<1.0"
opentelemetry-instrumentation-urllib = ">=0.49b0,<1.0"
opentelemetry-instrumentation-urllib3 = ">=0.49b0,<1.0"
opentelemetry-resource-detector-azure = ">=0.1.4,<0.2.0"
opentelemetry-sdk = ">=1.28,<2.0"
[[package]]
name = "azure-monitor-opentelemetry-exporter"
version = "1.0.0b33"
@ -3519,13 +3565,13 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
[[package]]
name = "openai"
version = "1.59.7"
version = "1.59.8"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.8"
files = [
{file = "openai-1.59.7-py3-none-any.whl", hash = "sha256:cfa806556226fa96df7380ab2e29814181d56fea44738c2b0e581b462c268692"},
{file = "openai-1.59.7.tar.gz", hash = "sha256:043603def78c00befb857df9f0a16ee76a3af5984ba40cb7ee5e2f40db4646bf"},
{file = "openai-1.59.8-py3-none-any.whl", hash = "sha256:a8b8ee35c4083b88e6da45406d883cf6bd91a98ab7dd79178b8bc24c8bfb09d9"},
{file = "openai-1.59.8.tar.gz", hash = "sha256:ac4bda5fa9819fdc6127e8ea8a63501f425c587244bc653c7c11a8ad84f953e1"},
]
[package.dependencies]
@ -3557,6 +3603,234 @@ files = [
deprecated = ">=1.2.6"
importlib-metadata = ">=6.0,<=8.5.0"
[[package]]
name = "opentelemetry-instrumentation"
version = "0.50b0"
description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation-0.50b0-py3-none-any.whl", hash = "sha256:b8f9fc8812de36e1c6dffa5bfc6224df258841fb387b6dfe5df15099daa10630"},
{file = "opentelemetry_instrumentation-0.50b0.tar.gz", hash = "sha256:7d98af72de8dec5323e5202e46122e5f908592b22c6d24733aad619f07d82979"},
]
[package.dependencies]
opentelemetry-api = ">=1.4,<2.0"
opentelemetry-semantic-conventions = "0.50b0"
packaging = ">=18.0"
wrapt = ">=1.0.0,<2.0.0"
[[package]]
name = "opentelemetry-instrumentation-asgi"
version = "0.50b0"
description = "ASGI instrumentation for OpenTelemetry"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_asgi-0.50b0-py3-none-any.whl", hash = "sha256:2ba1297f746e55dec5a17fe825689da0613662fb25c004c3965a6c54b1d5be22"},
{file = "opentelemetry_instrumentation_asgi-0.50b0.tar.gz", hash = "sha256:3ca4cb5616ae6a3e8ce86e7d5c360a8d8cc8ed722cf3dc8a5e44300774e87d49"},
]
[package.dependencies]
asgiref = ">=3.0,<4.0"
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.50b0"
opentelemetry-semantic-conventions = "0.50b0"
opentelemetry-util-http = "0.50b0"
[package.extras]
instruments = ["asgiref (>=3.0,<4.0)"]
[[package]]
name = "opentelemetry-instrumentation-dbapi"
version = "0.50b0"
description = "OpenTelemetry Database API instrumentation"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_dbapi-0.50b0-py3-none-any.whl", hash = "sha256:23a730c3d7372b04b8a9507d2a67c5efbf92ff718eaa002b81ffbaf2b01d270f"},
{file = "opentelemetry_instrumentation_dbapi-0.50b0.tar.gz", hash = "sha256:2603ca39e216893026c185ca8c44c326c0a9a763d5afff2309bd6195c50b7c49"},
]
[package.dependencies]
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.50b0"
opentelemetry-semantic-conventions = "0.50b0"
wrapt = ">=1.0.0,<2.0.0"
[[package]]
name = "opentelemetry-instrumentation-django"
version = "0.50b0"
description = "OpenTelemetry Instrumentation for Django"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_django-0.50b0-py3-none-any.whl", hash = "sha256:ab7b4cd52b8f12420d968823f6bbfbc2a6ddb2af7a05fcb0d5b6755d338f1915"},
{file = "opentelemetry_instrumentation_django-0.50b0.tar.gz", hash = "sha256:624fd0beb1ac827f2af31709c2da5cb55d8dc899c2449d6e8fcc9fa5538fd56b"},
]
[package.dependencies]
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.50b0"
opentelemetry-instrumentation-wsgi = "0.50b0"
opentelemetry-semantic-conventions = "0.50b0"
opentelemetry-util-http = "0.50b0"
[package.extras]
asgi = ["opentelemetry-instrumentation-asgi (==0.50b0)"]
instruments = ["django (>=1.10)"]
[[package]]
name = "opentelemetry-instrumentation-fastapi"
version = "0.50b0"
description = "OpenTelemetry FastAPI Instrumentation"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_fastapi-0.50b0-py3-none-any.whl", hash = "sha256:8f03b738495e4705fbae51a2826389c7369629dace89d0f291c06ffefdff5e52"},
{file = "opentelemetry_instrumentation_fastapi-0.50b0.tar.gz", hash = "sha256:16b9181682136da210295def2bb304a32fb9bdee9a935cdc9da43567f7c1149e"},
]
[package.dependencies]
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.50b0"
opentelemetry-instrumentation-asgi = "0.50b0"
opentelemetry-semantic-conventions = "0.50b0"
opentelemetry-util-http = "0.50b0"
[package.extras]
instruments = ["fastapi (>=0.58,<1.0)"]
[[package]]
name = "opentelemetry-instrumentation-flask"
version = "0.50b0"
description = "Flask instrumentation for OpenTelemetry"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_flask-0.50b0-py3-none-any.whl", hash = "sha256:db7fb40191145f4356a793922c3fc80a33689e6a7c7c4c6def8aa1eedb0ac42a"},
{file = "opentelemetry_instrumentation_flask-0.50b0.tar.gz", hash = "sha256:e56a820b1d43fdd5a57f7b481c4d6365210a48a1312c83af4185bc636977755f"},
]
[package.dependencies]
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.50b0"
opentelemetry-instrumentation-wsgi = "0.50b0"
opentelemetry-semantic-conventions = "0.50b0"
opentelemetry-util-http = "0.50b0"
packaging = ">=21.0"
[package.extras]
instruments = ["flask (>=1.0)"]
[[package]]
name = "opentelemetry-instrumentation-psycopg2"
version = "0.50b0"
description = "OpenTelemetry psycopg2 instrumentation"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_psycopg2-0.50b0-py3-none-any.whl", hash = "sha256:448297e63320711b5571f64bcf5d67ecf4856454c36d3bff6c3d01a4f8a48d18"},
{file = "opentelemetry_instrumentation_psycopg2-0.50b0.tar.gz", hash = "sha256:86f8e507e98d8824f51bbc3c62121dbd4b8286063362f10b9dfa035a8da49f0b"},
]
[package.dependencies]
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.50b0"
opentelemetry-instrumentation-dbapi = "0.50b0"
[package.extras]
instruments = ["psycopg2 (>=2.7.3.1)"]
[[package]]
name = "opentelemetry-instrumentation-requests"
version = "0.50b0"
description = "OpenTelemetry requests instrumentation"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_requests-0.50b0-py3-none-any.whl", hash = "sha256:2c60a890988d6765de9230004d0af9071b3b2e1ddba4ca3b631cfb8a1722208d"},
{file = "opentelemetry_instrumentation_requests-0.50b0.tar.gz", hash = "sha256:f8088c76f757985b492aad33331d21aec2f99c197472a57091c2e986a4b7ec8b"},
]
[package.dependencies]
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.50b0"
opentelemetry-semantic-conventions = "0.50b0"
opentelemetry-util-http = "0.50b0"
[package.extras]
instruments = ["requests (>=2.0,<3.0)"]
[[package]]
name = "opentelemetry-instrumentation-urllib"
version = "0.50b0"
description = "OpenTelemetry urllib instrumentation"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_urllib-0.50b0-py3-none-any.whl", hash = "sha256:55024940fd41fbdd5a6ab5b6397660900b7a75e23f9ff7f61b4ae1279710a3ec"},
{file = "opentelemetry_instrumentation_urllib-0.50b0.tar.gz", hash = "sha256:af3e9710635c3f8a5ec38adc772dfef0c1022d0196007baf4b74504e920b5d31"},
]
[package.dependencies]
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.50b0"
opentelemetry-semantic-conventions = "0.50b0"
opentelemetry-util-http = "0.50b0"
[[package]]
name = "opentelemetry-instrumentation-urllib3"
version = "0.50b0"
description = "OpenTelemetry urllib3 instrumentation"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_urllib3-0.50b0-py3-none-any.whl", hash = "sha256:c679b3908645b7d4d07c36960fe0efef490b403983e314108450146cc89bd675"},
{file = "opentelemetry_instrumentation_urllib3-0.50b0.tar.gz", hash = "sha256:2c4a1d9f128eaf753871b1d90659c744691d039a6601ba546081347ae192bd0e"},
]
[package.dependencies]
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.50b0"
opentelemetry-semantic-conventions = "0.50b0"
opentelemetry-util-http = "0.50b0"
wrapt = ">=1.0.0,<2.0.0"
[package.extras]
instruments = ["urllib3 (>=1.0.0,<3.0.0)"]
[[package]]
name = "opentelemetry-instrumentation-wsgi"
version = "0.50b0"
description = "WSGI Middleware for OpenTelemetry"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_instrumentation_wsgi-0.50b0-py3-none-any.whl", hash = "sha256:4bc0fdf52b603507d6170a25504f0ceea358d7e90a2c0e8794b7b7eca5ea355c"},
{file = "opentelemetry_instrumentation_wsgi-0.50b0.tar.gz", hash = "sha256:c25b5f1b664d984a41546a34cf2f893dcde6cf56922f88c475864e7df37edf4a"},
]
[package.dependencies]
opentelemetry-api = ">=1.12,<2.0"
opentelemetry-instrumentation = "0.50b0"
opentelemetry-semantic-conventions = "0.50b0"
opentelemetry-util-http = "0.50b0"
[[package]]
name = "opentelemetry-resource-detector-azure"
version = "0.1.5"
description = "Azure Resource Detector for OpenTelemetry"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_resource_detector_azure-0.1.5-py3-none-any.whl", hash = "sha256:4dcc5d54ab5c3b11226af39509bc98979a8b9e0f8a24c1b888783755d3bf00eb"},
{file = "opentelemetry_resource_detector_azure-0.1.5.tar.gz", hash = "sha256:e0ba658a87c69eebc806e75398cd0e9f68a8898ea62de99bc1b7083136403710"},
]
[package.dependencies]
opentelemetry-sdk = ">=1.21,<2.0"
[[package]]
name = "opentelemetry-sdk"
version = "1.29.0"
@ -3588,6 +3862,17 @@ files = [
deprecated = ">=1.2.6"
opentelemetry-api = "1.29.0"
[[package]]
name = "opentelemetry-util-http"
version = "0.50b0"
description = "Web util for OpenTelemetry"
optional = false
python-versions = ">=3.8"
files = [
{file = "opentelemetry_util_http-0.50b0-py3-none-any.whl", hash = "sha256:21f8aedac861ffa3b850f8c0a6c373026189eb8630ac6e14a2bf8c55695cc090"},
{file = "opentelemetry_util_http-0.50b0.tar.gz", hash = "sha256:dc4606027e1bc02aabb9533cc330dd43f874fca492e4175c31d7154f341754af"},
]
[[package]]
name = "overrides"
version = "7.7.0"
@ -5874,13 +6159,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
[[package]]
name = "virtualenv"
version = "20.29.0"
version = "20.29.1"
description = "Virtual Python Environment builder"
optional = false
python-versions = ">=3.8"
files = [
{file = "virtualenv-20.29.0-py3-none-any.whl", hash = "sha256:c12311863497992dc4b8644f8ea82d3b35bb7ef8ee82e6630d76d0197c39baf9"},
{file = "virtualenv-20.29.0.tar.gz", hash = "sha256:6345e1ff19d4b1296954cee076baaf58ff2a12a84a338c62b02eda39f20aa982"},
{file = "virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779"},
{file = "virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35"},
]
[package.dependencies]
@ -6172,4 +6457,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "~3.10"
content-hash = "f7039aa8ec1bb1d1f436703fcdb21f83711fd77c72ead5aa94f3d5fccd47a699"
content-hash = "a7740ff5c5d60171f025a676b625678331eed3ac55c763c8057adb49de56f734"

View File

@ -40,12 +40,11 @@ wikipedia = ">=1.4.0"
[tool.poetry.group.backend.dependencies]
adlfs = ">=2024.7.0"
applicationinsights = ">=0.11.10"
attrs = ">=23.2.0"
azure-core = ">=1.30.1"
azure-cosmos = ">=4.5.1"
azure-identity = ">=1.15.0"
azure-monitor-opentelemetry-exporter = "*"
azure-monitor-opentelemetry = "^1.6.4"
azure-search-documents = ">=11.4.0"
azure-storage-blob = ">=12.19.0"
environs = ">=9.5.0"
@ -54,12 +53,10 @@ fastapi-offline = ">=1.7.3"
fastparquet = ">=2023.10.1"
fsspec = ">=2024.2.0"
graphrag = { git = "https://github.com/microsoft/graphrag.git", branch = "main" }
graspologic = ">=3.3.0"
httpx = ">=0.25.2"
kubernetes = ">=29.0.0"
networkx = ">=3.2.1"
nltk = "*"
opentelemetry-sdk = ">=1.27.0"
pandas = ">=2.2.1"
pyaml-env = ">=1.2.1"
pyarrow = ">=15.0.0"

View File

@ -6,6 +6,6 @@ asyncio_mode=auto
; If executing these pytests locally, users may need to modify the cosmosdb connection string to use http protocol instead of https.
; This depends on how the cosmosdb emulator has been configured (by the user) to run.
env =
COSMOS_CONNECTION_STRING=AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==
COSMOS_CONNECTION_STRING=AccountEndpoint=http://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==
STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;QueueEndpoint=http://127.0.0.1:10001/devstoreaccount1;TableEndpoint=http://127.0.0.1:10002/devstoreaccount1;
TESTING=1

View File

@ -1,18 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import asyncio
from src import main # noqa: F401
from src.api.index import _start_indexing_pipeline
parser = argparse.ArgumentParser(description="Kickoff indexing job.")
parser.add_argument("-i", "--index-name", required=True)
args = parser.parse_args()
asyncio.run(
_start_indexing_pipeline(
index_name=args.index_name,
)
)

View File

@ -14,17 +14,17 @@ from fastapi import (
)
from src.api.azure_clients import AzureClientManager
from src.api.common import (
from src.logger import LoggerSingleton
from src.typing.models import (
BaseResponse,
StorageNameList,
)
from src.utils.common import (
delete_blob_container,
delete_cosmos_container_item,
sanitize_name,
validate_blob_container_name,
)
from src.logger import LoggerSingleton
from src.models import (
BaseResponse,
StorageNameList,
)
data_route = APIRouter(
prefix="/data",
@ -53,7 +53,7 @@ async def get_all_data_storage_containers():
items.append(item["human_readable_name"])
except Exception:
reporter = LoggerSingleton().get_instance()
reporter.on_error("Error getting list of blob containers.")
reporter.error("Error getting list of blob containers.")
raise HTTPException(
status_code=500, detail="Error getting list of blob containers."
)
@ -171,7 +171,7 @@ async def upload_files(
return BaseResponse(status="File upload successful.")
except Exception:
logger = LoggerSingleton().get_instance()
logger.on_error("Error uploading files.", details={"files": files})
logger.error("Error uploading files.", details={"files": files})
raise HTTPException(
status_code=500,
detail=f"Error uploading files to container '{storage_name}'.",
@ -197,7 +197,7 @@ async def delete_files(storage_name: str):
delete_cosmos_container_item("container-store", sanitized_storage_name)
except Exception:
logger = LoggerSingleton().get_instance()
logger.on_error(
logger.error(
f"Error deleting container {storage_name}.",
details={"Container": storage_name},
)

View File

@ -8,11 +8,11 @@ from fastapi import (
from fastapi.responses import StreamingResponse
from src.api.azure_clients import AzureClientManager
from src.api.common import (
from src.logger import LoggerSingleton
from src.utils.common import (
sanitize_name,
validate_index_file_exist,
)
from src.logger import LoggerSingleton
graph_route = APIRouter(
prefix="/graph",
@ -44,7 +44,7 @@ async def get_graphml_file(index_name: str):
)
except Exception:
logger = LoggerSingleton().get_instance()
logger.on_error("Could not retrieve graphml file")
logger.error("Could not retrieve graphml file")
raise HTTPException(
status_code=500,
detail=f"Could not retrieve graphml file for index '{index_name}'.",

View File

@ -1,13 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import inspect
import os
import traceback
from time import time
import graphrag.api as api
import yaml
from azure.identity import DefaultAzureCredential
from azure.search.documents.indexes import SearchIndexClient
from fastapi import (
@ -15,9 +11,6 @@ from fastapi import (
HTTPException,
UploadFile,
)
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.bootstrap import bootstrap
from graphrag.index.create_pipeline_config import create_pipeline_config
from kubernetes import (
client as kubernetes_client,
)
@ -26,23 +19,18 @@ from kubernetes import (
)
from src.api.azure_clients import AzureClientManager
from src.api.common import (
delete_blob_container,
sanitize_name,
validate_blob_container_name,
)
from src.logger import (
LoggerSingleton,
PipelineJobWorkflowCallbacks,
Reporters,
load_pipeline_logger,
)
from src.models import (
from src.logger import LoggerSingleton
from src.typing.models import (
BaseResponse,
IndexNameList,
IndexStatusResponse,
)
from src.typing.pipeline import PipelineJobState
from src.utils.common import (
delete_blob_container,
sanitize_name,
validate_blob_container_name,
)
from src.utils.pipeline import PipelineJob
index_route = APIRouter(
@ -57,7 +45,7 @@ index_route = APIRouter(
response_model=BaseResponse,
responses={200: {"model": BaseResponse}},
)
async def setup_indexing_pipeline(
async def schedule_indexing_job(
storage_name: str,
index_name: str,
entity_extraction_prompt: UploadFile | None = None,
@ -148,173 +136,6 @@ async def setup_indexing_pipeline(
return BaseResponse(status="Indexing job scheduled")
async def _start_indexing_pipeline(index_name: str):
# get sanitized name
sanitized_index_name = sanitize_name(index_name)
# update or create new item in container-store in cosmosDB
azure_client_manager = AzureClientManager()
blob_service_client = azure_client_manager.get_blob_service_client()
if not blob_service_client.get_container_client(sanitized_index_name).exists():
blob_service_client.create_container(sanitized_index_name)
cosmos_container_client = azure_client_manager.get_cosmos_container_client(
database="graphrag", container="container-store"
)
cosmos_container_client.upsert_item({
"id": sanitized_index_name,
"human_readable_name": index_name,
"type": "index",
})
logger = LoggerSingleton().get_instance()
pipelinejob = PipelineJob()
pipeline_job = pipelinejob.load_item(sanitized_index_name)
sanitized_storage_name = pipeline_job.sanitized_storage_name
storage_name = pipeline_job.human_readable_index_name
# download nltk dependencies
bootstrap()
# load custom pipeline settings
this_directory = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
)
data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml"))
# dynamically set some values
data["input"]["container_name"] = sanitized_storage_name
data["storage"]["container_name"] = sanitized_index_name
data["reporting"]["container_name"] = sanitized_index_name
data["cache"]["container_name"] = sanitized_index_name
if "vector_store" in data["embeddings"]:
data["embeddings"]["vector_store"]["collection_name"] = (
f"{sanitized_index_name}_description_embedding"
)
# set prompt for entity extraction
if pipeline_job.entity_extraction_prompt:
fname = "entity-extraction-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.entity_extraction_prompt)
data["entity_extraction"]["prompt"] = fname
else:
data.pop("entity_extraction")
# set prompt for summarize descriptions
if pipeline_job.summarize_descriptions_prompt:
fname = "summarize-descriptions-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.summarize_descriptions_prompt)
data["summarize_descriptions"]["prompt"] = fname
else:
data.pop("summarize_descriptions")
# set prompt for community report
if pipeline_job.community_report_prompt:
fname = "community-report-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.community_report_prompt)
data["community_reports"]["prompt"] = fname
else:
data.pop("community_reports")
# generate a default GraphRagConfig and override with custom settings
parameters = create_graphrag_config(data, ".")
# reset pipeline job details
pipeline_job.status = PipelineJobState.RUNNING
pipeline_job.all_workflows = []
pipeline_job.completed_workflows = []
pipeline_job.failed_workflows = []
pipeline_config = create_pipeline_config(parameters)
for workflow in pipeline_config.workflows:
pipeline_job.all_workflows.append(workflow.name)
# create new loggers/callbacks just for this job
loggers = []
logger_names = os.getenv("REPORTERS", Reporters.CONSOLE.name.upper()).split(",")
for logger_name in logger_names:
try:
loggers.append(Reporters[logger_name.upper()])
except KeyError:
raise ValueError(f"Unknown logger type: {logger_name}")
workflow_callbacks = load_pipeline_logger(
index_name=index_name,
num_workflow_steps=len(pipeline_job.all_workflows),
reporting_dir=sanitized_index_name,
reporters=loggers,
)
# add pipeline job callback to monitor job progress
pipeline_job_callback = PipelineJobWorkflowCallbacks(pipeline_job)
# run the pipeline
try:
await api.build_index(
config=parameters,
callbacks=[workflow_callbacks, pipeline_job_callback],
)
# if job is done, check if any workflow steps failed
if len(pipeline_job.failed_workflows) > 0:
pipeline_job.status = PipelineJobState.FAILED
workflow_callbacks.on_log(
message=f"Indexing pipeline encountered error for index'{index_name}'.",
details={
"index": index_name,
"storage_name": storage_name,
"status_message": "indexing pipeline encountered error",
},
)
else:
# record the workflow completion
pipeline_job.status = PipelineJobState.COMPLETE
pipeline_job.percent_complete = 100
workflow_callbacks.on_log(
message=f"Indexing pipeline complete for index'{index_name}'.",
details={
"index": index_name,
"storage_name": storage_name,
"status_message": "indexing pipeline complete",
},
)
pipeline_job.progress = (
f"{len(pipeline_job.completed_workflows)} out of "
f"{len(pipeline_job.all_workflows)} workflows completed successfully."
)
del workflow_callbacks # garbage collect
if pipeline_job.status == PipelineJobState.FAILED:
exit(1) # signal to AKS that indexing job failed
except Exception as e:
pipeline_job.status = PipelineJobState.FAILED
# update failed state in cosmos db
error_details = {
"index": index_name,
"storage_name": storage_name,
}
# log error in local index directory logs
workflow_callbacks.on_error(
message=f"Indexing pipeline failed for index '{index_name}'.",
cause=e,
stack=traceback.format_exc(),
details=error_details,
)
# log error in global index directory logs
logger.on_error(
message=f"Indexing pipeline failed for index '{index_name}'.",
cause=e,
stack=traceback.format_exc(),
details=error_details,
)
raise HTTPException(
status_code=500,
detail=f"Error encountered during indexing job for index '{index_name}'.",
)
@index_route.get(
"",
summary="Get all indexes",
@ -336,7 +157,7 @@ async def get_all_indexes():
items.append(item["human_readable_name"])
except Exception:
logger = LoggerSingleton().get_instance()
logger.on_error("Error retrieving index names")
logger.error("Error retrieving index names")
return IndexNameList(index_name=items)
@ -367,7 +188,7 @@ def _delete_k8s_job(job_name: str, namespace: str) -> None:
batch_v1 = kubernetes_client.BatchV1Api()
batch_v1.delete_namespaced_job(name=job_name, namespace=namespace)
except Exception:
logger.on_error(
logger.error(
message=f"Error deleting k8s job {job_name}.",
details={"container": job_name},
)
@ -378,7 +199,7 @@ def _delete_k8s_job(job_name: str, namespace: str) -> None:
if job_pod:
core_v1.delete_namespaced_pod(job_pod, namespace=namespace)
except Exception:
logger.on_error(
logger.error(
message=f"Error deleting k8s pod for job {job_name}.",
details={"container": job_name},
)
@ -441,7 +262,7 @@ async def delete_index(index_name: str):
except Exception:
logger = LoggerSingleton().get_instance()
logger.on_error(
logger.error(
message=f"Error encountered while deleting all data for index {index_name}.",
stack=None,
details={"container": index_name},

View File

@ -14,17 +14,13 @@ from fastapi import (
from graphrag.config.create_graphrag_config import create_graphrag_config
from src.api.azure_clients import AzureClientManager
from src.api.common import (
sanitize_name,
)
from src.logger import LoggerSingleton
from src.utils.common import sanitize_name
index_configuration_route = APIRouter(
prefix="/index/config", tags=["Index Configuration"]
)
prompt_tuning_route = APIRouter(prefix="/index/config", tags=["Index Configuration"])
@index_configuration_route.get(
@prompt_tuning_route.get(
"/prompts",
summary="Generate prompts from user-provided data.",
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
@ -66,7 +62,7 @@ async def generate_prompts(storage_name: str, limit: int = 5):
error_details = {
"storage_name": storage_name,
}
logger.on_error(
logger.error(
message="Auto-prompt generation failed.",
cause=e,
stack=traceback.format_exc(),

View File

@ -26,17 +26,17 @@ from graphrag.vector_stores.base import (
)
from src.api.azure_clients import AzureClientManager
from src.api.common import (
sanitize_name,
validate_index_file_exist,
)
from src.logger import LoggerSingleton
from src.models import (
from src.typing.models import (
GraphRequest,
GraphResponse,
)
from src.typing.pipeline import PipelineJobState
from src.utils import query as query_helper
from src.utils.common import (
get_df,
sanitize_name,
validate_index_file_exist,
)
from src.utils.pipeline import PipelineJob
query_route = APIRouter(
@ -116,7 +116,7 @@ async def global_query(request: GraphRequest):
# read the parquet files into DataFrames and add provenance information
# note that nodes need to be set before communities so that max community id makes sense
nodes_df = query_helper.get_df(nodes_table_path)
nodes_df = get_df(nodes_table_path)
for i in nodes_df["human_readable_id"]:
links["nodes"][i + max_vals["nodes"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -134,7 +134,7 @@ async def global_query(request: GraphRequest):
max_vals["nodes"] = nodes_df["human_readable_id"].max()
nodes_dfs.append(nodes_df)
community_df = query_helper.get_df(community_report_table_path)
community_df = get_df(community_report_table_path)
for i in community_df["community"].astype(int):
links["community"][i + max_vals["community"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -146,7 +146,7 @@ async def global_query(request: GraphRequest):
max_vals["community"] = community_df["community"].astype(int).max()
community_dfs.append(community_df)
entities_df = query_helper.get_df(entities_table_path)
entities_df = get_df(entities_table_path)
for i in entities_df["human_readable_id"]:
links["entities"][i + max_vals["entities"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -197,7 +197,7 @@ async def global_query(request: GraphRequest):
return GraphResponse(result=result[0], context_data=context_data)
except Exception as e:
logger = LoggerSingleton().get_instance()
logger.on_error(
logger.error(
message="Could not perform global search.",
cause=e,
stack=traceback.format_exc(),
@ -287,7 +287,7 @@ async def local_query(request: GraphRequest):
# read the parquet files into DataFrames and add provenance information
# note that nodes need to set before communities to that max community id makes sense
nodes_df = query_helper.get_df(nodes_table_path)
nodes_df = get_df(nodes_table_path)
for i in nodes_df["human_readable_id"]:
links["nodes"][i + max_vals["nodes"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -306,7 +306,7 @@ async def local_query(request: GraphRequest):
max_vals["nodes"] = nodes_df["human_readable_id"].max()
nodes_dfs.append(nodes_df)
community_df = query_helper.get_df(community_report_table_path)
community_df = get_df(community_report_table_path)
for i in community_df["community"].astype(int):
links["community"][i + max_vals["community"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -318,7 +318,7 @@ async def local_query(request: GraphRequest):
max_vals["community"] = community_df["community"].astype(int).max()
community_dfs.append(community_df)
entities_df = query_helper.get_df(entities_table_path)
entities_df = get_df(entities_table_path)
for i in entities_df["human_readable_id"]:
links["entities"][i + max_vals["entities"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -334,7 +334,7 @@ async def local_query(request: GraphRequest):
max_vals["entities"] = entities_df["human_readable_id"].max()
entities_dfs.append(entities_df)
relationships_df = query_helper.get_df(relationships_table_path)
relationships_df = get_df(relationships_table_path)
for i in relationships_df["human_readable_id"].astype(int):
links["relationships"][i + max_vals["relationships"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -361,13 +361,13 @@ async def local_query(request: GraphRequest):
)
relationships_dfs.append(relationships_df)
text_units_df = query_helper.get_df(text_units_table_path)
text_units_df = get_df(text_units_table_path)
text_units_df["id"] = text_units_df["id"].apply(lambda x: f"{x}-{index_name}")
text_units_dfs.append(text_units_df)
index_container_client = blob_service_client.get_container_client(index_name)
if index_container_client.get_blob_client(COVARIATES_TABLE).exists():
covariates_df = query_helper.get_df(covariates_table_path)
covariates_df = get_df(covariates_table_path)
if i in covariates_df["human_readable_id"].astype(int):
links["covariates"][i + max_vals["covariates"] + 1] = {
"index_name": sanitized_index_names_link[index_name],

View File

@ -22,14 +22,14 @@ from graphrag.api.query import (
from graphrag.config import create_graphrag_config
from src.api.azure_clients import AzureClientManager
from src.api.common import (
from src.api.query import _is_index_complete
from src.logger import LoggerSingleton
from src.typing.models import GraphRequest
from src.utils.common import (
get_df,
sanitize_name,
validate_index_file_exist,
)
from src.api.query import _is_index_complete
from src.logger import LoggerSingleton
from src.models import GraphRequest
from src.utils import query as query_helper
from .query import _get_embedding_description_store, _update_context
@ -107,7 +107,7 @@ async def global_search_streaming(request: GraphRequest):
# read parquet files into DataFrames and add provenance information
# note that nodes need to set before communities to that max community id makes sense
nodes_df = query_helper.get_df(nodes_table_path)
nodes_df = get_df(nodes_table_path)
for i in nodes_df["human_readable_id"]:
links["nodes"][i + max_vals["nodes"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -125,7 +125,7 @@ async def global_search_streaming(request: GraphRequest):
max_vals["nodes"] = nodes_df["human_readable_id"].max()
nodes_dfs.append(nodes_df)
community_df = query_helper.get_df(community_report_table_path)
community_df = get_df(community_report_table_path)
for i in community_df["community"].astype(int):
links["community"][i + max_vals["community"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -137,7 +137,7 @@ async def global_search_streaming(request: GraphRequest):
max_vals["community"] = community_df["community"].astype(int).max()
community_dfs.append(community_df)
entities_df = query_helper.get_df(entities_table_path)
entities_df = get_df(entities_table_path)
for i in entities_df["human_readable_id"]:
links["entities"][i + max_vals["entities"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -188,7 +188,7 @@ async def global_search_streaming(request: GraphRequest):
)
except Exception as e:
logger = LoggerSingleton().get_instance()
logger.on_error(
logger.error(
message="Error encountered while streaming global search response",
cause=e,
stack=traceback.format_exc(),
@ -277,7 +277,7 @@ async def local_search_streaming(request: GraphRequest):
# read the parquet files into DataFrames and add provenance information
# note that nodes need to set before communities to that max community id makes sense
nodes_df = query_helper.get_df(nodes_table_path)
nodes_df = get_df(nodes_table_path)
for i in nodes_df["human_readable_id"]:
links["nodes"][i + max_vals["nodes"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -296,7 +296,7 @@ async def local_search_streaming(request: GraphRequest):
max_vals["nodes"] = nodes_df["human_readable_id"].max()
nodes_dfs.append(nodes_df)
community_df = query_helper.get_df(community_report_table_path)
community_df = get_df(community_report_table_path)
for i in community_df["community"].astype(int):
links["community"][i + max_vals["community"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -308,7 +308,7 @@ async def local_search_streaming(request: GraphRequest):
max_vals["community"] = community_df["community"].astype(int).max()
community_dfs.append(community_df)
entities_df = query_helper.get_df(entities_table_path)
entities_df = get_df(entities_table_path)
for i in entities_df["human_readable_id"]:
links["entities"][i + max_vals["entities"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -326,7 +326,7 @@ async def local_search_streaming(request: GraphRequest):
max_vals["entities"] = entities_df["human_readable_id"].max()
entities_dfs.append(entities_df)
relationships_df = query_helper.get_df(relationships_table_path)
relationships_df = get_df(relationships_table_path)
for i in relationships_df["human_readable_id"].astype(int):
links["relationships"][i + max_vals["relationships"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -353,7 +353,7 @@ async def local_search_streaming(request: GraphRequest):
)
relationships_dfs.append(relationships_df)
text_units_df = query_helper.get_df(text_units_table_path)
text_units_df = get_df(text_units_table_path)
text_units_df["id"] = text_units_df["id"].apply(
lambda x: f"{x}-{index_name}"
)
@ -363,7 +363,7 @@ async def local_search_streaming(request: GraphRequest):
index_name
)
if index_container_client.get_blob_client(COVARIATES_TABLE).exists():
covariates_df = query_helper.get_df(covariates_table_path)
covariates_df = get_df(covariates_table_path)
if i in covariates_df["human_readable_id"].astype(int):
links["covariates"][i + max_vals["covariates"] + 1] = {
"index_name": sanitized_index_names_link[index_name],
@ -431,7 +431,7 @@ async def local_search_streaming(request: GraphRequest):
)
except Exception as e:
logger = LoggerSingleton().get_instance()
logger.on_error(
logger.error(
message="Error encountered while streaming local search response",
cause=e,
stack=traceback.format_exc(),

View File

@ -5,19 +5,19 @@
import pandas as pd
from fastapi import APIRouter, HTTPException
from src.api.common import (
get_pandas_storage_options,
sanitize_name,
validate_index_file_exist,
)
from src.logger import LoggerSingleton
from src.models import (
from src.typing.models import (
ClaimResponse,
EntityResponse,
RelationshipResponse,
ReportResponse,
TextUnitResponse,
)
from src.utils.common import (
pandas_storage_options,
sanitize_name,
validate_index_file_exist,
)
source_route = APIRouter(
prefix="/source",
@ -46,7 +46,7 @@ async def get_report_info(index_name: str, report_id: str):
try:
report_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{COMMUNITY_REPORT_TABLE}",
storage_options=get_pandas_storage_options(),
storage_options=pandas_storage_options(),
)
# check if report_id exists in the index
if not report_table["community"].isin([report_id]).any():
@ -62,7 +62,7 @@ async def get_report_info(index_name: str, report_id: str):
return ReportResponse(text=report_content)
except Exception:
logger = LoggerSingleton().get_instance()
logger.on_error("Could not get report.")
logger.error("Could not get report.")
raise HTTPException(
status_code=500,
detail=f"Error retrieving report '{report_id}' from index '{index_name}'.",
@ -83,11 +83,11 @@ async def get_chunk_info(index_name: str, text_unit_id: str):
try:
text_units = pd.read_parquet(
f"abfs://{sanitized_index_name}/{TEXT_UNITS_TABLE}",
storage_options=get_pandas_storage_options(),
storage_options=pandas_storage_options(),
)
docs = pd.read_parquet(
f"abfs://{sanitized_index_name}/{DOCUMENTS_TABLE}",
storage_options=get_pandas_storage_options(),
storage_options=pandas_storage_options(),
)
# rename columns for easy joining
docs = docs[["id", "title"]].rename(
@ -115,7 +115,7 @@ async def get_chunk_info(index_name: str, text_unit_id: str):
)
except Exception:
logger = LoggerSingleton().get_instance()
logger.on_error("Could not get text chunk.")
logger.error("Could not get text chunk.")
raise HTTPException(
status_code=500,
detail=f"Error retrieving text chunk '{text_unit_id}' from index '{index_name}'.",
@ -135,7 +135,7 @@ async def get_entity_info(index_name: str, entity_id: int):
try:
entity_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{ENTITY_EMBEDDING_TABLE}",
storage_options=get_pandas_storage_options(),
storage_options=pandas_storage_options(),
)
# check if entity_id exists in the index
if not entity_table["human_readable_id"].isin([entity_id]).any():
@ -148,7 +148,7 @@ async def get_entity_info(index_name: str, entity_id: int):
)
except Exception:
logger = LoggerSingleton().get_instance()
logger.on_error("Could not get entity")
logger.error("Could not get entity")
raise HTTPException(
status_code=500,
detail=f"Error retrieving entity '{entity_id}' from index '{index_name}'.",
@ -175,7 +175,7 @@ async def get_claim_info(index_name: str, claim_id: int):
try:
claims_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{COVARIATES_TABLE}",
storage_options=get_pandas_storage_options(),
storage_options=pandas_storage_options(),
)
claims_table.human_readable_id = claims_table.human_readable_id.astype(
float
@ -193,7 +193,7 @@ async def get_claim_info(index_name: str, claim_id: int):
)
except Exception:
logger = LoggerSingleton().get_instance()
logger.on_error("Could not get claim.")
logger.error("Could not get claim.")
raise HTTPException(
status_code=500,
detail=f"Error retrieving claim '{claim_id}' from index '{index_name}'.",
@ -214,11 +214,11 @@ async def get_relationship_info(index_name: str, relationship_id: int):
try:
relationship_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{RELATIONSHIPS_TABLE}",
storage_options=get_pandas_storage_options(),
storage_options=pandas_storage_options(),
)
entity_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{ENTITY_EMBEDDING_TABLE}",
storage_options=get_pandas_storage_options(),
storage_options=pandas_storage_options(),
)
row = relationship_table[
relationship_table.human_readable_id == str(relationship_id)
@ -239,7 +239,7 @@ async def get_relationship_info(index_name: str, relationship_id: int):
)
except Exception:
logger = LoggerSingleton().get_instance()
logger.on_error("Could not get relationship.")
logger.error("Could not get relationship.")
raise HTTPException(
status_code=500,
detail=f"Error retrieving relationship '{relationship_id}' from index '{index_name}'.",

View File

@ -0,0 +1,189 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import asyncio
import inspect
import os
import traceback
import graphrag.api as api
import yaml
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.create_pipeline_config import create_pipeline_config
from src.api.azure_clients import AzureClientManager
from src.logger import (
Logger,
PipelineJobUpdater,
load_pipeline_logger,
)
from src.typing.pipeline import PipelineJobState
from src.utils.common import sanitize_name
from src.utils.pipeline import PipelineJob
def start_indexing_job(index_name: str):
print("Start indexing job...")
# get sanitized name
sanitized_index_name = sanitize_name(index_name)
# update or create new item in container-store in cosmosDB
azure_client_manager = AzureClientManager()
blob_service_client = azure_client_manager.get_blob_service_client()
if not blob_service_client.get_container_client(sanitized_index_name).exists():
blob_service_client.create_container(sanitized_index_name)
cosmos_container_client = azure_client_manager.get_cosmos_container_client(
database="graphrag", container="container-store"
)
cosmos_container_client.upsert_item({
"id": sanitized_index_name,
"human_readable_name": index_name,
"type": "index",
})
print("Initialize pipeline job...")
pipelinejob = PipelineJob()
pipeline_job = pipelinejob.load_item(sanitized_index_name)
sanitized_storage_name = pipeline_job.sanitized_storage_name
storage_name = pipeline_job.human_readable_index_name
# load custom pipeline settings
this_directory = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
)
data = yaml.safe_load(open(f"{this_directory}/settings.yaml"))
# dynamically set some values
data["input"]["container_name"] = sanitized_storage_name
data["storage"]["container_name"] = sanitized_index_name
data["reporting"]["container_name"] = sanitized_index_name
data["cache"]["container_name"] = sanitized_index_name
if "vector_store" in data["embeddings"]:
data["embeddings"]["vector_store"]["collection_name"] = (
f"{sanitized_index_name}_description_embedding"
)
# set prompt for entity extraction
if pipeline_job.entity_extraction_prompt:
fname = "entity-extraction-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.entity_extraction_prompt)
data["entity_extraction"]["prompt"] = fname
else:
data.pop("entity_extraction")
# set prompt for summarize descriptions
if pipeline_job.summarize_descriptions_prompt:
fname = "summarize-descriptions-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.summarize_descriptions_prompt)
data["summarize_descriptions"]["prompt"] = fname
else:
data.pop("summarize_descriptions")
# set prompt for community report
if pipeline_job.community_report_prompt:
fname = "community-report-prompt.txt"
with open(fname, "w") as outfile:
outfile.write(pipeline_job.community_report_prompt)
data["community_reports"]["prompt"] = fname
else:
data.pop("community_reports")
# generate default graphrag config parameters and override with custom settings
parameters = create_graphrag_config(data, ".")
# reset pipeline job details
pipeline_job.status = PipelineJobState.RUNNING
pipeline_job.all_workflows = []
pipeline_job.completed_workflows = []
pipeline_job.failed_workflows = []
pipeline_config = create_pipeline_config(parameters)
for workflow in pipeline_config.workflows:
pipeline_job.all_workflows.append(workflow.name)
# create new loggers/callbacks just for this job
logger_names = []
for logger_type in ["BLOB", "CONSOLE", "APP_INSIGHTS"]:
logger_names.append(Logger[logger_type.upper()])
print("Creating generic loggers...")
logger: WorkflowCallbacks = load_pipeline_logger(
logging_dir=sanitized_index_name,
index_name=index_name,
num_workflow_steps=len(pipeline_job.all_workflows),
loggers=logger_names,
)
# create pipeline job updater to monitor job progress
print("Creating pipeline job updater...")
pipeline_job_updater = PipelineJobUpdater(pipeline_job)
# run the pipeline
try:
print("Building index...")
asyncio.run(
api.build_index(
config=parameters,
callbacks=[logger, pipeline_job_updater],
)
)
print("Index building complete")
# if job is done, check if any pipeline steps failed
if len(pipeline_job.failed_workflows) > 0:
print("Indexing pipeline encountered error.")
pipeline_job.status = PipelineJobState.FAILED
logger.error(
message=f"Indexing pipeline encountered error for index'{index_name}'.",
details={
"index": index_name,
"storage_name": storage_name,
"status_message": "indexing pipeline encountered error",
},
)
else:
print("Indexing pipeline complete.")
# record the pipeline completion
pipeline_job.status = PipelineJobState.COMPLETE
pipeline_job.percent_complete = 100
logger.log(
message=f"Indexing pipeline complete for index'{index_name}'.",
details={
"index": index_name,
"storage_name": storage_name,
"status_message": "indexing pipeline complete",
},
)
pipeline_job.progress = (
f"{len(pipeline_job.completed_workflows)} out of "
f"{len(pipeline_job.all_workflows)} workflows completed successfully."
)
if pipeline_job.status == PipelineJobState.FAILED:
exit(1) # signal to AKS that indexing job failed
except Exception as e:
pipeline_job.status = PipelineJobState.FAILED
# update failed state in cosmos db
error_details = {
"index": index_name,
"storage_name": storage_name,
}
# log error in local index directory logs
logger.error(
message=f"Indexing pipeline failed for index '{index_name}'.",
cause=e,
stack=traceback.format_exc(),
details=error_details,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Build a graphrag index.")
parser.add_argument("-i", "--index-name", required=True)
args = parser.parse_args()
asyncio.run(
start_indexing_job(
index_name=args.index_name,
)
)

View File

@ -17,8 +17,8 @@ llm:
model_supports_json: True
tokens_per_minute: 80_000
requests_per_minute: 480
concurrent_requests: 25
max_retries: 25
concurrent_requests: 50
max_retries: 250
max_retry_wait: 60.0
sleep_on_rate_limit_recommendation: True
@ -45,7 +45,7 @@ embeddings:
deployment_name: $GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME
cognitive_services_endpoint: $GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT
tokens_per_minute: 350_000
requests_per_minute: 2100
requests_per_minute: 2_100
###################### Input settings ######################
input:

View File

@ -7,20 +7,20 @@ from src.logger.application_insights_workflow_callbacks import (
from src.logger.console_workflow_callbacks import ConsoleWorkflowCallbacks
from src.logger.load_logger import load_pipeline_logger
from src.logger.logger_singleton import LoggerSingleton
from src.logger.pipeline_job_workflow_callbacks import PipelineJobWorkflowCallbacks
from src.logger.pipeline_job_updater import PipelineJobUpdater
from src.logger.typing import (
Logger,
PipelineAppInsightsReportingConfig,
PipelineReportingConfigTypes,
Reporters,
)
__all__ = [
"Reporters",
"Logger",
"ApplicationInsightsWorkflowCallbacks",
"ConsoleWorkflowCallbacks",
"LoggerSingleton",
"PipelineAppInsightsReportingConfig",
"PipelineJobWorkflowCallbacks",
"PipelineJobUpdater",
"PipelineReportingConfigTypes",
"load_pipeline_logger",
]

View File

@ -1,28 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import hashlib
import logging
import time
# from dataclasses import asdict
from typing import (
Any,
Dict,
Optional,
)
from azure.monitor.opentelemetry.exporter import AzureMonitorLogExporter
from azure.identity import DefaultAzureCredential
from azure.monitor.opentelemetry import configure_azure_monitor
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from opentelemetry._logs import (
get_logger_provider,
set_logger_provider,
)
from opentelemetry.sdk._logs import (
LoggerProvider,
LoggingHandler,
)
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
class ApplicationInsightsWorkflowCallbacks(NoopWorkflowCallbacks):
@ -31,7 +19,6 @@ class ApplicationInsightsWorkflowCallbacks(NoopWorkflowCallbacks):
_logger: logging.Logger
_logger_name: str
_logger_level: int
_logger_level_name: str
_properties: Dict[str, Any]
_workflow_name: str
_index_name: str
@ -40,9 +27,7 @@ class ApplicationInsightsWorkflowCallbacks(NoopWorkflowCallbacks):
def __init__(
self,
connection_string: str,
logger_name: str | None = None,
logger_level: int = logging.INFO,
logger_name: str = "graphrag-accelerator",
index_name: str = "",
num_workflow_steps: int = 0,
properties: Dict[str, Any] = {},
@ -51,60 +36,31 @@ class ApplicationInsightsWorkflowCallbacks(NoopWorkflowCallbacks):
Initialize the AppInsightsReporter.
Args:
connection_string (str): The connection string for the App Insights instance.
logger_name (str | None, optional): The name of the logger. Defaults to None.
logger_level (int, optional): The logging level. Defaults to logging.INFO.
index_name (str, optional): The name of an index. Defaults to "".
num_workflow_steps (int): A list of workflow names ordered by their execution. Defaults to [].
properties (Dict[str, Any], optional): Additional properties to be included in the log. Defaults to {}.
"""
self._logger: logging.Logger
self._logger_name = logger_name
self._logger_level = logger_level
self._logger_level_name: str = logging.getLevelName(logger_level)
self._properties = properties
self._workflow_name = "N/A"
self._index_name = index_name
self._num_workflow_steps = num_workflow_steps
self._processed_workflow_steps = [] # maintain a running list of workflow steps that get processed
"""Create a new logger with an AppInsights handler."""
self.__init_logger(connection_string=connection_string)
self._properties = properties
self._workflow_name = "N/A"
self._processed_workflow_steps = [] # if logger is used in a pipeline job, maintain a running list of workflows that are processed
# initialize a new logger with an AppInsights handler
self.__init_logger()
def __init_logger(self, connection_string, max_logger_init_retries: int = 10):
max_retry = max_logger_init_retries
while not (hasattr(self, "_logger")):
if max_retry == 0:
raise Exception(
"Failed to create logger. Could not disambiguate logger name."
)
# generate a unique logger name
current_time = str(time.time())
unique_hash = hashlib.sha256(current_time.encode()).hexdigest()
self._logger_name = f"{self.__class__.__name__}-{unique_hash}"
if self._logger_name not in logging.Logger.manager.loggerDict:
# attach azure monitor log exporter to logger provider
logger_provider = LoggerProvider()
set_logger_provider(logger_provider)
exporter = AzureMonitorLogExporter(connection_string=connection_string)
get_logger_provider().add_log_record_processor(
BatchLogRecordProcessor(
exporter=exporter,
schedule_delay_millis=60000,
)
)
# instantiate new logger
self._logger = logging.getLogger(self._logger_name)
self._logger.propagate = False
# remove any existing handlers
self._logger.handlers.clear()
# fetch handler from logger provider and attach to class
self._logger.addHandler(LoggingHandler())
# set logging level
self._logger.setLevel(logging.DEBUG)
# reduce sentinel counter value
max_retry -= 1
def __init_logger(self, max_logger_init_retries: int = 10):
# Configure OpenTelemetry to use Azure Monitor with the
# APPLICATIONINSIGHTS_CONNECTION_STRING environment variable
configure_azure_monitor(
logger_name=self._logger_name,
disable_offline_storage=True,
enable_live_metrics=True,
credential=DefaultAzureCredential(),
)
self._logger = logging.getLogger(self._logger_name)
def _format_details(self, details: Dict[str, Any] | None = None) -> Dict[str, Any]:
"""

View File

@ -7,6 +7,7 @@ from typing import List
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager
from src.api.azure_clients import AzureClientManager
from src.logger.application_insights_workflow_callbacks import (
@ -14,32 +15,32 @@ from src.logger.application_insights_workflow_callbacks import (
)
from src.logger.blob_workflow_callbacks import BlobWorkflowCallbacks
from src.logger.console_workflow_callbacks import ConsoleWorkflowCallbacks
from src.logger.typing import Reporters
from src.logger.typing import Logger
def load_pipeline_logger(
reporting_dir: str | None,
reporters: List[Reporters] | None = [],
logging_dir: str | None,
index_name: str = "",
num_workflow_steps: int = 0,
loggers: List[Logger] = [],
) -> WorkflowCallbacks:
"""Create and load a list of loggers.
Loggers may be configured as generic loggers or associated with a specified indexing job.
"""
# always register the console logger as a fallback option
if Reporters.CONSOLE not in reporters:
reporters.append(Reporters.CONSOLE)
if Logger.CONSOLE not in loggers:
loggers.append(Logger.CONSOLE)
azure_client_manager = AzureClientManager()
logger_callbacks = []
for reporter in reporters:
match reporter:
case Reporters.BLOB:
callback_manager = WorkflowCallbacksManager()
for logger in loggers:
match logger:
case Logger.BLOB:
# create a dedicated container for logs
container_name = "logs"
if reporting_dir is not None:
container_name = os.path.join(reporting_dir, container_name)
if logging_dir is not None:
container_name = os.path.join(logging_dir, container_name)
# ensure the root directory exists; if not, create it
blob_service_client = azure_client_manager.get_blob_service_client()
container_root = Path(container_name).parts[0]
@ -47,8 +48,7 @@ def load_pipeline_logger(
container_root
).exists():
blob_service_client.create_container(container_root)
# register the blob reporter
logger_callbacks.append(
callback_manager.register(
BlobWorkflowCallbacks(
blob_service_client=blob_service_client,
container_name=container_name,
@ -56,25 +56,25 @@ def load_pipeline_logger(
num_workflow_steps=num_workflow_steps,
)
)
case Reporters.FILE:
logger_callbacks.append(FileWorkflowCallbacks(dir=reporting_dir))
case Reporters.APP_INSIGHTS:
if os.getenv("APP_INSIGHTS_CONNECTION_STRING"):
logger_callbacks.append(
case Logger.FILE:
callback_manager.register(FileWorkflowCallbacks(dir=logging_dir))
case Logger.APP_INSIGHTS:
if os.getenv("APPLICATIONINSIGHTS_CONNECTION_STRING"):
callback_manager.register(
ApplicationInsightsWorkflowCallbacks(
connection_string=os.environ[
"APP_INSIGHTS_CONNECTION_STRING"
"APPLICATIONINSIGHTS_CONNECTION_STRING"
],
index_name=index_name,
num_workflow_steps=num_workflow_steps,
)
)
case Reporters.CONSOLE:
logger_callbacks.append(
case Logger.CONSOLE:
callback_manager.register(
ConsoleWorkflowCallbacks(
index_name=index_name, num_workflow_steps=num_workflow_steps
)
)
case _:
print(f"WARNING: unknown reporter type: {reporter}. Skipping.")
return logger_callbacks
print(f"WARNING: unknown logger type: {logger}. Skipping.")
return callback_manager

View File

@ -1,13 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from urllib.parse import urlparse
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from src.logger.load_logger import load_pipeline_logger
from src.logger.typing import Reporters
from src.logger.typing import Logger
class LoggerSingleton:
@ -15,23 +13,9 @@ class LoggerSingleton:
@classmethod
def get_instance(cls) -> WorkflowCallbacks:
if cls._instance is None:
# Set up reporters based on environment variable or defaults
if not cls._instance:
reporters = []
for reporter_name in os.getenv(
"REPORTERS", Reporters.CONSOLE.name.upper()
).split(","):
try:
reporters.append(Reporters[reporter_name.upper()])
except KeyError:
raise ValueError(f"Found unknown reporter: {reporter_name}")
cls._instance = load_pipeline_logger(reporting_dir="", reporters=reporters)
for logger_type in ["BLOB", "CONSOLE", "APP_INSIGHTS"]:
reporters.append(Logger[logger_type])
cls._instance = load_pipeline_logger(logging_dir="", loggers=reporters)
return cls._instance
def _is_valid_url(url: str) -> bool:
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False

View File

@ -7,12 +7,12 @@ from src.typing.pipeline import PipelineJobState
from src.utils.pipeline import PipelineJob
class PipelineJobWorkflowCallbacks(NoopWorkflowCallbacks):
"""A reporter that writes to a stream (sys.stdout)."""
class PipelineJobUpdater(NoopWorkflowCallbacks):
"""A callback that records pipeline updates."""
def __init__(self, pipeline_job: "PipelineJob"):
def __init__(self, pipeline_job: PipelineJob):
"""
This class defines a set of callback methods that can be used to report the progress and status of a workflow job.
This class defines a set of callback methods that can be used to log the progress of a pipeline job.
It inherits from the NoopWorkflowCallbacks class, which provides default implementations for all the callback methods.
Attributes:

View File

@ -12,7 +12,7 @@ from graphrag.index.config.reporting import (
from pydantic import Field as pydantic_Field
class Reporters(Enum):
class Logger(Enum):
BLOB = (1, "blob")
CONSOLE = (2, "console")
FILE = (3, "file")
@ -24,7 +24,7 @@ class PipelineAppInsightsReportingConfig(
):
"""Represents the ApplicationInsights reporting configuration for the pipeline."""
type: Literal["app_insights"] = Reporters.APP_INSIGHTS.name.lower()
type: Literal["app_insights"] = Logger.APP_INSIGHTS.name.lower()
"""The type of reporting."""
connection_string: str = pydantic_Field(

View File

@ -23,7 +23,7 @@ from src.api.azure_clients import AzureClientManager
from src.api.data import data_route
from src.api.graph import graph_route
from src.api.index import index_route
from src.api.index_configuration import index_configuration_route
from src.api.prompt_tuning import prompt_tuning_route
from src.api.query import query_route
from src.api.query_streaming import query_streaming_route
from src.api.source import source_route
@ -37,7 +37,7 @@ async def catch_all_exceptions_middleware(request: Request, call_next):
except Exception as e:
reporter = LoggerSingleton().get_instance()
stack = traceback.format_exc()
reporter.on_error(
reporter.error(
message="Unexpected internal server error",
cause=e,
stack=stack,
@ -82,7 +82,7 @@ async def lifespan(app: FastAPI):
name=pod_name, namespace=os.environ["AKS_NAMESPACE"]
)
# load the cronjob manifest template and update PLACEHOLDER values with correct values using the pod spec
with open("indexing-job-manager-template.yaml", "r") as f:
with open("index-job-manager.yaml", "r") as f:
manifest = yaml.safe_load(f)
manifest["spec"]["jobTemplate"]["spec"]["template"]["spec"]["containers"][0][
"image"
@ -104,7 +104,7 @@ async def lifespan(app: FastAPI):
except Exception as e:
print("Failed to create graphrag cronjob.")
logger = LoggerSingleton().get_instance()
logger.on_error(
logger.error(
message="Failed to create graphrag cronjob",
cause=str(e),
stack=traceback.format_exc(),
@ -133,7 +133,7 @@ app.include_router(data_route)
app.include_router(index_route)
app.include_router(query_route)
app.include_router(query_streaming_route)
app.include_router(index_configuration_route)
app.include_router(prompt_tuning_route)
app.include_router(source_route)
app.include_router(graph_route)

View File

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

View File

@ -5,6 +5,7 @@ import hashlib
import os
import re
import pandas as pd
from azure.cosmos import exceptions
from azure.identity import DefaultAzureCredential
from fastapi import HTTPException
@ -12,7 +13,17 @@ from fastapi import HTTPException
from src.api.azure_clients import AzureClientManager
def get_pandas_storage_options() -> dict:
def get_df(
table_path: str,
) -> pd.DataFrame:
df = pd.read_parquet(
table_path,
storage_options=pandas_storage_options(),
)
return df
def pandas_storage_options() -> dict:
"""Generate the storage options required by pandas to read parquet files from Storage."""
# For more information on the options available, see: https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials
azure_client_manager = AzureClientManager()

View File

@ -9,8 +9,8 @@ from typing import (
from azure.cosmos.exceptions import CosmosHttpResponseError
from src.api.azure_clients import AzureClientManager
from src.api.common import sanitize_name
from src.typing.pipeline import PipelineJobState
from src.utils.common import sanitize_name
@dataclass

View File

@ -1,81 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from graphrag.query.indexer_adapters import (
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from src.api.common import get_pandas_storage_options
def get_entities(
entity_table_path: str,
entity_embedding_table_path: str,
community_level: int = 0,
) -> pd.DataFrame:
storage_options = get_pandas_storage_options()
entity_df = pd.read_parquet(
entity_table_path,
storage_options=storage_options,
)
entity_embedding_df = pd.read_parquet(
entity_embedding_table_path,
storage_options=storage_options,
)
return pd.DataFrame(
read_indexer_entities(entity_df, entity_embedding_df, community_level)
)
def get_reports(
entity_table_path: str, community_report_table_path: str, community_level: int
) -> pd.DataFrame:
storage_options = get_pandas_storage_options()
entity_df = pd.read_parquet(
entity_table_path,
storage_options=storage_options(),
)
report_df = pd.read_parquet(
community_report_table_path,
storage_options=storage_options(),
)
return pd.DataFrame(read_indexer_reports(report_df, entity_df, community_level))
def get_relationships(relationships_table_path: str) -> pd.DataFrame:
relationship_df = pd.read_parquet(
relationships_table_path,
storage_options=get_pandas_storage_options(),
)
return pd.DataFrame(read_indexer_relationships(relationship_df))
def get_covariates(covariate_table_path: str) -> pd.DataFrame:
covariate_df = pd.read_parquet(
covariate_table_path,
storage_options=get_pandas_storage_options(),
)
return pd.DataFrame(read_indexer_covariates(covariate_df))
def get_text_units(text_unit_table_path: str) -> pd.DataFrame:
text_unit_df = pd.read_parquet(
text_unit_table_path,
storage_options=get_pandas_storage_options(),
)
return pd.DataFrame(read_indexer_text_units(text_unit_df))
def get_df(
table_path: str,
) -> pd.DataFrame:
df = pd.read_parquet(
table_path,
storage_options=get_pandas_storage_options(),
)
return df

View File

@ -10,8 +10,8 @@ from azure.cosmos import CosmosClient, PartitionKey
from azure.storage.blob import BlobServiceClient
from fastapi.testclient import TestClient
from src.api.common import sanitize_name
from src.main import app
from src.utils.common import sanitize_name
@pytest.fixture(scope="session")

View File

@ -3,7 +3,7 @@
import pytest
from src.api.common import (
from src.utils.common import (
retrieve_original_blob_container_name,
sanitize_name,
validate_blob_container_name,
@ -12,7 +12,7 @@ from src.api.common import (
def test_validate_blob_container_name():
"""Test the src.api.common.validate_blob_container_name function."""
"""Test the src.utils.common.validate_blob_container_name function."""
# test valid container name
assert validate_blob_container_name("validcontainername") is None
# test invalid container name
@ -33,7 +33,7 @@ def test_validate_blob_container_name():
def test_retrieve_original_blob_container_name(container_with_graphml_file):
"""Test the src.api.common.retrieve_original_blob_container_name function."""
"""Test the src.utils.common.retrieve_original_blob_container_name function."""
# test retrieving a valid container name
original_name = container_with_graphml_file
sanitized_name = sanitize_name(original_name)
@ -43,7 +43,7 @@ def test_retrieve_original_blob_container_name(container_with_graphml_file):
def test_validate_index_file_exist(container_with_graphml_file):
"""Test the src.api.common.validate_index_file_exist function."""
"""Test the src.utils.common.validate_index_file_exist function."""
original_name = container_with_graphml_file
sanitized_name = sanitize_name(original_name)
# test with a valid index and valid file

View File

@ -46,8 +46,8 @@ def test_load_pipeline_logger_with_console(
):
"""Test load_pipeline_logger."""
loggers = load_pipeline_logger(
reporting_dir="logs",
reporters=["app_insights", "blob", "console", "file"],
logging_dir="logs",
loggers=["app_insights", "blob", "console", "file"],
index_name="test-index",
num_workflow_steps=4,
)

View File

@ -37,26 +37,26 @@ def workflow_callbacks(mock_logger):
yield instance
def test_on_workflow_start(workflow_callbacks, mock_logger):
workflow_callbacks.on_workflow_start("test_workflow", object())
def test_workflow_start(workflow_callbacks, mock_logger):
workflow_callbacks.workflow_start("test_workflow", object())
assert mock_logger.info.called
def test_on_workflow_end(workflow_callbacks, mock_logger):
workflow_callbacks.on_workflow_end("test_workflow", object())
def test_workflow_end(workflow_callbacks, mock_logger):
workflow_callbacks.workflow_end("test_workflow", object())
assert mock_logger.info.called
def test_on_log(workflow_callbacks, mock_logger):
workflow_callbacks.on_log("test_log_message")
def test_log(workflow_callbacks, mock_logger):
workflow_callbacks.log("test_log_message")
assert mock_logger.info.called
def test_on_warning(workflow_callbacks, mock_logger):
workflow_callbacks.on_warning("test_warning")
def test_warning(workflow_callbacks, mock_logger):
workflow_callbacks.warning("test_warning")
assert mock_logger.warning.called
def test_on_error(workflow_callbacks, mock_logger):
workflow_callbacks.on_error("test_error", Exception("test_exception"))
def test_error(workflow_callbacks, mock_logger):
workflow_callbacks.error("test_error", Exception("test_exception"))
assert mock_logger.error.called

View File

@ -34,16 +34,16 @@ def workflow_callbacks(mock_blob_service_client):
def test_on_workflow_start(workflow_callbacks):
workflow_callbacks.on_workflow_start("test_workflow", object())
workflow_callbacks.workflow_start("test_workflow", object())
# check if blob workflow callbacks _write_log() method was called
assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called
def test_on_workflow_end(workflow_callbacks):
workflow_callbacks.on_workflow_end("test_workflow", object())
workflow_callbacks.workflow_end("test_workflow", object())
assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called
def test_on_error(workflow_callbacks):
workflow_callbacks.on_error("test_error", Exception("test_exception"))
workflow_callbacks.error("test_error", Exception("test_exception"))
assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called

View File

@ -34,26 +34,26 @@ def workflow_callbacks(mock_logger):
yield instance
def test_on_workflow_start(workflow_callbacks, mock_logger):
workflow_callbacks.on_workflow_start("test_workflow", object())
def test_workflow_start(workflow_callbacks, mock_logger):
workflow_callbacks.workflow_start("test_workflow", object())
assert mock_logger.info.called
def test_on_workflow_end(workflow_callbacks, mock_logger):
workflow_callbacks.on_workflow_end("test_workflow", object())
def test_workflow_end(workflow_callbacks, mock_logger):
workflow_callbacks.workflow_end("test_workflow", object())
assert mock_logger.info.called
def test_on_log(workflow_callbacks, mock_logger):
workflow_callbacks.on_log("test_log_message")
def test_log(workflow_callbacks, mock_logger):
workflow_callbacks.log("test_log_message")
assert mock_logger.info.called
def test_on_warning(workflow_callbacks, mock_logger):
workflow_callbacks.on_warning("test_warning")
def test_warning(workflow_callbacks, mock_logger):
workflow_callbacks.warning("test_warning")
assert mock_logger.warning.called
def test_on_error(workflow_callbacks, mock_logger):
workflow_callbacks.on_error("test_error", Exception("test_exception"))
def test_error(workflow_callbacks, mock_logger):
workflow_callbacks.error("test_error", Exception("test_exception"))
assert mock_logger.error.called

View File

@ -1,8 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# For more information about the base image visit:
# https://mcr.microsoft.com/en-us/artifact/mar/devcontainers/python/about
# For more information about the base image: https://mcr.microsoft.com/en-us/artifact/mar/devcontainers/python/about
FROM mcr.microsoft.com/devcontainers/python:3.10-bookworm
# default graphrag version will be 0.0.0 unless overridden by --build-arg
@ -22,7 +21,6 @@ RUN cd backend \
# download all nltk data that graphrag requires
RUN python -c "import nltk;nltk.download(['punkt','averaged_perceptron_tagger','maxent_ne_chunker','words','wordnet'])"
# download tiktoken model encodings
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt-3.5-turbo'); tiktoken.encoding_for_model('gpt-4'); tiktoken.encoding_for_model('gpt-4o');"

View File

@ -95,7 +95,6 @@ In the `deploy.parameters.json` file, provide values for the following required
`RESOURCE_BASE_NAME` | | No | Suffix to apply to all azure resource names. If not provided a unique suffix will be generated.
`AISEARCH_ENDPOINT_SUFFIX` | | No | Suffix to apply to AI search endpoint. Will default to `search.windows.net` for Azure Commercial cloud but should be overridden for deployments in other Azure clouds.
`AISEARCH_AUDIENCE` | | No | Audience for AAD for AI Search. Will default to `https://search.azure.com/` for Azure Commercial cloud but should be overridden for deployments in other Azure clouds.
`REPORTERS` | blob,console,app_insights | No | The type of logging to enable. A comma separated string containing any of the following values: `[blob,console,file,app_insights]`. Will default to `"blob,console,app_insights"`.
### 5. Deploy solution accelerator to the resource group
```shell

View File

@ -16,7 +16,6 @@ GRAPHRAG_IMAGE=""
PUBLISHER_EMAIL=""
PUBLISHER_NAME=""
RESOURCE_BASE_NAME=""
REPORTERS=""
GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT=""
CONTAINER_REGISTRY_NAME=""
@ -240,10 +239,6 @@ populateOptionalParams () {
if [ ! -z "$RESOURCE_BASE_NAME" ]; then
printf "\tsetting RESOURCE_BASE_NAME=$RESOURCE_BASE_NAME\n"
fi
if [ -z "$REPORTERS" ]; then
REPORTERS="blob,console,app_insights"
printf "\tsetting REPORTERS=blob,console,app_insights\n"
fi
if [ -z "$GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT" ]; then
GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT="https://cognitiveservices.azure.com/.default"
printf "\tsetting GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT=$GRAPHRAG_COGNITIVE_SERVICES_ENDPOINT\n"
@ -440,7 +435,6 @@ installGraphRAGHelmChart () {
exitIfValueEmpty "$graphragImageName" "Unable to parse graphrag image name, exiting..."
exitIfValueEmpty "$graphragImageVersion" "Unable to parse graphrag image version, exiting..."
local escapedReporters=$(sed "s/,/\\\,/g" <<< "$REPORTERS")
reset_x=true
if ! [ -o xtrace ]; then
set -x
@ -455,7 +449,7 @@ installGraphRAGHelmChart () {
--set "master.image.repository=$containerRegistryName/$graphragImageName" \
--set "master.image.tag=$graphragImageVersion" \
--set "ingress.host=$graphragHostname" \
--set "graphragConfig.APP_INSIGHTS_CONNECTION_STRING=$appInsightsConnectionString" \
--set "graphragConfig.APPLICATIONINSIGHTS_CONNECTION_STRING=$appInsightsConnectionString" \
--set "graphragConfig.AI_SEARCH_URL=https://$aiSearchName.$AISEARCH_ENDPOINT_SUFFIX" \
--set "graphragConfig.AI_SEARCH_AUDIENCE=$AISEARCH_AUDIENCE" \
--set "graphragConfig.COSMOS_URI_ENDPOINT=$cosmosEndpoint" \
@ -466,7 +460,6 @@ installGraphRAGHelmChart () {
--set "graphragConfig.GRAPHRAG_LLM_DEPLOYMENT_NAME=$GRAPHRAG_LLM_DEPLOYMENT_NAME" \
--set "graphragConfig.GRAPHRAG_EMBEDDING_MODEL=$GRAPHRAG_EMBEDDING_MODEL" \
--set "graphragConfig.GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME=$GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME" \
--set "graphragConfig.REPORTERS=$escapedReporters" \
--set "graphragConfig.STORAGE_ACCOUNT_BLOB_URL=$storageAccountBlobUrl"
local helmResult=$?

View File

@ -32,7 +32,9 @@ ingress:
graphragConfig:
AI_SEARCH_AUDIENCE: ""
AI_SEARCH_URL: ""
APP_INSIGHTS_CONNECTION_STRING: ""
APPLICATIONINSIGHTS_CONNECTION_STRING: ""
# Must set hidden env variable to true to disable statsbeat. For more information: https://github.com/Azure/azure-sdk-for-python/issues/34804
APPLICATIONINSIGHTS_STATSBEAT_DISABLED_ALL: "True"
COSMOS_URI_ENDPOINT: ""
GRAPHRAG_API_BASE: ""
GRAPHRAG_API_VERSION: ""
@ -41,7 +43,6 @@ graphragConfig:
GRAPHRAG_LLM_DEPLOYMENT_NAME: ""
GRAPHRAG_EMBEDDING_MODEL: ""
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME: ""
REPORTERS: "blob,console,app_insights"
STORAGE_ACCOUNT_BLOB_URL: ""
master:
@ -54,10 +55,12 @@ master:
tag: ""
podAnnotations: {}
podLabels: {}
podSecurityContext: {}
podSecurityContext:
{}
# fsGroup: 2000
securityContext: {}
securityContext:
{}
# capabilities:
# drop:
# - ALL
@ -125,8 +128,8 @@ master:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: workload
operator: In
values:
- graphrag
- matchExpressions:
- key: workload
operator: In
values:
- graphrag