mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-10-12 01:08:54 +00:00
refactor variable names to be more generic and add integration tests
This commit is contained in:
parent
ff5714af1f
commit
0252646d16
@ -3,16 +3,15 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import shutil
|
||||
import traceback
|
||||
|
||||
import graphrag.api as api
|
||||
import yaml
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
HTTPException,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from graphrag.prompt_tune.cli import prompt_tune as generate_fine_tune_prompts
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
|
||||
from src.api.azure_clients import AzureClientManager
|
||||
from src.api.common import (
|
||||
@ -27,7 +26,7 @@ index_configuration_route = APIRouter(
|
||||
|
||||
@index_configuration_route.get(
|
||||
"/prompts",
|
||||
summary="Generate graphrag prompts from user-provided data.",
|
||||
summary="Generate prompts from user-provided data.",
|
||||
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
|
||||
)
|
||||
async def generate_prompts(storage_name: str, limit: int = 5):
|
||||
@ -44,29 +43,23 @@ async def generate_prompts(storage_name: str, limit: int = 5):
|
||||
status_code=500,
|
||||
detail=f"Data container '{storage_name}' does not exist.",
|
||||
)
|
||||
|
||||
# load pipeline configuration file (settings.yaml) for input data and other settings
|
||||
this_directory = os.path.dirname(
|
||||
os.path.abspath(inspect.getfile(inspect.currentframe()))
|
||||
)
|
||||
|
||||
# write custom settings.yaml to a file and store in a temporary directory
|
||||
data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml"))
|
||||
data["input"]["container_name"] = sanitized_storage_name
|
||||
temp_dir = f"/tmp/{sanitized_storage_name}_prompt_tuning"
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
with open(f"{temp_dir}/settings.yaml", "w") as f:
|
||||
yaml.dump(data, f, default_flow_style=False)
|
||||
graphrag_config = create_graphrag_config(values=data, root_dir=".")
|
||||
|
||||
# generate prompts
|
||||
try:
|
||||
await generate_fine_tune_prompts(
|
||||
config=f"{temp_dir}/settings.yaml",
|
||||
root=temp_dir,
|
||||
domain="",
|
||||
selection_method="random",
|
||||
# NOTE: we need to call api.generate_indexing_prompts
|
||||
prompts: tuple[str, str, str] = await api.generate_indexing_prompts(
|
||||
config=graphrag_config,
|
||||
root=".",
|
||||
limit=limit,
|
||||
skip_entity_types=True,
|
||||
output=f"{temp_dir}/prompts",
|
||||
selection_method="random",
|
||||
)
|
||||
except Exception as e:
|
||||
logger = LoggerSingleton().get_instance()
|
||||
@ -84,14 +77,9 @@ async def generate_prompts(storage_name: str, limit: int = 5):
|
||||
detail=f"Error generating prompts for data in '{storage_name}'. Please try a lower limit.",
|
||||
)
|
||||
|
||||
# zip up the generated prompt files and return the zip file
|
||||
temp_archive = (
|
||||
f"{temp_dir}/prompts" # will become a zip file with the name prompts.zip
|
||||
)
|
||||
shutil.make_archive(temp_archive, "zip", root_dir=temp_dir, base_dir="prompts")
|
||||
|
||||
def iterfile(file_path: str):
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
|
||||
return StreamingResponse(iterfile(f"{temp_archive}.zip"))
|
||||
content = {
|
||||
"entity_extraction_prompt": prompts[0],
|
||||
"entity_summarization_prompt": prompts[1],
|
||||
"community_summarization_prompt": prompts[2],
|
||||
}
|
||||
return content # return a fastapi.responses.JSONResponse object
|
||||
|
@ -42,7 +42,9 @@ def cosmos_index_job_entry(cosmos_client) -> Generator[str, None, None]:
|
||||
|
||||
|
||||
def test_pipeline_job_interface(cosmos_index_job_entry):
|
||||
"""Test the src.utils.pipeline.PipelineJob class interface."""
|
||||
pipeline_job = PipelineJob()
|
||||
|
||||
# test creating a new entry
|
||||
pipeline_job.create_item(
|
||||
id="synthetic_id",
|
||||
@ -69,3 +71,36 @@ def test_pipeline_job_interface(cosmos_index_job_entry):
|
||||
assert pipeline_job.status == PipelineJobState.COMPLETE
|
||||
assert pipeline_job.percent_complete == 50.0
|
||||
assert pipeline_job.progress == "some progress"
|
||||
assert pipeline_job.calculate_percent_complete() == 50.0
|
||||
|
||||
# test setters and getters
|
||||
pipeline_job.id = "newID"
|
||||
assert pipeline_job.id == "newID"
|
||||
pipeline_job.epoch_request_time = 1
|
||||
assert pipeline_job.epoch_request_time == 1
|
||||
|
||||
pipeline_job.human_readable_index_name = "new_human_readable_index_name"
|
||||
assert pipeline_job.human_readable_index_name == "new_human_readable_index_name"
|
||||
pipeline_job.sanitized_index_name = "new_sanitized_index_name"
|
||||
assert pipeline_job.sanitized_index_name == "new_sanitized_index_name"
|
||||
|
||||
pipeline_job.human_readable_storage_name = "new_human_readable_storage_name"
|
||||
assert pipeline_job.human_readable_storage_name == "new_human_readable_storage_name"
|
||||
pipeline_job.sanitized_storage_name = "new_sanitized_storage_name"
|
||||
assert pipeline_job.sanitized_storage_name == "new_sanitized_storage_name"
|
||||
|
||||
pipeline_job.entity_extraction_prompt = "new_entity_extraction_prompt"
|
||||
assert pipeline_job.entity_extraction_prompt == "new_entity_extraction_prompt"
|
||||
pipeline_job.community_report_prompt = "new_community_report_prompt"
|
||||
assert pipeline_job.community_report_prompt == "new_community_report_prompt"
|
||||
pipeline_job.summarize_descriptions_prompt = "new_summarize_descriptions_prompt"
|
||||
assert pipeline_job.summarize_descriptions_prompt == "new_summarize_descriptions_prompt"
|
||||
|
||||
pipeline_job.all_workflows = ["new_workflow1", "new_workflow2", "new_workflow3"]
|
||||
assert len(pipeline_job.all_workflows) == 3
|
||||
|
||||
pipeline_job.completed_workflows = ["new_workflow1", "new_workflow2"]
|
||||
assert len(pipeline_job.completed_workflows) == 2
|
||||
|
||||
pipeline_job.failed_workflows = ["new_workflow3"]
|
||||
assert len(pipeline_job.failed_workflows) == 1
|
||||
|
@ -347,12 +347,12 @@ deployAzureResources () {
|
||||
--resource-group $RESOURCE_GROUP \
|
||||
--template-file ./main.bicep \
|
||||
--parameters "resourceBaseName=$RESOURCE_BASE_NAME" \
|
||||
--parameters "graphRagName=$RESOURCE_GROUP" \
|
||||
--parameters "resourceGroupName=$RESOURCE_GROUP" \
|
||||
--parameters "apimName=$APIM_NAME" \
|
||||
--parameters "apimTier=$APIM_TIER" \
|
||||
--parameters "publisherName=$PUBLISHER_NAME" \
|
||||
--parameters "apiPublisherName=$PUBLISHER_NAME" \
|
||||
--parameters "apiPublisherEmail=$PUBLISHER_EMAIL" \
|
||||
--parameters "aksSshRsaPublicKey=$SSH_PUBLICKEY" \
|
||||
--parameters "publisherEmail=$PUBLISHER_EMAIL" \
|
||||
--parameters "enablePrivateEndpoints=$ENABLE_PRIVATE_ENDPOINTS" \
|
||||
--parameters "acrName=$CONTAINER_REGISTRY_NAME" \
|
||||
--parameters "deployerPrincipalId=$deployerPrincipalId" \
|
||||
|
Loading…
x
Reference in New Issue
Block a user