update notebook

This commit is contained in:
Josh Bradley 2025-01-26 03:14:30 -05:00
parent d0273c3d75
commit 46b2a8c2ec

View File

@ -61,18 +61,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"id": "4", "id": "4",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import getpass\n", "import getpass\n",
"import json\n", "import json\n",
"import os\n",
"import sys\n", "import sys\n",
"import time\n", "import time\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"from zipfile import ZipFile\n",
"\n", "\n",
"import magic\n", "import magic\n",
"import pandas as pd\n", "import pandas as pd\n",
@ -102,7 +100,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"id": "7", "id": "7",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -134,7 +132,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 26,
"id": "9", "id": "9",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -161,7 +159,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 27,
"id": "10", "id": "10",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -183,7 +181,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 28,
"id": "12", "id": "12",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -274,28 +272,24 @@
"def build_index(\n", "def build_index(\n",
" storage_name: str,\n", " storage_name: str,\n",
" index_name: str,\n", " index_name: str,\n",
" entity_extraction_prompt_filepath: str = None,\n", " entity_extraction_prompt: str = None,\n",
" community_prompt_filepath: str = None,\n", " entity_summarization_prompt: str = None,\n",
" summarize_description_prompt_filepath: str = None,\n", " community_summarization_prompt: str = None,\n",
") -> requests.Response:\n", ") -> requests.Response:\n",
" \"\"\"Create a search index.\n", " \"\"\"Build a graphrag index.\n",
" This function kicks off a job that builds a knowledge graph (KG) index from files located in a blob storage container.\n", " This function submits a job that builds a graphrag index (i.e. a knowledge graph) from data files located in a blob storage container.\n",
" \"\"\"\n", " \"\"\"\n",
" url = endpoint + \"/index\"\n", " url = endpoint + \"/index\"\n",
" prompt_files = dict()\n", " prompts = dict()\n",
" if entity_extraction_prompt_filepath:\n", " if entity_extraction_prompt:\n",
" prompt_files[\"entity_extraction_prompt\"] = open(\n", " prompts[\"entity_extraction_prompt\"] = entity_extraction_prompt\n",
" entity_extraction_prompt_filepath, \"r\"\n", " if entity_summarization_prompt:\n",
" )\n", " prompts[\"summarize_descriptions_prompt\"] = entity_summarization_prompt\n",
" if community_prompt_filepath:\n", " if community_summarization_prompt:\n",
" prompt_files[\"community_report_prompt\"] = open(community_prompt_filepath, \"r\")\n", " prompts[\"community_report_prompt\"] = community_summarization_prompt\n",
" if summarize_description_prompt_filepath:\n",
" prompt_files[\"summarize_descriptions_prompt\"] = open(\n",
" summarize_description_prompt_filepath, \"r\"\n",
" )\n",
" return requests.post(\n", " return requests.post(\n",
" url,\n", " url,\n",
" files=prompt_files if len(prompt_files) > 0 else None,\n", " files=prompts if len(prompts) > 0 else None,\n",
" params={\"index_name\": index_name, \"storage_name\": storage_name},\n", " params={\"index_name\": index_name, \"storage_name\": storage_name},\n",
" headers=headers,\n", " headers=headers,\n",
" )\n", " )\n",
@ -475,15 +469,11 @@
" return response\n", " return response\n",
"\n", "\n",
"\n", "\n",
"def generate_prompts(storage_name: str, zip_file_name: str, limit: int = 1) -> None:\n", "def generate_prompts(storage_name: str, limit: int = 1) -> None:\n",
" \"\"\"Generate graphrag prompts using data provided in a specific storage container.\"\"\"\n", " \"\"\"Generate graphrag prompts using data provided in a specific storage container.\"\"\"\n",
" url = endpoint + \"/index/config/prompts\"\n", " url = endpoint + \"/index/config/prompts\"\n",
" params = {\"storage_name\": storage_name, \"limit\": limit}\n", " params = {\"storage_name\": storage_name, \"limit\": limit}\n",
" with requests.get(url, params=params, headers=headers, stream=True) as r:\n", " return requests.get(url, params=params, headers=headers)"
" r.raise_for_status()\n",
" with open(zip_file_name, \"wb\") as f:\n",
" for chunk in r.iter_content():\n",
" f.write(chunk)"
] ]
}, },
{ {
@ -573,14 +563,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 9,
"id": "20", "id": "20",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"generate_prompts(storage_name=storage_name, limit=1, zip_file_name=\"prompts.zip\")\n", "auto_template_response = generate_prompts(storage_name=storage_name, limit=1)\n",
"with ZipFile(\"prompts.zip\", \"r\") as zip_ref:\n", "if auto_template_response.ok:\n",
" zip_ref.extractall()" " prompts = auto_template_response.json()"
] ]
}, },
{ {
@ -618,30 +608,20 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# check if prompt files exist\n", "# check if custom prompts were generated\n",
"entity_extraction_prompt_filepath = \"prompts/entity_extraction.txt\"\n", "if \"auto_template_response\" in locals() and auto_template_response.ok:\n",
"community_prompt_filepath = \"prompts/community_report.txt\"\n", " entity_extraction_prompt = prompts[\"entity_extraction_prompt\"]\n",
"summarize_description_prompt_filepath = \"prompts/summarize_descriptions.txt\"\n", " community_summarization_prompt = prompts[\"community_summarization_prompt\"]\n",
"entity_prompt = (\n", " summarize_description_prompt = prompts[\"entity_summarization_prompt\"]\n",
" entity_extraction_prompt_filepath\n", "else:\n",
" if os.path.isfile(entity_extraction_prompt_filepath)\n", " entity_extraction_prompt = community_summarization_prompt = summarize_description_prompt = None\n",
" else None\n",
")\n",
"community_prompt = (\n",
" community_prompt_filepath if os.path.isfile(community_prompt_filepath) else None\n",
")\n",
"summarize_prompt = (\n",
" summarize_description_prompt_filepath\n",
" if os.path.isfile(summarize_description_prompt_filepath)\n",
" else None\n",
")\n",
"\n", "\n",
"response = build_index(\n", "response = build_index(\n",
" storage_name=storage_name,\n", " storage_name=storage_name,\n",
" index_name=index_name,\n", " index_name=index_name,\n",
" entity_extraction_prompt_filepath=entity_prompt,\n", " entity_extraction_prompt=entity_extraction_prompt,\n",
" community_prompt_filepath=community_prompt,\n", " community_summarization_prompt=community_summarization_prompt,\n",
" summarize_description_prompt_filepath=summarize_prompt,\n", " entity_summarization_prompt=summarize_description_prompt,\n",
")\n", ")\n",
"if response.ok:\n", "if response.ok:\n",
" pprint(response.json())\n", " pprint(response.json())\n",
@ -717,7 +697,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 40,
"id": "31", "id": "31",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],