diff --git a/backend/src/api/index_configuration.py b/backend/src/api/index_configuration.py index 43efa34..55c24c3 100644 --- a/backend/src/api/index_configuration.py +++ b/backend/src/api/index_configuration.py @@ -3,16 +3,15 @@ import inspect import os -import shutil import traceback +import graphrag.api as api import yaml from fastapi import ( APIRouter, HTTPException, ) -from fastapi.responses import StreamingResponse -from graphrag.prompt_tune.cli import prompt_tune as generate_fine_tune_prompts +from graphrag.config.create_graphrag_config import create_graphrag_config from src.api.azure_clients import AzureClientManager from src.api.common import ( @@ -27,7 +26,7 @@ index_configuration_route = APIRouter( @index_configuration_route.get( "/prompts", - summary="Generate graphrag prompts from user-provided data.", + summary="Generate prompts from user-provided data.", description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.", ) async def generate_prompts(storage_name: str, limit: int = 5): @@ -44,29 +43,23 @@ async def generate_prompts(storage_name: str, limit: int = 5): status_code=500, detail=f"Data container '{storage_name}' does not exist.", ) + + # load pipeline configuration file (settings.yaml) for input data and other settings this_directory = os.path.dirname( os.path.abspath(inspect.getfile(inspect.currentframe())) ) - - # write custom settings.yaml to a file and store in a temporary directory data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml")) data["input"]["container_name"] = sanitized_storage_name - temp_dir = f"/tmp/{sanitized_storage_name}_prompt_tuning" - shutil.rmtree(temp_dir, ignore_errors=True) - os.makedirs(temp_dir, exist_ok=True) - with open(f"{temp_dir}/settings.yaml", "w") as f: - yaml.dump(data, f, default_flow_style=False) + graphrag_config = create_graphrag_config(values=data, root_dir=".") # generate prompts try: - await generate_fine_tune_prompts( - config=f"{temp_dir}/settings.yaml", - root=temp_dir, - domain="", - selection_method="random", + # NOTE: we need to call api.generate_indexing_prompts + prompts: tuple[str, str, str] = await api.generate_indexing_prompts( + config=graphrag_config, + root=".", limit=limit, - skip_entity_types=True, - output=f"{temp_dir}/prompts", + selection_method="random", ) except Exception as e: logger = LoggerSingleton().get_instance() @@ -84,14 +77,9 @@ async def generate_prompts(storage_name: str, limit: int = 5): detail=f"Error generating prompts for data in '{storage_name}'. Please try a lower limit.", ) - # zip up the generated prompt files and return the zip file - temp_archive = ( - f"{temp_dir}/prompts" # will become a zip file with the name prompts.zip - ) - shutil.make_archive(temp_archive, "zip", root_dir=temp_dir, base_dir="prompts") - - def iterfile(file_path: str): - with open(file_path, mode="rb") as file_like: - yield from file_like - - return StreamingResponse(iterfile(f"{temp_archive}.zip")) + content = { + "entity_extraction_prompt": prompts[0], + "entity_summarization_prompt": prompts[1], + "community_summarization_prompt": prompts[2], + } + return content # return a fastapi.responses.JSONResponse object diff --git a/backend/tests/integration/test_utils_pipeline.py b/backend/tests/integration/test_utils_pipeline.py index a36e774..8cc40c9 100644 --- a/backend/tests/integration/test_utils_pipeline.py +++ b/backend/tests/integration/test_utils_pipeline.py @@ -42,7 +42,9 @@ def cosmos_index_job_entry(cosmos_client) -> Generator[str, None, None]: def test_pipeline_job_interface(cosmos_index_job_entry): + """Test the src.utils.pipeline.PipelineJob class interface.""" pipeline_job = PipelineJob() + # test creating a new entry pipeline_job.create_item( id="synthetic_id", @@ -69,3 +71,36 @@ def test_pipeline_job_interface(cosmos_index_job_entry): assert pipeline_job.status == PipelineJobState.COMPLETE assert pipeline_job.percent_complete == 50.0 assert pipeline_job.progress == "some progress" + assert pipeline_job.calculate_percent_complete() == 50.0 + + # test setters and getters + pipeline_job.id = "newID" + assert pipeline_job.id == "newID" + pipeline_job.epoch_request_time = 1 + assert pipeline_job.epoch_request_time == 1 + + pipeline_job.human_readable_index_name = "new_human_readable_index_name" + assert pipeline_job.human_readable_index_name == "new_human_readable_index_name" + pipeline_job.sanitized_index_name = "new_sanitized_index_name" + assert pipeline_job.sanitized_index_name == "new_sanitized_index_name" + + pipeline_job.human_readable_storage_name = "new_human_readable_storage_name" + assert pipeline_job.human_readable_storage_name == "new_human_readable_storage_name" + pipeline_job.sanitized_storage_name = "new_sanitized_storage_name" + assert pipeline_job.sanitized_storage_name == "new_sanitized_storage_name" + + pipeline_job.entity_extraction_prompt = "new_entity_extraction_prompt" + assert pipeline_job.entity_extraction_prompt == "new_entity_extraction_prompt" + pipeline_job.community_report_prompt = "new_community_report_prompt" + assert pipeline_job.community_report_prompt == "new_community_report_prompt" + pipeline_job.summarize_descriptions_prompt = "new_summarize_descriptions_prompt" + assert pipeline_job.summarize_descriptions_prompt == "new_summarize_descriptions_prompt" + + pipeline_job.all_workflows = ["new_workflow1", "new_workflow2", "new_workflow3"] + assert len(pipeline_job.all_workflows) == 3 + + pipeline_job.completed_workflows = ["new_workflow1", "new_workflow2"] + assert len(pipeline_job.completed_workflows) == 2 + + pipeline_job.failed_workflows = ["new_workflow3"] + assert len(pipeline_job.failed_workflows) == 1 diff --git a/infra/deploy.sh b/infra/deploy.sh index d7fa847..56fdaf8 100755 --- a/infra/deploy.sh +++ b/infra/deploy.sh @@ -347,12 +347,12 @@ deployAzureResources () { --resource-group $RESOURCE_GROUP \ --template-file ./main.bicep \ --parameters "resourceBaseName=$RESOURCE_BASE_NAME" \ - --parameters "graphRagName=$RESOURCE_GROUP" \ + --parameters "resourceGroupName=$RESOURCE_GROUP" \ --parameters "apimName=$APIM_NAME" \ --parameters "apimTier=$APIM_TIER" \ - --parameters "publisherName=$PUBLISHER_NAME" \ + --parameters "apiPublisherName=$PUBLISHER_NAME" \ + --parameters "apiPublisherEmail=$PUBLISHER_EMAIL" \ --parameters "aksSshRsaPublicKey=$SSH_PUBLICKEY" \ - --parameters "publisherEmail=$PUBLISHER_EMAIL" \ --parameters "enablePrivateEndpoints=$ENABLE_PRIVATE_ENDPOINTS" \ --parameters "acrName=$CONTAINER_REGISTRY_NAME" \ --parameters "deployerPrincipalId=$deployerPrincipalId" \