mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-06-27 04:39:57 +00:00
357 lines
13 KiB
Python
357 lines
13 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
from dataclasses import dataclass, field
|
|
from time import time
|
|
from typing import (
|
|
List,
|
|
)
|
|
|
|
from azure.cosmos.exceptions import CosmosHttpResponseError
|
|
|
|
from graphrag_app.typing.pipeline import PipelineJobState
|
|
from graphrag_app.utils.azure_clients import AzureClientManager
|
|
from graphrag_app.utils.common import sanitize_name
|
|
|
|
|
|
@dataclass
|
|
class PipelineJob:
|
|
"""Indexing Pipeline Job metadata
|
|
|
|
This is a custom class where the attributes are stored/retrieved in cosmosdb.
|
|
# TODO: fix the class so initiliazation is not required
|
|
"""
|
|
|
|
_id: str = field(default=None, init=False)
|
|
_epoch_request_time: int = field(default=None, init=False)
|
|
_index_name: str = field(default=None, init=False)
|
|
_human_readable_index_name: str = field(default=None, init=False)
|
|
_sanitized_index_name: str = field(default=None, init=False)
|
|
_human_readable_storage_name: str = field(default=None, init=False)
|
|
_sanitized_storage_name: str = field(default=None, init=False)
|
|
|
|
_all_workflows: List[str] = field(default_factory=list, init=False)
|
|
_completed_workflows: List[str] = field(default_factory=list, init=False)
|
|
_failed_workflows: List[str] = field(default_factory=list, init=False)
|
|
|
|
_status: PipelineJobState = field(default=None, init=False)
|
|
_percent_complete: float = field(default=0, init=False)
|
|
_progress: str = field(default="", init=False)
|
|
|
|
_entity_extraction_prompt: str = field(default=None, init=False)
|
|
_entity_summarization_prompt: str = field(default=None, init=False)
|
|
_community_summarization_prompt: str = field(default=None, init=False)
|
|
|
|
@staticmethod
|
|
def _jobs_container():
|
|
azure_storage_client = AzureClientManager()
|
|
return azure_storage_client.get_cosmos_container_client(
|
|
database="graphrag", container="jobs"
|
|
)
|
|
|
|
@classmethod
|
|
def create_item(
|
|
cls,
|
|
id: str,
|
|
human_readable_index_name: str,
|
|
human_readable_storage_name: str,
|
|
entity_extraction_prompt: str | None = None,
|
|
entity_summarization_prompt: str | None = None,
|
|
community_summarization_prompt: str | None = None,
|
|
**kwargs,
|
|
) -> "PipelineJob":
|
|
"""
|
|
This method creates a new instance of the PipelineJob class and adds it to the database.
|
|
|
|
Args:
|
|
id (str): The ID of the pipeline job.
|
|
index_name (str): The name of the index.
|
|
storage_name (str): The name of the storage.
|
|
entity_extraction_prompt (str): The entity extraction prompt.
|
|
community_prompt (str): The community prompt.
|
|
summarize_descriptions_prompt (str): The prompt for summarizing descriptions.
|
|
all_workflows (List[str]): List of all workflows.
|
|
completed_workflows (List[str]): List of completed workflows.
|
|
failed_workflows (List[str]): List of failed workflows.
|
|
status (PipelineJobState): The status of the pipeline job.
|
|
percent_complete (float): The percentage of completion.
|
|
progress (str): The progress of the pipeline job.
|
|
Returns:
|
|
PipelineJob: The created pipeline job instance.
|
|
"""
|
|
if PipelineJob.item_exist(id):
|
|
raise ValueError(
|
|
f"Pipeline job with ID {id} already exist. "
|
|
"Use PipelineJob.load_item() to create a new pipeline job."
|
|
)
|
|
|
|
assert id is not None, "ID cannot be None."
|
|
assert human_readable_index_name is not None, "index_name cannot be None."
|
|
assert len(human_readable_index_name) > 0, "index_name cannot be empty."
|
|
assert human_readable_storage_name is not None, "storage_name cannot be None."
|
|
assert len(human_readable_storage_name) > 0, "storage_name cannot be empty."
|
|
|
|
instance = cls.__new__(
|
|
cls, id, human_readable_index_name, human_readable_storage_name, **kwargs
|
|
)
|
|
instance._id = id
|
|
instance._epoch_request_time = int(time())
|
|
instance._human_readable_index_name = human_readable_index_name
|
|
instance._sanitized_index_name = sanitize_name(human_readable_index_name)
|
|
instance._human_readable_storage_name = human_readable_storage_name
|
|
instance._sanitized_storage_name = sanitize_name(human_readable_storage_name)
|
|
|
|
instance._all_workflows = kwargs.get("all_workflows", [])
|
|
instance._completed_workflows = kwargs.get("completed_workflows", [])
|
|
instance._failed_workflows = kwargs.get("failed_workflows", [])
|
|
|
|
instance._status = PipelineJobState(
|
|
kwargs.get("status", PipelineJobState.SCHEDULED.value)
|
|
)
|
|
instance._percent_complete = kwargs.get("percent_complete", 0.0)
|
|
instance._progress = kwargs.get("progress", "")
|
|
|
|
instance._entity_extraction_prompt = entity_extraction_prompt
|
|
instance._entity_summarization_prompt = entity_summarization_prompt
|
|
instance._community_summarization_prompt = community_summarization_prompt
|
|
|
|
# Create the item in the database
|
|
instance.update_db()
|
|
return instance
|
|
|
|
@classmethod
|
|
def load_item(cls, id: str) -> "PipelineJob":
|
|
"""
|
|
This method loads an existing pipeline job from the database and returns
|
|
it as an instance of the PipelineJob class.
|
|
|
|
Args:
|
|
id (str): The ID of the pipeline job.
|
|
|
|
Returns:
|
|
PipelineJob: The loaded pipeline job instance.
|
|
"""
|
|
try:
|
|
db_item = PipelineJob._jobs_container().read_item(item=id, partition_key=id)
|
|
except CosmosHttpResponseError:
|
|
raise ValueError(
|
|
f"Pipeline job with ID {id} does not exist. "
|
|
"Use PipelineJob.create_item() to create a new pipeline job."
|
|
)
|
|
instance = cls.__new__(cls, **db_item)
|
|
instance._id = db_item.get("id")
|
|
instance._epoch_request_time = db_item.get("epoch_request_time")
|
|
instance._index_name = db_item.get("index_name")
|
|
instance._human_readable_index_name = db_item.get("human_readable_index_name")
|
|
instance._sanitized_index_name = db_item.get("sanitized_index_name")
|
|
instance._human_readable_storage_name = db_item.get(
|
|
"human_readable_storage_name"
|
|
)
|
|
instance._sanitized_storage_name = db_item.get("sanitized_storage_name")
|
|
|
|
instance._all_workflows = db_item.get("all_workflows", [])
|
|
instance._completed_workflows = db_item.get("completed_workflows", [])
|
|
instance._failed_workflows = db_item.get("failed_workflows", [])
|
|
|
|
instance._status = PipelineJobState(db_item.get("status"))
|
|
instance._percent_complete = db_item.get("percent_complete", 0.0)
|
|
instance._progress = db_item.get("progress", "")
|
|
|
|
instance._entity_extraction_prompt = db_item.get("entity_extraction_prompt")
|
|
instance._entity_summarization_prompt = db_item.get(
|
|
"entity_summarization_prompt"
|
|
)
|
|
instance._community_summarization_prompt = db_item.get(
|
|
"community_summarization_prompt"
|
|
)
|
|
return instance
|
|
|
|
@staticmethod
|
|
def item_exist(id: str) -> bool:
|
|
try:
|
|
PipelineJob._jobs_container().read_item(item=id, partition_key=id)
|
|
return True
|
|
except CosmosHttpResponseError:
|
|
return False
|
|
|
|
def calculate_percent_complete(self) -> float:
|
|
"""
|
|
This method calculates the percentage of completion of the pipeline job.
|
|
|
|
Returns:
|
|
float: The percentage of completion.
|
|
"""
|
|
if len(self.completed_workflows) == 0 or len(self.all_workflows) == 0:
|
|
return 0.0
|
|
return round(
|
|
(len(self.completed_workflows) / len(self.all_workflows)) * 100, ndigits=2
|
|
)
|
|
|
|
def dump_model(self) -> dict:
|
|
model = {
|
|
"id": self._id,
|
|
"epoch_request_time": self._epoch_request_time,
|
|
"human_readable_index_name": self._human_readable_index_name,
|
|
"sanitized_index_name": self._sanitized_index_name,
|
|
"human_readable_storage_name": self._human_readable_storage_name,
|
|
"sanitized_storage_name": self._sanitized_storage_name,
|
|
"all_workflows": self._all_workflows,
|
|
"completed_workflows": self._completed_workflows,
|
|
"failed_workflows": self._failed_workflows,
|
|
"status": self._status.value,
|
|
"percent_complete": self._percent_complete,
|
|
"progress": self._progress,
|
|
}
|
|
if self._entity_extraction_prompt:
|
|
model["entity_extraction_prompt"] = self._entity_extraction_prompt
|
|
if self._entity_summarization_prompt:
|
|
model["entity_summarization_prompt"] = self._entity_summarization_prompt
|
|
if self._community_summarization_prompt:
|
|
model["community_summarization_prompt"] = (
|
|
self._community_summarization_prompt
|
|
)
|
|
return model
|
|
|
|
def update_db(self):
|
|
PipelineJob._jobs_container().upsert_item(body=self.dump_model())
|
|
|
|
@property
|
|
def id(self) -> str:
|
|
return self._id
|
|
|
|
@id.setter
|
|
def id(self, id: str) -> None:
|
|
if self._id is not None:
|
|
self._id = id
|
|
else:
|
|
raise ValueError("ID cannot be changed once set.")
|
|
|
|
@property
|
|
def epoch_request_time(self) -> int:
|
|
return self._epoch_request_time
|
|
|
|
@epoch_request_time.setter
|
|
def epoch_request_time(self, epoch_request_time: int) -> None:
|
|
if self._epoch_request_time is not None:
|
|
self._epoch_request_time = epoch_request_time
|
|
else:
|
|
raise ValueError("ID cannot be changed once set.")
|
|
|
|
@property
|
|
def human_readable_index_name(self) -> str:
|
|
return self._human_readable_index_name
|
|
|
|
@human_readable_index_name.setter
|
|
def human_readable_index_name(self, human_readable_index_name: str) -> None:
|
|
self._human_readable_index_name = human_readable_index_name
|
|
self.update_db()
|
|
|
|
@property
|
|
def sanitized_index_name(self) -> str:
|
|
return self._sanitized_index_name
|
|
|
|
@sanitized_index_name.setter
|
|
def sanitized_index_name(self, sanitized_index_name: str) -> None:
|
|
self._sanitized_index_name = sanitized_index_name
|
|
self.update_db()
|
|
|
|
@property
|
|
def human_readable_storage_name(self) -> str:
|
|
return self._human_readable_storage_name
|
|
|
|
@human_readable_storage_name.setter
|
|
def human_readable_storage_name(self, human_readable_storage_name: str) -> None:
|
|
self._human_readable_storage_name = human_readable_storage_name
|
|
self.update_db()
|
|
|
|
@property
|
|
def sanitized_storage_name(self) -> str:
|
|
return self._sanitized_storage_name
|
|
|
|
@sanitized_storage_name.setter
|
|
def sanitized_storage_name(self, sanitized_storage_name: str) -> None:
|
|
self._sanitized_storage_name = sanitized_storage_name
|
|
self.update_db()
|
|
|
|
@property
|
|
def entity_extraction_prompt(self) -> str:
|
|
return self._entity_extraction_prompt
|
|
|
|
@entity_extraction_prompt.setter
|
|
def entity_extraction_prompt(self, entity_extraction_prompt: str) -> None:
|
|
self._entity_extraction_prompt = entity_extraction_prompt
|
|
self.update_db()
|
|
|
|
@property
|
|
def entity_summarization_prompt(self) -> str:
|
|
return self._entity_summarization_prompt
|
|
|
|
@entity_summarization_prompt.setter
|
|
def entity_summarization_prompt(self, entity_summarization_prompt: str) -> None:
|
|
self._entity_summarization_prompt = entity_summarization_prompt
|
|
self.update_db()
|
|
|
|
@property
|
|
def community_summarization_prompt(self) -> str:
|
|
return self._community_summarization_prompt
|
|
|
|
@community_summarization_prompt.setter
|
|
def community_summarization_prompt(
|
|
self, community_summarization_prompt: str
|
|
) -> None:
|
|
self._community_summarization_prompt = community_summarization_prompt
|
|
self.update_db()
|
|
|
|
@property
|
|
def all_workflows(self) -> List[str]:
|
|
return self._all_workflows
|
|
|
|
@all_workflows.setter
|
|
def all_workflows(self, all_workflows: List[str]) -> None:
|
|
self._all_workflows = all_workflows
|
|
self.update_db()
|
|
|
|
@property
|
|
def completed_workflows(self) -> List[str]:
|
|
return self._completed_workflows
|
|
|
|
@completed_workflows.setter
|
|
def completed_workflows(self, completed_workflows: List[str]) -> None:
|
|
self._completed_workflows = completed_workflows
|
|
self.update_db()
|
|
|
|
@property
|
|
def failed_workflows(self) -> List[str]:
|
|
return self._failed_workflows
|
|
|
|
@failed_workflows.setter
|
|
def failed_workflows(self, failed_workflows: List[str]) -> None:
|
|
self._failed_workflows = failed_workflows
|
|
self.update_db()
|
|
|
|
@property
|
|
def status(self) -> PipelineJobState:
|
|
return self._status
|
|
|
|
@status.setter
|
|
def status(self, status: PipelineJobState) -> None:
|
|
self._status = status
|
|
self.update_db()
|
|
|
|
@property
|
|
def percent_complete(self) -> float:
|
|
return self._percent_complete
|
|
|
|
@percent_complete.setter
|
|
def percent_complete(self, percent_complete: float) -> None:
|
|
self._percent_complete = percent_complete
|
|
self.update_db()
|
|
|
|
@property
|
|
def progress(self) -> str:
|
|
return self._progress
|
|
|
|
@progress.setter
|
|
def progress(self, progress: str) -> None:
|
|
self._progress = progress
|
|
self.update_db()
|