mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-10-13 17:59:37 +00:00
refactor variable names to be more generic and add integration tests
This commit is contained in:
parent
ff5714af1f
commit
0252646d16
@ -3,16 +3,15 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
import graphrag.api as api
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
)
|
)
|
||||||
from fastapi.responses import StreamingResponse
|
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||||
from graphrag.prompt_tune.cli import prompt_tune as generate_fine_tune_prompts
|
|
||||||
|
|
||||||
from src.api.azure_clients import AzureClientManager
|
from src.api.azure_clients import AzureClientManager
|
||||||
from src.api.common import (
|
from src.api.common import (
|
||||||
@ -27,7 +26,7 @@ index_configuration_route = APIRouter(
|
|||||||
|
|
||||||
@index_configuration_route.get(
|
@index_configuration_route.get(
|
||||||
"/prompts",
|
"/prompts",
|
||||||
summary="Generate graphrag prompts from user-provided data.",
|
summary="Generate prompts from user-provided data.",
|
||||||
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
|
description="Generating custom prompts from user-provided data may take several minutes to run based on the amount of data used.",
|
||||||
)
|
)
|
||||||
async def generate_prompts(storage_name: str, limit: int = 5):
|
async def generate_prompts(storage_name: str, limit: int = 5):
|
||||||
@ -44,29 +43,23 @@ async def generate_prompts(storage_name: str, limit: int = 5):
|
|||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"Data container '{storage_name}' does not exist.",
|
detail=f"Data container '{storage_name}' does not exist.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# load pipeline configuration file (settings.yaml) for input data and other settings
|
||||||
this_directory = os.path.dirname(
|
this_directory = os.path.dirname(
|
||||||
os.path.abspath(inspect.getfile(inspect.currentframe()))
|
os.path.abspath(inspect.getfile(inspect.currentframe()))
|
||||||
)
|
)
|
||||||
|
|
||||||
# write custom settings.yaml to a file and store in a temporary directory
|
|
||||||
data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml"))
|
data = yaml.safe_load(open(f"{this_directory}/pipeline-settings.yaml"))
|
||||||
data["input"]["container_name"] = sanitized_storage_name
|
data["input"]["container_name"] = sanitized_storage_name
|
||||||
temp_dir = f"/tmp/{sanitized_storage_name}_prompt_tuning"
|
graphrag_config = create_graphrag_config(values=data, root_dir=".")
|
||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
||||||
os.makedirs(temp_dir, exist_ok=True)
|
|
||||||
with open(f"{temp_dir}/settings.yaml", "w") as f:
|
|
||||||
yaml.dump(data, f, default_flow_style=False)
|
|
||||||
|
|
||||||
# generate prompts
|
# generate prompts
|
||||||
try:
|
try:
|
||||||
await generate_fine_tune_prompts(
|
# NOTE: we need to call api.generate_indexing_prompts
|
||||||
config=f"{temp_dir}/settings.yaml",
|
prompts: tuple[str, str, str] = await api.generate_indexing_prompts(
|
||||||
root=temp_dir,
|
config=graphrag_config,
|
||||||
domain="",
|
root=".",
|
||||||
selection_method="random",
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
skip_entity_types=True,
|
selection_method="random",
|
||||||
output=f"{temp_dir}/prompts",
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger = LoggerSingleton().get_instance()
|
logger = LoggerSingleton().get_instance()
|
||||||
@ -84,14 +77,9 @@ async def generate_prompts(storage_name: str, limit: int = 5):
|
|||||||
detail=f"Error generating prompts for data in '{storage_name}'. Please try a lower limit.",
|
detail=f"Error generating prompts for data in '{storage_name}'. Please try a lower limit.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# zip up the generated prompt files and return the zip file
|
content = {
|
||||||
temp_archive = (
|
"entity_extraction_prompt": prompts[0],
|
||||||
f"{temp_dir}/prompts" # will become a zip file with the name prompts.zip
|
"entity_summarization_prompt": prompts[1],
|
||||||
)
|
"community_summarization_prompt": prompts[2],
|
||||||
shutil.make_archive(temp_archive, "zip", root_dir=temp_dir, base_dir="prompts")
|
}
|
||||||
|
return content # return a fastapi.responses.JSONResponse object
|
||||||
def iterfile(file_path: str):
|
|
||||||
with open(file_path, mode="rb") as file_like:
|
|
||||||
yield from file_like
|
|
||||||
|
|
||||||
return StreamingResponse(iterfile(f"{temp_archive}.zip"))
|
|
||||||
|
@ -42,7 +42,9 @@ def cosmos_index_job_entry(cosmos_client) -> Generator[str, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
def test_pipeline_job_interface(cosmos_index_job_entry):
|
def test_pipeline_job_interface(cosmos_index_job_entry):
|
||||||
|
"""Test the src.utils.pipeline.PipelineJob class interface."""
|
||||||
pipeline_job = PipelineJob()
|
pipeline_job = PipelineJob()
|
||||||
|
|
||||||
# test creating a new entry
|
# test creating a new entry
|
||||||
pipeline_job.create_item(
|
pipeline_job.create_item(
|
||||||
id="synthetic_id",
|
id="synthetic_id",
|
||||||
@ -69,3 +71,36 @@ def test_pipeline_job_interface(cosmos_index_job_entry):
|
|||||||
assert pipeline_job.status == PipelineJobState.COMPLETE
|
assert pipeline_job.status == PipelineJobState.COMPLETE
|
||||||
assert pipeline_job.percent_complete == 50.0
|
assert pipeline_job.percent_complete == 50.0
|
||||||
assert pipeline_job.progress == "some progress"
|
assert pipeline_job.progress == "some progress"
|
||||||
|
assert pipeline_job.calculate_percent_complete() == 50.0
|
||||||
|
|
||||||
|
# test setters and getters
|
||||||
|
pipeline_job.id = "newID"
|
||||||
|
assert pipeline_job.id == "newID"
|
||||||
|
pipeline_job.epoch_request_time = 1
|
||||||
|
assert pipeline_job.epoch_request_time == 1
|
||||||
|
|
||||||
|
pipeline_job.human_readable_index_name = "new_human_readable_index_name"
|
||||||
|
assert pipeline_job.human_readable_index_name == "new_human_readable_index_name"
|
||||||
|
pipeline_job.sanitized_index_name = "new_sanitized_index_name"
|
||||||
|
assert pipeline_job.sanitized_index_name == "new_sanitized_index_name"
|
||||||
|
|
||||||
|
pipeline_job.human_readable_storage_name = "new_human_readable_storage_name"
|
||||||
|
assert pipeline_job.human_readable_storage_name == "new_human_readable_storage_name"
|
||||||
|
pipeline_job.sanitized_storage_name = "new_sanitized_storage_name"
|
||||||
|
assert pipeline_job.sanitized_storage_name == "new_sanitized_storage_name"
|
||||||
|
|
||||||
|
pipeline_job.entity_extraction_prompt = "new_entity_extraction_prompt"
|
||||||
|
assert pipeline_job.entity_extraction_prompt == "new_entity_extraction_prompt"
|
||||||
|
pipeline_job.community_report_prompt = "new_community_report_prompt"
|
||||||
|
assert pipeline_job.community_report_prompt == "new_community_report_prompt"
|
||||||
|
pipeline_job.summarize_descriptions_prompt = "new_summarize_descriptions_prompt"
|
||||||
|
assert pipeline_job.summarize_descriptions_prompt == "new_summarize_descriptions_prompt"
|
||||||
|
|
||||||
|
pipeline_job.all_workflows = ["new_workflow1", "new_workflow2", "new_workflow3"]
|
||||||
|
assert len(pipeline_job.all_workflows) == 3
|
||||||
|
|
||||||
|
pipeline_job.completed_workflows = ["new_workflow1", "new_workflow2"]
|
||||||
|
assert len(pipeline_job.completed_workflows) == 2
|
||||||
|
|
||||||
|
pipeline_job.failed_workflows = ["new_workflow3"]
|
||||||
|
assert len(pipeline_job.failed_workflows) == 1
|
||||||
|
@ -347,12 +347,12 @@ deployAzureResources () {
|
|||||||
--resource-group $RESOURCE_GROUP \
|
--resource-group $RESOURCE_GROUP \
|
||||||
--template-file ./main.bicep \
|
--template-file ./main.bicep \
|
||||||
--parameters "resourceBaseName=$RESOURCE_BASE_NAME" \
|
--parameters "resourceBaseName=$RESOURCE_BASE_NAME" \
|
||||||
--parameters "graphRagName=$RESOURCE_GROUP" \
|
--parameters "resourceGroupName=$RESOURCE_GROUP" \
|
||||||
--parameters "apimName=$APIM_NAME" \
|
--parameters "apimName=$APIM_NAME" \
|
||||||
--parameters "apimTier=$APIM_TIER" \
|
--parameters "apimTier=$APIM_TIER" \
|
||||||
--parameters "publisherName=$PUBLISHER_NAME" \
|
--parameters "apiPublisherName=$PUBLISHER_NAME" \
|
||||||
|
--parameters "apiPublisherEmail=$PUBLISHER_EMAIL" \
|
||||||
--parameters "aksSshRsaPublicKey=$SSH_PUBLICKEY" \
|
--parameters "aksSshRsaPublicKey=$SSH_PUBLICKEY" \
|
||||||
--parameters "publisherEmail=$PUBLISHER_EMAIL" \
|
|
||||||
--parameters "enablePrivateEndpoints=$ENABLE_PRIVATE_ENDPOINTS" \
|
--parameters "enablePrivateEndpoints=$ENABLE_PRIVATE_ENDPOINTS" \
|
||||||
--parameters "acrName=$CONTAINER_REGISTRY_NAME" \
|
--parameters "acrName=$CONTAINER_REGISTRY_NAME" \
|
||||||
--parameters "deployerPrincipalId=$deployerPrincipalId" \
|
--parameters "deployerPrincipalId=$deployerPrincipalId" \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user