graphrag/docs/examples_notebooks/multi_index_search.ipynb
Nathan Evans 321d479ab6
Update notebooks for 2.0 (#1785)
* Update API overview

* Fix global search example

* Fix local search example

* Fix global dynamic example

* Fix drift example

* Update multi-index example

* Semver
2025-03-11 17:23:49 -07:00

559 lines
17 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) 2024 Microsoft Corporation.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multi Index Search\n",
"This notebook demonstrates multi-index search using the GraphRAG API.\n",
"\n",
"Indexes created from Wikipedia state articles for Alaska, California, DC, Maryland, NY and Washington are used."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"\n",
"import pandas as pd\n",
"\n",
"from graphrag.api.query import (\n",
" multi_index_basic_search,\n",
" multi_index_drift_search,\n",
" multi_index_global_search,\n",
" multi_index_local_search,\n",
")\n",
"from graphrag.config.create_graphrag_config import create_graphrag_config\n",
"\n",
"indexes = [\"alaska\", \"california\", \"dc\", \"maryland\", \"ny\", \"washington\"]\n",
"indexes = sorted(indexes)\n",
"\n",
"print(indexes)\n",
"\n",
"vector_store_configs = {\n",
" index: {\n",
" \"type\": \"lancedb\",\n",
" \"db_uri\": f\"inputs/{index}/lancedb\",\n",
" \"container_name\": \"default\",\n",
" \"overwrite\": True,\n",
" \"index_name\": f\"{index}\",\n",
" }\n",
" for index in indexes\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config_data = {\n",
" \"models\": {\n",
" \"default_chat_model\": {\n",
" \"model_supports_json\": True,\n",
" \"parallelization_num_threads\": 50,\n",
" \"parallelization_stagger\": 0.3,\n",
" \"async_mode\": \"threaded\",\n",
" \"type\": \"azure_openai_chat\",\n",
" \"model\": \"gpt-4o\",\n",
" \"auth_type\": \"azure_managed_identity\",\n",
" \"api_base\": \"<API_BASE_URL>\",\n",
" \"api_version\": \"2024-02-15-preview\",\n",
" \"deployment_name\": \"gpt-4o\",\n",
" },\n",
" \"default_embedding_model\": {\n",
" \"parallelization_num_threads\": 50,\n",
" \"parallelization_stagger\": 0.3,\n",
" \"async_mode\": \"threaded\",\n",
" \"type\": \"azure_openai_embedding\",\n",
" \"model\": \"text-embedding-3-large\",\n",
" \"auth_type\": \"azure_managed_identity\",\n",
" \"api_base\": \"<API_BASE_URL>\",\n",
" \"api_version\": \"2024-02-15-preview\",\n",
" \"deployment_name\": \"text-embedding-3-large\",\n",
" },\n",
" },\n",
" \"vector_store\": vector_store_configs,\n",
" \"local_search\": {\n",
" \"prompt\": \"prompts/local_search_system_prompt.txt\",\n",
" \"llm_max_tokens\": 12000,\n",
" },\n",
" \"global_search\": {\n",
" \"map_prompt\": \"prompts/global_search_map_system_prompt.txt\",\n",
" \"reduce_prompt\": \"prompts/global_search_reduce_system_prompt.txt\",\n",
" \"knowledge_prompt\": \"prompts/global_search_knowledge_system_prompt.txt\",\n",
" },\n",
" \"drift_search\": {\n",
" \"prompt\": \"prompts/drift_search_system_prompt.txt\",\n",
" \"reduce_prompt\": \"prompts/drift_search_reduce_prompt.txt\",\n",
" },\n",
" \"basic_search\": {\"prompt\": \"prompts/basic_search_system_prompt.txt\"},\n",
"}\n",
"parameters = create_graphrag_config(config_data, \".\")\n",
"loop = asyncio.get_event_loop()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-index Global Search"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n",
"communities = [\n",
" pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n",
"]\n",
"community_reports = [\n",
" pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n",
"]\n",
"\n",
"task = loop.create_task(\n",
" multi_index_global_search(\n",
" parameters,\n",
" entities,\n",
" communities,\n",
" community_reports,\n",
" indexes,\n",
" 1,\n",
" False,\n",
" \"Multiple Paragraphs\",\n",
" False,\n",
" \"Describe this dataset.\",\n",
" )\n",
")\n",
"results = await task"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Print report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(results[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Show context links back to original index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for report_id in [120, 129, 40, 16, 204, 143, 85, 122, 83]:\n",
" index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(report_id, index_name, index_id)\n",
" index_reports = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_community_reports.parquet\"\n",
" )\n",
" print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n",
" print(\n",
" index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n",
" 0\n",
" ]\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Multi-index Local Search"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n",
"communities = [\n",
" pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n",
"]\n",
"community_reports = [\n",
" pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n",
"]\n",
"covariates = [\n",
" pd.read_parquet(f\"inputs/{index}/covariates.parquet\") for index in indexes\n",
"]\n",
"text_units = [\n",
" pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n",
"]\n",
"relationships = [\n",
" pd.read_parquet(f\"inputs/{index}/relationships.parquet\") for index in indexes\n",
"]\n",
"\n",
"task = loop.create_task(\n",
" multi_index_local_search(\n",
" parameters,\n",
" entities,\n",
" communities,\n",
" community_reports,\n",
" text_units,\n",
" relationships,\n",
" covariates,\n",
" indexes,\n",
" 1,\n",
" \"Multiple Paragraphs\",\n",
" False,\n",
" \"weather\",\n",
" )\n",
")\n",
"results = await task"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Print report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(results[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Show context links back to original index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for report_id in [47, 213]:\n",
" index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(report_id, index_name, index_id)\n",
" index_reports = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_community_reports.parquet\"\n",
" )\n",
" print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n",
" print(\n",
" index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n",
" 0\n",
" ]\n",
" )\n",
"for entity_id in [500, 502, 506, 1960, 1961, 1962]:\n",
" index_name = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(entity_id, index_name, index_id)\n",
" index_entities = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_entities.parquet\"\n",
" )\n",
" print(\n",
" [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n",
" \"description\"\n",
" ][:100]\n",
" )\n",
" print(\n",
" index_entities[index_entities[\"human_readable_id\"] == int(index_id)][\n",
" \"description\"\n",
" ].to_numpy()[0][:100]\n",
" )\n",
"for relationship_id in [1805, 1806]:\n",
" index_name = [ # noqa: RUF015\n",
" i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n",
" ][0][\"index_name\"]\n",
" index_id = [ # noqa: RUF015\n",
" i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n",
" ][0][\"index_id\"]\n",
" print(relationship_id, index_name, index_id)\n",
" index_relationships = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_relationships.parquet\"\n",
" )\n",
" print(\n",
" [i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)][0][ # noqa: RUF015\n",
" \"description\"\n",
" ]\n",
" )\n",
" print(\n",
" index_relationships[index_relationships[\"human_readable_id\"] == int(index_id)][\n",
" \"description\"\n",
" ].to_numpy()[0]\n",
" )\n",
"for claim_id in [100]:\n",
" index_name = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(relationship_id, index_name, index_id)\n",
" index_claims = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_covariates.parquet\"\n",
" )\n",
" print(\n",
" [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][\"description\"] # noqa: RUF015\n",
" )\n",
" print(\n",
" index_claims[index_claims[\"human_readable_id\"] == int(index_id)][\n",
" \"description\"\n",
" ].to_numpy()[0]\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-index Drift Search"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n",
"communities = [\n",
" pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n",
"]\n",
"community_reports = [\n",
" pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n",
"]\n",
"text_units = [\n",
" pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n",
"]\n",
"relationships = [\n",
" pd.read_parquet(f\"inputs/{index}/relationships.parquet\") for index in indexes\n",
"]\n",
"\n",
"task = loop.create_task(\n",
" multi_index_drift_search(\n",
" parameters,\n",
" entities,\n",
" communities,\n",
" community_reports,\n",
" text_units,\n",
" relationships,\n",
" indexes,\n",
" 1,\n",
" \"Multiple Paragraphs\",\n",
" False,\n",
" \"agriculture\",\n",
" )\n",
")\n",
"results = await task"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Print report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(results[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Show context links back to original index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for report_id in [47, 236]:\n",
" for question in results[1]:\n",
" resq = results[1][question]\n",
" if len(resq[\"reports\"]) == 0:\n",
" continue\n",
" if len([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)]) == 0:\n",
" continue\n",
" index_name = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(question, report_id, index_name, index_id)\n",
" index_reports = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_community_reports.parquet\"\n",
" )\n",
" print([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n",
" print(\n",
" index_reports[index_reports[\"community\"] == int(index_id)][\n",
" \"title\"\n",
" ].to_numpy()[0]\n",
" )\n",
" break\n",
"for source_id in [10, 16, 19, 20, 21, 22, 24, 29, 93, 95]:\n",
" for question in results[1]:\n",
" resq = results[1][question]\n",
" if len(resq[\"sources\"]) == 0:\n",
" continue\n",
" if len([i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)]) == 0:\n",
" continue\n",
" index_name = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n",
" \"index_name\"\n",
" ]\n",
" index_id = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n",
" \"index_id\"\n",
" ]\n",
" print(question, source_id, index_name, index_id)\n",
" index_sources = pd.read_parquet(\n",
" f\"inputs/{index_name}/create_final_text_units.parquet\"\n",
" )\n",
" print(\n",
" [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][\"text\"][:250] # noqa: RUF015\n",
" )\n",
" print(index_sources.loc[int(index_id)][\"text\"][:250])\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-index Basic Search"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_units = [\n",
" pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n",
"]\n",
"\n",
"task = loop.create_task(\n",
" multi_index_basic_search(\n",
" parameters, text_units, indexes, False, \"industry in maryland\"\n",
" )\n",
")\n",
"results = await task"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Print report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(results[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Show context links back to original text\n",
"\n",
"Note that original index name is not saved in context data for basic search"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for source_id in [0, 1]:\n",
" print(results[1][\"sources\"][source_id][\"text\"][:250])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}