{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Copyright (c) 2024 Microsoft Corporation.\n", "# Licensed under the MIT License." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "from pathlib import Path\n", "\n", "import pandas as pd\n", "import tiktoken\n", "\n", "from graphrag.config.enums import ModelType\n", "from graphrag.config.models.drift_search_config import DRIFTSearchConfig\n", "from graphrag.config.models.language_model_config import LanguageModelConfig\n", "from graphrag.language_model.manager import ModelManager\n", "from graphrag.query.indexer_adapters import (\n", " read_indexer_entities,\n", " read_indexer_relationships,\n", " read_indexer_report_embeddings,\n", " read_indexer_reports,\n", " read_indexer_text_units,\n", ")\n", "from graphrag.query.structured_search.drift_search.drift_context import (\n", " DRIFTSearchContextBuilder,\n", ")\n", "from graphrag.query.structured_search.drift_search.search import DRIFTSearch\n", "from graphrag.vector_stores.lancedb import LanceDBVectorStore\n", "\n", "INPUT_DIR = \"./inputs/operation dulce\"\n", "LANCEDB_URI = f\"{INPUT_DIR}/lancedb\"\n", "\n", "COMMUNITY_REPORT_TABLE = \"community_reports\"\n", "COMMUNITY_TABLE = \"communities\"\n", "ENTITY_TABLE = \"entities\"\n", "RELATIONSHIP_TABLE = \"relationships\"\n", "COVARIATE_TABLE = \"covariates\"\n", "TEXT_UNIT_TABLE = \"text_units\"\n", "COMMUNITY_LEVEL = 2\n", "\n", "\n", "# read nodes table to get community and degree data\n", "entity_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_TABLE}.parquet\")\n", "community_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet\")\n", "\n", "print(f\"Entity df columns: {entity_df.columns}\")\n", "\n", "entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)\n", "\n", "# load description embeddings to an in-memory lancedb vectorstore\n", "# to connect to a remote db, specify url and port values.\n", "description_embedding_store = LanceDBVectorStore(\n", " collection_name=\"default-entity-description\",\n", ")\n", "description_embedding_store.connect(db_uri=LANCEDB_URI)\n", "\n", "full_content_embedding_store = LanceDBVectorStore(\n", " collection_name=\"default-community-full_content\",\n", ")\n", "full_content_embedding_store.connect(db_uri=LANCEDB_URI)\n", "\n", "print(f\"Entity count: {len(entity_df)}\")\n", "entity_df.head()\n", "\n", "relationship_df = pd.read_parquet(f\"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet\")\n", "relationships = read_indexer_relationships(relationship_df)\n", "\n", "print(f\"Relationship count: {len(relationship_df)}\")\n", "relationship_df.head()\n", "\n", "text_unit_df = pd.read_parquet(f\"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet\")\n", "text_units = read_indexer_text_units(text_unit_df)\n", "\n", "print(f\"Text unit records: {len(text_unit_df)}\")\n", "text_unit_df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n", "llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n", "embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n", "\n", "chat_config = LanguageModelConfig(\n", " api_key=api_key,\n", " type=ModelType.OpenAIChat,\n", " model=llm_model,\n", " max_retries=20,\n", ")\n", "chat_model = ModelManager().get_or_create_chat_model(\n", " name=\"local_search\",\n", " model_type=ModelType.OpenAIChat,\n", " config=chat_config,\n", ")\n", "\n", "token_encoder = tiktoken.encoding_for_model(llm_model)\n", "\n", "embedding_config = LanguageModelConfig(\n", " api_key=api_key,\n", " type=ModelType.OpenAIEmbedding,\n", " model=embedding_model,\n", " max_retries=20,\n", ")\n", "\n", "text_embedder = ModelManager().get_or_create_embedding_model(\n", " name=\"local_search_embedding\",\n", " model_type=ModelType.OpenAIEmbedding,\n", " config=embedding_config,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def read_community_reports(\n", " input_dir: str,\n", " community_report_table: str = COMMUNITY_REPORT_TABLE,\n", "):\n", " \"\"\"Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path.\"\"\"\n", " input_path = Path(input_dir) / f\"{community_report_table}.parquet\"\n", " return pd.read_parquet(input_path)\n", "\n", "\n", "report_df = read_community_reports(INPUT_DIR)\n", "reports = read_indexer_reports(\n", " report_df,\n", " community_df,\n", " COMMUNITY_LEVEL,\n", " content_embedding_col=\"full_content_embeddings\",\n", ")\n", "read_indexer_report_embeddings(reports, full_content_embedding_store)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "drift_params = DRIFTSearchConfig(\n", " temperature=0,\n", " max_tokens=12_000,\n", " primer_folds=1,\n", " drift_k_followups=3,\n", " n_depth=3,\n", " n=1,\n", ")\n", "\n", "context_builder = DRIFTSearchContextBuilder(\n", " model=chat_model,\n", " text_embedder=text_embedder,\n", " entities=entities,\n", " relationships=relationships,\n", " reports=reports,\n", " entity_text_embeddings=description_embedding_store,\n", " text_units=text_units,\n", " token_encoder=token_encoder,\n", " config=drift_params,\n", ")\n", "\n", "search = DRIFTSearch(\n", " model=chat_model, context_builder=context_builder, token_encoder=token_encoder\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "resp = await search.search(\"Who is agent Mercer?\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "resp.response" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(resp.context_data)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 }