Community children (#1704)

* Add children to the community tables

* Replace NaN children with empty list

* Replace subcommunity logic with built-in parent/child fields

* Remove restore_community_hierarchy

* Add children and frequency to migration notebook

* Format

* Semver

* Add children to reports

* Update tests

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Nathan Evans 2025-02-13 17:03:51 -08:00 committed by GitHub
parent 35b639399b
commit 981fd31963
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 118 additions and 127 deletions

View File

@ -0,0 +1,4 @@
{
"type": "major",
"description": "Add children to communities to avoid re-compute."
}

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -25,17 +25,17 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is the directory that has your settings.yaml\n",
"PROJECT_DIRECTORY = \"<your project directory\""
"PROJECT_DIRECTORY = \"<your project directory>\""
]
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -54,7 +54,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -65,7 +65,7 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -96,6 +96,30 @@
" final_nodes.loc[:, [\"id\", \"degree\", \"x\", \"y\"]].groupby(\"id\").first().reset_index()\n",
")\n",
"final_entities = final_entities.merge(graph_props, on=\"id\", how=\"left\")\n",
"# we're also persistint the frequency column\n",
"final_entities[\"frequency\"] = final_entities[\"text_unit_ids\"].count()\n",
"\n",
"\n",
"# we added children to communities to eliminate query-time reconstruction\n",
"parent_grouped = final_communities.groupby(\"parent\").agg(\n",
" children=(\"community\", \"unique\")\n",
")\n",
"final_communities = final_communities.merge(\n",
" parent_grouped,\n",
" left_on=\"community\",\n",
" right_on=\"parent\",\n",
" how=\"left\",\n",
")\n",
"\n",
"# add children to the reports as well\n",
"final_community_reports = final_community_reports.merge(\n",
" parent_grouped,\n",
" left_on=\"community\",\n",
" right_on=\"parent\",\n",
" how=\"left\",\n",
")\n",
"\n",
"# copy children into the reports as well\n",
"\n",
"# we renamed all the output files for better clarity now that we don't have workflow naming constraints from DataShaper\n",
"await write_table_to_storage(final_documents, \"documents\", storage)\n",

View File

@ -4,8 +4,10 @@
"""All the steps to transform final communities."""
from datetime import datetime, timezone
from typing import cast
from uuid import uuid4
import numpy as np
import pandas as pd
from graphrag.index.operations.cluster_graph import cluster_graph
@ -92,7 +94,21 @@ def create_communities(
str
)
final_communities["parent"] = final_communities["parent"].astype(int)
# collect the children so we have a tree going both ways
parent_grouped = cast(
"pd.DataFrame",
final_communities.groupby("parent").agg(children=("community", "unique")),
)
final_communities = final_communities.merge(
parent_grouped,
left_on="community",
right_on="parent",
how="left",
)
# replace NaN children with empty list
final_communities["children"] = final_communities["children"].apply(
lambda x: x if isinstance(x, np.ndarray) else [] # type: ignore
)
# add fields for incremental update tracking
final_communities["period"] = datetime.now(timezone.utc).date().isoformat()
final_communities["size"] = final_communities.loc[:, "entity_ids"].apply(len)
@ -103,8 +119,9 @@ def create_communities(
"id",
"human_readable_id",
"community",
"parent",
"level",
"parent",
"children",
"title",
"entity_ids",
"relationship_ids",

View File

@ -62,6 +62,7 @@ async def create_community_reports(
community_reports = await summarize_communities(
nodes,
communities,
local_contexts,
build_level_context,
callbacks,

View File

@ -53,6 +53,7 @@ async def create_community_reports_text(
community_reports = await summarize_communities(
nodes,
communities,
local_contexts,
build_level_context,
callbacks,

View File

@ -13,9 +13,9 @@ def finalize_community_reports(
communities: pd.DataFrame,
) -> pd.DataFrame:
"""All the steps to transform final community reports."""
# Merge with communities to add size and period
# Merge with communities to add shared fields
community_reports = reports.merge(
communities.loc[:, ["community", "parent", "size", "period"]],
communities.loc[:, ["community", "parent", "children", "size", "period"]],
on="community",
how="left",
copy=False,
@ -31,8 +31,9 @@ def finalize_community_reports(
"id",
"human_readable_id",
"community",
"parent",
"level",
"parent",
"children",
"title",
"summary",
"full_content",

View File

@ -20,7 +20,6 @@ from graphrag.index.operations.summarize_communities.typing import (
)
from graphrag.index.operations.summarize_communities.utils import (
get_levels,
restore_community_hierarchy,
)
from graphrag.index.run.derive_from_rows import derive_from_rows
from graphrag.logger.progress import progress_ticker
@ -30,6 +29,7 @@ log = logging.getLogger(__name__)
async def summarize_communities(
nodes: pd.DataFrame,
communities: pd.DataFrame,
local_contexts,
level_context_builder: Callable,
callbacks: WorkflowCallbacks,
@ -49,7 +49,12 @@ async def summarize_communities(
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(nodes)
community_hierarchy = restore_community_hierarchy(nodes)
community_hierarchy = (
communities.explode("children")
.rename({"children": "sub_community"}, axis=1)
.loc[:, ["community", "level", "sub_community"]]
).dropna()
levels = get_levels(nodes)
level_contexts = []

View File

@ -3,8 +3,6 @@
"""A module containing community report generation utilities."""
from itertools import pairwise
import pandas as pd
import graphrag.model.schemas as schemas
@ -17,48 +15,3 @@ def get_levels(
levels = df[level_column].dropna().unique()
levels = [int(lvl) for lvl in levels if lvl != -1]
return sorted(levels, reverse=True)
def restore_community_hierarchy(
input: pd.DataFrame,
name_column: str = schemas.TITLE,
community_column: str = schemas.COMMUNITY_ID,
level_column: str = schemas.COMMUNITY_LEVEL,
) -> pd.DataFrame:
"""Restore the community hierarchy from the node data."""
# Group by community and level, aggregate names as lists
community_df = (
input.groupby([community_column, level_column])[name_column]
.apply(set)
.reset_index()
)
# Build dictionary with levels as integers
community_levels = {
level: group.set_index(community_column)[name_column].to_dict()
for level, group in community_df.groupby(level_column)
}
# get unique levels, sorted in ascending order
levels = sorted(community_levels.keys()) # type: ignore
community_hierarchy = []
# Iterate through adjacent levels
for current_level, next_level in pairwise(levels):
current_communities = community_levels[current_level]
next_communities = community_levels[next_level]
# Find sub-communities
for curr_comm, curr_entities in current_communities.items():
for next_comm, next_entities in next_communities.items():
if next_entities.issubset(curr_entities):
community_hierarchy.append({
community_column: curr_comm,
schemas.COMMUNITY_LEVEL: current_level,
schemas.SUB_COMMUNITY: next_comm,
schemas.SUB_COMMUNITY_SIZE: len(next_entities),
})
return pd.DataFrame(
community_hierarchy,
)

View File

@ -13,9 +13,15 @@ from graphrag.model.named import Named
class Community(Named):
"""A protocol for a community in the system."""
level: str = ""
level: str
"""Community level."""
parent: str
"""Community ID of the parent node of this community."""
children: list[str]
"""List of community IDs of the child nodes of this community."""
entity_ids: list[str] | None = None
"""List of entity IDs related to the community (optional)."""
@ -25,9 +31,6 @@ class Community(Named):
covariate_ids: dict[str, list[str]] | None = None
"""Dictionary of different types of covariates related to the community (optional), e.g. claims"""
sub_community_ids: list[str] | None = None
"""List of community IDs of the child nodes of this community (optional)."""
attributes: dict[str, Any] | None = None
"""A dictionary of additional attributes associated with the community (optional). To be included in the search prompt."""
@ -48,7 +51,8 @@ class Community(Named):
entities_key: str = "entity_ids",
relationships_key: str = "relationship_ids",
covariates_key: str = "covariate_ids",
sub_communities_key: str = "sub_community_ids",
parent_key: str = "parent",
children_key: str = "children",
attributes_key: str = "attributes",
size_key: str = "size",
period_key: str = "period",
@ -57,12 +61,13 @@ class Community(Named):
return Community(
id=d[id_key],
title=d[title_key],
short_id=d.get(short_id_key),
level=d[level_key],
parent=d[parent_key],
children=d[children_key],
short_id=d.get(short_id_key),
entity_ids=d.get(entities_key),
relationship_ids=d.get(relationships_key),
covariate_ids=d.get(covariates_key),
sub_community_ids=d.get(sub_communities_key),
attributes=d.get(attributes_key),
size=d.get(size_key),
period=d.get(period_key),

View File

@ -29,7 +29,6 @@ CLAIM_DETAILS = "claim_details"
# COMMUNITY HIERARCHY TABLE SCHEMA
SUB_COMMUNITY = "sub_community"
SUB_COMMUNITY_SIZE = "sub_community_size"
COMMUNITY_LEVEL = "level"
# COMMUNITY CONTEXT TABLE SCHEMA

View File

@ -56,23 +56,7 @@ class DynamicCommunitySelection:
self.llm_kwargs = llm_kwargs
self.reports = {report.community_id: report for report in community_reports}
# mapping from community to sub communities
self.node2children = {
community.short_id: (
[]
if community.sub_community_ids is None
else [str(x) for x in community.sub_community_ids]
)
for community in communities
if community.short_id is not None
}
# mapping from community to parent community
self.node2parent: dict[str, str] = {
sub_community: community
for community, sub_communities in self.node2children.items()
for sub_community in sub_communities
}
self.communities = {community.short_id: community for community in communities}
# mapping from level to communities
self.levels: dict[str, list[str]] = {}
@ -140,18 +124,18 @@ class DynamicCommunitySelection:
relevant_communities.add(community)
# find children nodes of the current node and append them to the queue
# TODO check why some sub_communities are NOT in report_df
if community in self.node2children:
for sub_community in self.node2children[community]:
if sub_community in self.reports:
communities_to_rate.append(sub_community)
if community in self.communities:
for child in self.communities[community].children:
if child in self.reports:
communities_to_rate.append(child)
else:
log.debug(
"dynamic community selection: cannot find community %s in reports",
sub_community,
child,
)
# remove parent node if the current node is deemed relevant
if not self.keep_parent and community in self.node2parent:
relevant_communities.discard(self.node2parent[community])
if not self.keep_parent and community in self.communities:
relevant_communities.discard(self.communities[community].parent)
queue = communities_to_rate
level += 1
if (

View File

@ -12,9 +12,6 @@ from typing import cast
import pandas as pd
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.summarize_communities.utils import (
restore_community_hierarchy,
)
from graphrag.model.community import Community
from graphrag.model.community_report import CommunityReport
from graphrag.model.covariate import Covariate
@ -197,27 +194,6 @@ def read_indexer_communities(
]
nodes_df = nodes_df.loc[nodes_df.community.isin(reports_df.community.unique())]
# reconstruct the community hierarchy
# note that restore_community_hierarchy only return communities with sub communities
community_hierarchy = restore_community_hierarchy(input=nodes_df)
# small datasets can result in hierarchies that are only one deep, so the hierarchy will have no rows
if not community_hierarchy.empty:
community_hierarchy = (
community_hierarchy.groupby(["community"])
.agg({"sub_community": list})
.reset_index()
.rename(columns={"sub_community": "sub_community_ids"})
)
# add sub community IDs to community DataFrame
communities_df = communities_df.merge(
community_hierarchy, on="community", how="left"
)
# replace NaN sub community IDs with empty list
communities_df.sub_community_ids = communities_df.sub_community_ids.apply(
lambda x: x if isinstance(x, list) else []
)
return read_communities(
communities_df,
id_col="id",
@ -227,7 +203,8 @@ def read_indexer_communities(
entities_col=None,
relationships_col=None,
covariates_col=None,
sub_communities_col="sub_community_ids",
parent_col="parent",
children_col="children",
attributes_cols=None,
)

View File

@ -12,6 +12,7 @@ from graphrag.model.entity import Entity
from graphrag.model.relationship import Relationship
from graphrag.model.text_unit import TextUnit
from graphrag.query.input.loaders.utils import (
to_list,
to_optional_dict,
to_optional_float,
to_optional_int,
@ -154,7 +155,8 @@ def read_communities(
entities_col: str | None = "entity_ids",
relationships_col: str | None = "relationship_ids",
covariates_col: str | None = "covariate_ids",
sub_communities_col: str | None = "sub_community_ids",
parent_col: str | None = "parent",
children_col: str | None = "children",
attributes_cols: list[str] | None = None,
) -> list[Community]:
"""Read communities from a dataframe using pre-converted records."""
@ -172,7 +174,8 @@ def read_communities(
covariate_ids=to_optional_dict(
row, covariates_col, key_type=str, value_type=str
),
sub_community_ids=to_optional_list(row, sub_communities_col),
parent=to_str(row, parent_col),
children=to_list(row, children_col),
attributes=(
{col: row.get(col) for col in attributes_cols}
if attributes_cols

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -11,6 +11,7 @@ from .util import (
compare_outputs,
create_test_context,
load_test_table,
update_document_metadata,
)
@ -43,6 +44,8 @@ async def test_create_base_text_units_metadata():
config.input.metadata = ["title"]
config.chunks.prepend_metadata = True
await update_document_metadata(config.input.metadata, context)
await run_workflow(
config,
context,
@ -65,6 +68,8 @@ async def test_create_base_text_units_metadata_included_in_chunk():
config.chunks.prepend_metadata = True
config.chunks.chunk_size_includes_metadata = True
await update_document_metadata(config.input.metadata, context)
await run_workflow(
config,
context,

View File

@ -13,6 +13,7 @@ from .util import (
compare_outputs,
create_test_context,
load_test_table,
update_document_metadata,
)
@ -37,8 +38,6 @@ async def test_create_final_documents():
async def test_create_final_documents_with_metadata_column():
expected = load_test_table("documents")
context = await create_test_context(
storage=["text_units"],
)
@ -46,6 +45,11 @@ async def test_create_final_documents_with_metadata_column():
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
config.input.metadata = ["title"]
# simulate the metadata construction during initial input loading
await update_document_metadata(config.input.metadata, context)
expected = await load_table_from_storage("documents", context.storage)
await run_workflow(
config,
context,
@ -54,12 +58,12 @@ async def test_create_final_documents_with_metadata_column():
actual = await load_table_from_storage("documents", context.storage)
# we should have dropped "title" and added "attributes"
# our test dataframe does not have attributes, so we'll assert without it
# our test dataframe does not have metadata, so we'll assert without it
# and separately confirm it is in the output
compare_outputs(
actual, expected, columns=["id", "human_readable_id", "text", "text_unit_ids"]
actual, expected, columns=["id", "human_readable_id", "text", "metadata"]
)
assert len(actual.columns) == 6
assert len(actual.columns) == 7
assert "title" in actual.columns
assert "text_unit_ids" in actual.columns
assert "metadata" in actual.columns

View File

@ -7,7 +7,7 @@ from pandas.testing import assert_series_equal
import graphrag.config.defaults as defs
from graphrag.index.context import PipelineRunContext
from graphrag.index.run.utils import create_run_context
from graphrag.utils.storage import write_table_to_storage
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
pd.set_option("display.max_columns", None)
@ -43,7 +43,6 @@ async def create_test_context(storage: list[str] | None = None) -> PipelineRunCo
if storage:
for name in storage:
table = load_test_table(name)
# normal storage interface insists on bytes
await write_table_to_storage(table, name, context.storage)
return context
@ -83,3 +82,12 @@ def compare_outputs(
print("Actual:")
print(actual[column])
raise
async def update_document_metadata(metadata: list[str], context: PipelineRunContext):
"""Takes the default documents and adds the configured metadata columns for later parsing by the text units and final documents workflows."""
documents = await load_table_from_storage("documents", context.storage)
documents["metadata"] = documents[metadata].apply(lambda row: row.to_dict(), axis=1)
await write_table_to_storage(
documents, "documents", context.storage
) # write to the runtime context storage only