update notebook

This commit is contained in:
Josh Bradley 2025-01-26 03:14:30 -05:00
parent d0273c3d75
commit 46b2a8c2ec

View File

@ -61,18 +61,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import json\n",
"import os\n",
"import sys\n",
"import time\n",
"from pathlib import Path\n",
"from zipfile import ZipFile\n",
"\n",
"import magic\n",
"import pandas as pd\n",
@ -102,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "7",
"metadata": {
"tags": []
@ -134,7 +132,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"id": "9",
"metadata": {
"tags": []
@ -161,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 27,
"id": "10",
"metadata": {},
"outputs": [],
@ -183,7 +181,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 28,
"id": "12",
"metadata": {
"tags": []
@ -274,28 +272,24 @@
"def build_index(\n",
" storage_name: str,\n",
" index_name: str,\n",
" entity_extraction_prompt_filepath: str = None,\n",
" community_prompt_filepath: str = None,\n",
" summarize_description_prompt_filepath: str = None,\n",
" entity_extraction_prompt: str = None,\n",
" entity_summarization_prompt: str = None,\n",
" community_summarization_prompt: str = None,\n",
") -> requests.Response:\n",
" \"\"\"Create a search index.\n",
" This function kicks off a job that builds a knowledge graph (KG) index from files located in a blob storage container.\n",
" \"\"\"Build a graphrag index.\n",
" This function submits a job that builds a graphrag index (i.e. a knowledge graph) from data files located in a blob storage container.\n",
" \"\"\"\n",
" url = endpoint + \"/index\"\n",
" prompt_files = dict()\n",
" if entity_extraction_prompt_filepath:\n",
" prompt_files[\"entity_extraction_prompt\"] = open(\n",
" entity_extraction_prompt_filepath, \"r\"\n",
" )\n",
" if community_prompt_filepath:\n",
" prompt_files[\"community_report_prompt\"] = open(community_prompt_filepath, \"r\")\n",
" if summarize_description_prompt_filepath:\n",
" prompt_files[\"summarize_descriptions_prompt\"] = open(\n",
" summarize_description_prompt_filepath, \"r\"\n",
" )\n",
" prompts = dict()\n",
" if entity_extraction_prompt:\n",
" prompts[\"entity_extraction_prompt\"] = entity_extraction_prompt\n",
" if entity_summarization_prompt:\n",
" prompts[\"summarize_descriptions_prompt\"] = entity_summarization_prompt\n",
" if community_summarization_prompt:\n",
" prompts[\"community_report_prompt\"] = community_summarization_prompt\n",
" return requests.post(\n",
" url,\n",
" files=prompt_files if len(prompt_files) > 0 else None,\n",
" files=prompts if len(prompts) > 0 else None,\n",
" params={\"index_name\": index_name, \"storage_name\": storage_name},\n",
" headers=headers,\n",
" )\n",
@ -475,15 +469,11 @@
" return response\n",
"\n",
"\n",
"def generate_prompts(storage_name: str, zip_file_name: str, limit: int = 1) -> None:\n",
"def generate_prompts(storage_name: str, limit: int = 1) -> None:\n",
" \"\"\"Generate graphrag prompts using data provided in a specific storage container.\"\"\"\n",
" url = endpoint + \"/index/config/prompts\"\n",
" params = {\"storage_name\": storage_name, \"limit\": limit}\n",
" with requests.get(url, params=params, headers=headers, stream=True) as r:\n",
" r.raise_for_status()\n",
" with open(zip_file_name, \"wb\") as f:\n",
" for chunk in r.iter_content():\n",
" f.write(chunk)"
" return requests.get(url, params=params, headers=headers)"
]
},
{
@ -573,14 +563,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "20",
"metadata": {},
"outputs": [],
"source": [
"generate_prompts(storage_name=storage_name, limit=1, zip_file_name=\"prompts.zip\")\n",
"with ZipFile(\"prompts.zip\", \"r\") as zip_ref:\n",
" zip_ref.extractall()"
"auto_template_response = generate_prompts(storage_name=storage_name, limit=1)\n",
"if auto_template_response.ok:\n",
" prompts = auto_template_response.json()"
]
},
{
@ -618,30 +608,20 @@
},
"outputs": [],
"source": [
"# check if prompt files exist\n",
"entity_extraction_prompt_filepath = \"prompts/entity_extraction.txt\"\n",
"community_prompt_filepath = \"prompts/community_report.txt\"\n",
"summarize_description_prompt_filepath = \"prompts/summarize_descriptions.txt\"\n",
"entity_prompt = (\n",
" entity_extraction_prompt_filepath\n",
" if os.path.isfile(entity_extraction_prompt_filepath)\n",
" else None\n",
")\n",
"community_prompt = (\n",
" community_prompt_filepath if os.path.isfile(community_prompt_filepath) else None\n",
")\n",
"summarize_prompt = (\n",
" summarize_description_prompt_filepath\n",
" if os.path.isfile(summarize_description_prompt_filepath)\n",
" else None\n",
")\n",
"# check if custom prompts were generated\n",
"if \"auto_template_response\" in locals() and auto_template_response.ok:\n",
" entity_extraction_prompt = prompts[\"entity_extraction_prompt\"]\n",
" community_summarization_prompt = prompts[\"community_summarization_prompt\"]\n",
" summarize_description_prompt = prompts[\"entity_summarization_prompt\"]\n",
"else:\n",
" entity_extraction_prompt = community_summarization_prompt = summarize_description_prompt = None\n",
"\n",
"response = build_index(\n",
" storage_name=storage_name,\n",
" index_name=index_name,\n",
" entity_extraction_prompt_filepath=entity_prompt,\n",
" community_prompt_filepath=community_prompt,\n",
" summarize_description_prompt_filepath=summarize_prompt,\n",
" entity_extraction_prompt=entity_extraction_prompt,\n",
" community_summarization_prompt=community_summarization_prompt,\n",
" entity_summarization_prompt=summarize_description_prompt,\n",
")\n",
"if response.ok:\n",
" pprint(response.json())\n",
@ -717,7 +697,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 40,
"id": "31",
"metadata": {},
"outputs": [],