{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Copyright (c) 2024 Microsoft Corporation.\n", "# Licensed under the MIT License." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Multi Index Search\n", "This notebook demonstrates multi-index search using the GraphRAG API.\n", "\n", "Indexes created from Wikipedia state articles for Alaska, California, DC, Maryland, NY and Washington are used." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import asyncio\n", "\n", "import pandas as pd\n", "\n", "from graphrag.api.query import (\n", " multi_index_basic_search,\n", " multi_index_drift_search,\n", " multi_index_global_search,\n", " multi_index_local_search,\n", ")\n", "from graphrag.config.create_graphrag_config import create_graphrag_config\n", "\n", "indexes = [\"alaska\", \"california\", \"dc\", \"maryland\", \"ny\", \"washington\"]\n", "indexes = sorted(indexes)\n", "\n", "print(indexes)\n", "\n", "vector_store_configs = {\n", " index: {\n", " \"type\": \"lancedb\",\n", " \"db_uri\": f\"inputs/{index}/lancedb\",\n", " \"container_name\": \"default\",\n", " \"overwrite\": True,\n", " \"index_name\": f\"{index}\",\n", " }\n", " for index in indexes\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config_data = {\n", " \"models\": {\n", " \"default_chat_model\": {\n", " \"model_supports_json\": True,\n", " \"parallelization_num_threads\": 50,\n", " \"parallelization_stagger\": 0.3,\n", " \"async_mode\": \"threaded\",\n", " \"type\": \"azure_openai_chat\",\n", " \"model\": \"gpt-4o\",\n", " \"auth_type\": \"azure_managed_identity\",\n", " \"api_base\": \"\",\n", " \"api_version\": \"2024-02-15-preview\",\n", " \"deployment_name\": \"gpt-4o\",\n", " },\n", " \"default_embedding_model\": {\n", " \"parallelization_num_threads\": 50,\n", " \"parallelization_stagger\": 0.3,\n", " \"async_mode\": \"threaded\",\n", " \"type\": \"azure_openai_embedding\",\n", " \"model\": \"text-embedding-3-large\",\n", " \"auth_type\": \"azure_managed_identity\",\n", " \"api_base\": \"\",\n", " \"api_version\": \"2024-02-15-preview\",\n", " \"deployment_name\": \"text-embedding-3-large\",\n", " },\n", " },\n", " \"vector_store\": vector_store_configs,\n", " \"local_search\": {\n", " \"prompt\": \"prompts/local_search_system_prompt.txt\",\n", " \"llm_max_tokens\": 12000,\n", " },\n", " \"global_search\": {\n", " \"map_prompt\": \"prompts/global_search_map_system_prompt.txt\",\n", " \"reduce_prompt\": \"prompts/global_search_reduce_system_prompt.txt\",\n", " \"knowledge_prompt\": \"prompts/global_search_knowledge_system_prompt.txt\",\n", " },\n", " \"drift_search\": {\n", " \"prompt\": \"prompts/drift_search_system_prompt.txt\",\n", " \"reduce_prompt\": \"prompts/drift_search_reduce_prompt.txt\",\n", " },\n", " \"basic_search\": {\"prompt\": \"prompts/basic_search_system_prompt.txt\"},\n", "}\n", "parameters = create_graphrag_config(config_data, \".\")\n", "loop = asyncio.get_event_loop()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi-index Global Search" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n", "communities = [\n", " pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n", "]\n", "community_reports = [\n", " pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n", "]\n", "\n", "task = loop.create_task(\n", " multi_index_global_search(\n", " parameters,\n", " entities,\n", " communities,\n", " community_reports,\n", " indexes,\n", " 1,\n", " False,\n", " \"Multiple Paragraphs\",\n", " False,\n", " \"Describe this dataset.\",\n", " )\n", ")\n", "results = await task" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Print report" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(results[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Show context links back to original index" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for report_id in [120, 129, 40, 16, 204, 143, 85, 122, 83]:\n", " index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", " \"index_name\"\n", " ]\n", " index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", " \"index_id\"\n", " ]\n", " print(report_id, index_name, index_id)\n", " index_reports = pd.read_parquet(\n", " f\"inputs/{index_name}/create_final_community_reports.parquet\"\n", " )\n", " print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n", " print(\n", " index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n", " 0\n", " ]\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Multi-index Local Search" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n", "communities = [\n", " pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n", "]\n", "community_reports = [\n", " pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n", "]\n", "covariates = [\n", " pd.read_parquet(f\"inputs/{index}/covariates.parquet\") for index in indexes\n", "]\n", "text_units = [\n", " pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n", "]\n", "relationships = [\n", " pd.read_parquet(f\"inputs/{index}/relationships.parquet\") for index in indexes\n", "]\n", "\n", "task = loop.create_task(\n", " multi_index_local_search(\n", " parameters,\n", " entities,\n", " communities,\n", " community_reports,\n", " text_units,\n", " relationships,\n", " covariates,\n", " indexes,\n", " 1,\n", " \"Multiple Paragraphs\",\n", " False,\n", " \"weather\",\n", " )\n", ")\n", "results = await task" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Print report" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(results[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Show context links back to original index" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for report_id in [47, 213]:\n", " index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", " \"index_name\"\n", " ]\n", " index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", " \"index_id\"\n", " ]\n", " print(report_id, index_name, index_id)\n", " index_reports = pd.read_parquet(\n", " f\"inputs/{index_name}/create_final_community_reports.parquet\"\n", " )\n", " print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n", " print(\n", " index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n", " 0\n", " ]\n", " )\n", "for entity_id in [500, 502, 506, 1960, 1961, 1962]:\n", " index_name = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n", " \"index_name\"\n", " ]\n", " index_id = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n", " \"index_id\"\n", " ]\n", " print(entity_id, index_name, index_id)\n", " index_entities = pd.read_parquet(\n", " f\"inputs/{index_name}/create_final_entities.parquet\"\n", " )\n", " print(\n", " [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n", " \"description\"\n", " ][:100]\n", " )\n", " print(\n", " index_entities[index_entities[\"human_readable_id\"] == int(index_id)][\n", " \"description\"\n", " ].to_numpy()[0][:100]\n", " )\n", "for relationship_id in [1805, 1806]:\n", " index_name = [ # noqa: RUF015\n", " i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n", " ][0][\"index_name\"]\n", " index_id = [ # noqa: RUF015\n", " i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n", " ][0][\"index_id\"]\n", " print(relationship_id, index_name, index_id)\n", " index_relationships = pd.read_parquet(\n", " f\"inputs/{index_name}/create_final_relationships.parquet\"\n", " )\n", " print(\n", " [i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)][0][ # noqa: RUF015\n", " \"description\"\n", " ]\n", " )\n", " print(\n", " index_relationships[index_relationships[\"human_readable_id\"] == int(index_id)][\n", " \"description\"\n", " ].to_numpy()[0]\n", " )\n", "for claim_id in [100]:\n", " index_name = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][ # noqa: RUF015\n", " \"index_name\"\n", " ]\n", " index_id = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][ # noqa: RUF015\n", " \"index_id\"\n", " ]\n", " print(relationship_id, index_name, index_id)\n", " index_claims = pd.read_parquet(\n", " f\"inputs/{index_name}/create_final_covariates.parquet\"\n", " )\n", " print(\n", " [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][\"description\"] # noqa: RUF015\n", " )\n", " print(\n", " index_claims[index_claims[\"human_readable_id\"] == int(index_id)][\n", " \"description\"\n", " ].to_numpy()[0]\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi-index Drift Search" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n", "communities = [\n", " pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n", "]\n", "community_reports = [\n", " pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n", "]\n", "text_units = [\n", " pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n", "]\n", "relationships = [\n", " pd.read_parquet(f\"inputs/{index}/relationships.parquet\") for index in indexes\n", "]\n", "\n", "task = loop.create_task(\n", " multi_index_drift_search(\n", " parameters,\n", " entities,\n", " communities,\n", " community_reports,\n", " text_units,\n", " relationships,\n", " indexes,\n", " 1,\n", " \"Multiple Paragraphs\",\n", " False,\n", " \"agriculture\",\n", " )\n", ")\n", "results = await task" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Print report" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(results[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Show context links back to original index" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for report_id in [47, 236]:\n", " for question in results[1]:\n", " resq = results[1][question]\n", " if len(resq[\"reports\"]) == 0:\n", " continue\n", " if len([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)]) == 0:\n", " continue\n", " index_name = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", " \"index_name\"\n", " ]\n", " index_id = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", " \"index_id\"\n", " ]\n", " print(question, report_id, index_name, index_id)\n", " index_reports = pd.read_parquet(\n", " f\"inputs/{index_name}/create_final_community_reports.parquet\"\n", " )\n", " print([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n", " print(\n", " index_reports[index_reports[\"community\"] == int(index_id)][\n", " \"title\"\n", " ].to_numpy()[0]\n", " )\n", " break\n", "for source_id in [10, 16, 19, 20, 21, 22, 24, 29, 93, 95]:\n", " for question in results[1]:\n", " resq = results[1][question]\n", " if len(resq[\"sources\"]) == 0:\n", " continue\n", " if len([i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)]) == 0:\n", " continue\n", " index_name = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n", " \"index_name\"\n", " ]\n", " index_id = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n", " \"index_id\"\n", " ]\n", " print(question, source_id, index_name, index_id)\n", " index_sources = pd.read_parquet(\n", " f\"inputs/{index_name}/create_final_text_units.parquet\"\n", " )\n", " print(\n", " [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][\"text\"][:250] # noqa: RUF015\n", " )\n", " print(index_sources.loc[int(index_id)][\"text\"][:250])\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi-index Basic Search" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_units = [\n", " pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n", "]\n", "\n", "task = loop.create_task(\n", " multi_index_basic_search(\n", " parameters, text_units, indexes, False, \"industry in maryland\"\n", " )\n", ")\n", "results = await task" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Print report" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(results[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Show context links back to original text\n", "\n", "Note that original index name is not saved in context data for basic search" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for source_id in [0, 1]:\n", " print(results[1][\"sources\"][source_id][\"text\"][:250])" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 }