mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00

* Update API overview * Fix global search example * Fix local search example * Fix global dynamic example * Fix drift example * Update multi-index example * Semver
235 lines
7.0 KiB
Plaintext
235 lines
7.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Copyright (c) 2024 Microsoft Corporation.\n",
|
|
"# Licensed under the MIT License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"import pandas as pd\n",
|
|
"import tiktoken\n",
|
|
"\n",
|
|
"from graphrag.config.enums import ModelType\n",
|
|
"from graphrag.config.models.drift_search_config import DRIFTSearchConfig\n",
|
|
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
|
"from graphrag.language_model.manager import ModelManager\n",
|
|
"from graphrag.query.indexer_adapters import (\n",
|
|
" read_indexer_entities,\n",
|
|
" read_indexer_relationships,\n",
|
|
" read_indexer_report_embeddings,\n",
|
|
" read_indexer_reports,\n",
|
|
" read_indexer_text_units,\n",
|
|
")\n",
|
|
"from graphrag.query.structured_search.drift_search.drift_context import (\n",
|
|
" DRIFTSearchContextBuilder,\n",
|
|
")\n",
|
|
"from graphrag.query.structured_search.drift_search.search import DRIFTSearch\n",
|
|
"from graphrag.vector_stores.lancedb import LanceDBVectorStore\n",
|
|
"\n",
|
|
"INPUT_DIR = \"./inputs/operation dulce\"\n",
|
|
"LANCEDB_URI = f\"{INPUT_DIR}/lancedb\"\n",
|
|
"\n",
|
|
"COMMUNITY_REPORT_TABLE = \"community_reports\"\n",
|
|
"COMMUNITY_TABLE = \"communities\"\n",
|
|
"ENTITY_TABLE = \"entities\"\n",
|
|
"RELATIONSHIP_TABLE = \"relationships\"\n",
|
|
"COVARIATE_TABLE = \"covariates\"\n",
|
|
"TEXT_UNIT_TABLE = \"text_units\"\n",
|
|
"COMMUNITY_LEVEL = 2\n",
|
|
"\n",
|
|
"\n",
|
|
"# read nodes table to get community and degree data\n",
|
|
"entity_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_TABLE}.parquet\")\n",
|
|
"community_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet\")\n",
|
|
"\n",
|
|
"print(f\"Entity df columns: {entity_df.columns}\")\n",
|
|
"\n",
|
|
"entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)\n",
|
|
"\n",
|
|
"# load description embeddings to an in-memory lancedb vectorstore\n",
|
|
"# to connect to a remote db, specify url and port values.\n",
|
|
"description_embedding_store = LanceDBVectorStore(\n",
|
|
" collection_name=\"default-entity-description\",\n",
|
|
")\n",
|
|
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
|
"\n",
|
|
"full_content_embedding_store = LanceDBVectorStore(\n",
|
|
" collection_name=\"default-community-full_content\",\n",
|
|
")\n",
|
|
"full_content_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
|
"\n",
|
|
"print(f\"Entity count: {len(entity_df)}\")\n",
|
|
"entity_df.head()\n",
|
|
"\n",
|
|
"relationship_df = pd.read_parquet(f\"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet\")\n",
|
|
"relationships = read_indexer_relationships(relationship_df)\n",
|
|
"\n",
|
|
"print(f\"Relationship count: {len(relationship_df)}\")\n",
|
|
"relationship_df.head()\n",
|
|
"\n",
|
|
"text_unit_df = pd.read_parquet(f\"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet\")\n",
|
|
"text_units = read_indexer_text_units(text_unit_df)\n",
|
|
"\n",
|
|
"print(f\"Text unit records: {len(text_unit_df)}\")\n",
|
|
"text_unit_df.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
|
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
|
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
|
|
"\n",
|
|
"chat_config = LanguageModelConfig(\n",
|
|
" api_key=api_key,\n",
|
|
" type=ModelType.OpenAIChat,\n",
|
|
" model=llm_model,\n",
|
|
" max_retries=20,\n",
|
|
")\n",
|
|
"chat_model = ModelManager().get_or_create_chat_model(\n",
|
|
" name=\"local_search\",\n",
|
|
" model_type=ModelType.OpenAIChat,\n",
|
|
" config=chat_config,\n",
|
|
")\n",
|
|
"\n",
|
|
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
|
|
"\n",
|
|
"embedding_config = LanguageModelConfig(\n",
|
|
" api_key=api_key,\n",
|
|
" type=ModelType.OpenAIEmbedding,\n",
|
|
" model=embedding_model,\n",
|
|
" max_retries=20,\n",
|
|
")\n",
|
|
"\n",
|
|
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
|
|
" name=\"local_search_embedding\",\n",
|
|
" model_type=ModelType.OpenAIEmbedding,\n",
|
|
" config=embedding_config,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def read_community_reports(\n",
|
|
" input_dir: str,\n",
|
|
" community_report_table: str = COMMUNITY_REPORT_TABLE,\n",
|
|
"):\n",
|
|
" \"\"\"Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path.\"\"\"\n",
|
|
" input_path = Path(input_dir) / f\"{community_report_table}.parquet\"\n",
|
|
" return pd.read_parquet(input_path)\n",
|
|
"\n",
|
|
"\n",
|
|
"report_df = read_community_reports(INPUT_DIR)\n",
|
|
"reports = read_indexer_reports(\n",
|
|
" report_df,\n",
|
|
" community_df,\n",
|
|
" COMMUNITY_LEVEL,\n",
|
|
" content_embedding_col=\"full_content_embeddings\",\n",
|
|
")\n",
|
|
"read_indexer_report_embeddings(reports, full_content_embedding_store)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"drift_params = DRIFTSearchConfig(\n",
|
|
" temperature=0,\n",
|
|
" max_tokens=12_000,\n",
|
|
" primer_folds=1,\n",
|
|
" drift_k_followups=3,\n",
|
|
" n_depth=3,\n",
|
|
" n=1,\n",
|
|
")\n",
|
|
"\n",
|
|
"context_builder = DRIFTSearchContextBuilder(\n",
|
|
" model=chat_model,\n",
|
|
" text_embedder=text_embedder,\n",
|
|
" entities=entities,\n",
|
|
" relationships=relationships,\n",
|
|
" reports=reports,\n",
|
|
" entity_text_embeddings=description_embedding_store,\n",
|
|
" text_units=text_units,\n",
|
|
" token_encoder=token_encoder,\n",
|
|
" config=drift_params,\n",
|
|
")\n",
|
|
"\n",
|
|
"search = DRIFTSearch(\n",
|
|
" model=chat_model, context_builder=context_builder, token_encoder=token_encoder\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"resp = await search.search(\"Who is agent Mercer?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"resp.response"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(resp.context_data)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|