mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Community children (#1704)
* Add children to the community tables * Replace NaN children with empty list * Replace subcommunity logic with built-in parent/child fields * Remove restore_community_hierarchy * Add children and frequency to migration notebook * Format * Semver * Add children to reports * Update tests --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
35b639399b
commit
981fd31963
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "major",
|
||||
"description": "Add children to communities to avoid re-compute."
|
||||
}
|
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -25,17 +25,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is the directory that has your settings.yaml\n",
|
||||
"PROJECT_DIRECTORY = \"<your project directory\""
|
||||
"PROJECT_DIRECTORY = \"<your project directory>\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -54,7 +54,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -65,7 +65,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -96,6 +96,30 @@
|
||||
" final_nodes.loc[:, [\"id\", \"degree\", \"x\", \"y\"]].groupby(\"id\").first().reset_index()\n",
|
||||
")\n",
|
||||
"final_entities = final_entities.merge(graph_props, on=\"id\", how=\"left\")\n",
|
||||
"# we're also persistint the frequency column\n",
|
||||
"final_entities[\"frequency\"] = final_entities[\"text_unit_ids\"].count()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# we added children to communities to eliminate query-time reconstruction\n",
|
||||
"parent_grouped = final_communities.groupby(\"parent\").agg(\n",
|
||||
" children=(\"community\", \"unique\")\n",
|
||||
")\n",
|
||||
"final_communities = final_communities.merge(\n",
|
||||
" parent_grouped,\n",
|
||||
" left_on=\"community\",\n",
|
||||
" right_on=\"parent\",\n",
|
||||
" how=\"left\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# add children to the reports as well\n",
|
||||
"final_community_reports = final_community_reports.merge(\n",
|
||||
" parent_grouped,\n",
|
||||
" left_on=\"community\",\n",
|
||||
" right_on=\"parent\",\n",
|
||||
" how=\"left\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# copy children into the reports as well\n",
|
||||
"\n",
|
||||
"# we renamed all the output files for better clarity now that we don't have workflow naming constraints from DataShaper\n",
|
||||
"await write_table_to_storage(final_documents, \"documents\", storage)\n",
|
||||
|
@ -4,8 +4,10 @@
|
||||
"""All the steps to transform final communities."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.index.operations.cluster_graph import cluster_graph
|
||||
@ -92,7 +94,21 @@ def create_communities(
|
||||
str
|
||||
)
|
||||
final_communities["parent"] = final_communities["parent"].astype(int)
|
||||
|
||||
# collect the children so we have a tree going both ways
|
||||
parent_grouped = cast(
|
||||
"pd.DataFrame",
|
||||
final_communities.groupby("parent").agg(children=("community", "unique")),
|
||||
)
|
||||
final_communities = final_communities.merge(
|
||||
parent_grouped,
|
||||
left_on="community",
|
||||
right_on="parent",
|
||||
how="left",
|
||||
)
|
||||
# replace NaN children with empty list
|
||||
final_communities["children"] = final_communities["children"].apply(
|
||||
lambda x: x if isinstance(x, np.ndarray) else [] # type: ignore
|
||||
)
|
||||
# add fields for incremental update tracking
|
||||
final_communities["period"] = datetime.now(timezone.utc).date().isoformat()
|
||||
final_communities["size"] = final_communities.loc[:, "entity_ids"].apply(len)
|
||||
@ -103,8 +119,9 @@ def create_communities(
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"community",
|
||||
"parent",
|
||||
"level",
|
||||
"parent",
|
||||
"children",
|
||||
"title",
|
||||
"entity_ids",
|
||||
"relationship_ids",
|
||||
|
@ -62,6 +62,7 @@ async def create_community_reports(
|
||||
|
||||
community_reports = await summarize_communities(
|
||||
nodes,
|
||||
communities,
|
||||
local_contexts,
|
||||
build_level_context,
|
||||
callbacks,
|
||||
|
@ -53,6 +53,7 @@ async def create_community_reports_text(
|
||||
|
||||
community_reports = await summarize_communities(
|
||||
nodes,
|
||||
communities,
|
||||
local_contexts,
|
||||
build_level_context,
|
||||
callbacks,
|
||||
|
@ -13,9 +13,9 @@ def finalize_community_reports(
|
||||
communities: pd.DataFrame,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform final community reports."""
|
||||
# Merge with communities to add size and period
|
||||
# Merge with communities to add shared fields
|
||||
community_reports = reports.merge(
|
||||
communities.loc[:, ["community", "parent", "size", "period"]],
|
||||
communities.loc[:, ["community", "parent", "children", "size", "period"]],
|
||||
on="community",
|
||||
how="left",
|
||||
copy=False,
|
||||
@ -31,8 +31,9 @@ def finalize_community_reports(
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"community",
|
||||
"parent",
|
||||
"level",
|
||||
"parent",
|
||||
"children",
|
||||
"title",
|
||||
"summary",
|
||||
"full_content",
|
||||
|
@ -20,7 +20,6 @@ from graphrag.index.operations.summarize_communities.typing import (
|
||||
)
|
||||
from graphrag.index.operations.summarize_communities.utils import (
|
||||
get_levels,
|
||||
restore_community_hierarchy,
|
||||
)
|
||||
from graphrag.index.run.derive_from_rows import derive_from_rows
|
||||
from graphrag.logger.progress import progress_ticker
|
||||
@ -30,6 +29,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
async def summarize_communities(
|
||||
nodes: pd.DataFrame,
|
||||
communities: pd.DataFrame,
|
||||
local_contexts,
|
||||
level_context_builder: Callable,
|
||||
callbacks: WorkflowCallbacks,
|
||||
@ -49,7 +49,12 @@ async def summarize_communities(
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(nodes)
|
||||
|
||||
community_hierarchy = restore_community_hierarchy(nodes)
|
||||
community_hierarchy = (
|
||||
communities.explode("children")
|
||||
.rename({"children": "sub_community"}, axis=1)
|
||||
.loc[:, ["community", "level", "sub_community"]]
|
||||
).dropna()
|
||||
|
||||
levels = get_levels(nodes)
|
||||
|
||||
level_contexts = []
|
||||
|
@ -3,8 +3,6 @@
|
||||
|
||||
"""A module containing community report generation utilities."""
|
||||
|
||||
from itertools import pairwise
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.model.schemas as schemas
|
||||
@ -17,48 +15,3 @@ def get_levels(
|
||||
levels = df[level_column].dropna().unique()
|
||||
levels = [int(lvl) for lvl in levels if lvl != -1]
|
||||
return sorted(levels, reverse=True)
|
||||
|
||||
|
||||
def restore_community_hierarchy(
|
||||
input: pd.DataFrame,
|
||||
name_column: str = schemas.TITLE,
|
||||
community_column: str = schemas.COMMUNITY_ID,
|
||||
level_column: str = schemas.COMMUNITY_LEVEL,
|
||||
) -> pd.DataFrame:
|
||||
"""Restore the community hierarchy from the node data."""
|
||||
# Group by community and level, aggregate names as lists
|
||||
community_df = (
|
||||
input.groupby([community_column, level_column])[name_column]
|
||||
.apply(set)
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
# Build dictionary with levels as integers
|
||||
community_levels = {
|
||||
level: group.set_index(community_column)[name_column].to_dict()
|
||||
for level, group in community_df.groupby(level_column)
|
||||
}
|
||||
|
||||
# get unique levels, sorted in ascending order
|
||||
levels = sorted(community_levels.keys()) # type: ignore
|
||||
community_hierarchy = []
|
||||
|
||||
# Iterate through adjacent levels
|
||||
for current_level, next_level in pairwise(levels):
|
||||
current_communities = community_levels[current_level]
|
||||
next_communities = community_levels[next_level]
|
||||
|
||||
# Find sub-communities
|
||||
for curr_comm, curr_entities in current_communities.items():
|
||||
for next_comm, next_entities in next_communities.items():
|
||||
if next_entities.issubset(curr_entities):
|
||||
community_hierarchy.append({
|
||||
community_column: curr_comm,
|
||||
schemas.COMMUNITY_LEVEL: current_level,
|
||||
schemas.SUB_COMMUNITY: next_comm,
|
||||
schemas.SUB_COMMUNITY_SIZE: len(next_entities),
|
||||
})
|
||||
|
||||
return pd.DataFrame(
|
||||
community_hierarchy,
|
||||
)
|
||||
|
@ -13,9 +13,15 @@ from graphrag.model.named import Named
|
||||
class Community(Named):
|
||||
"""A protocol for a community in the system."""
|
||||
|
||||
level: str = ""
|
||||
level: str
|
||||
"""Community level."""
|
||||
|
||||
parent: str
|
||||
"""Community ID of the parent node of this community."""
|
||||
|
||||
children: list[str]
|
||||
"""List of community IDs of the child nodes of this community."""
|
||||
|
||||
entity_ids: list[str] | None = None
|
||||
"""List of entity IDs related to the community (optional)."""
|
||||
|
||||
@ -25,9 +31,6 @@ class Community(Named):
|
||||
covariate_ids: dict[str, list[str]] | None = None
|
||||
"""Dictionary of different types of covariates related to the community (optional), e.g. claims"""
|
||||
|
||||
sub_community_ids: list[str] | None = None
|
||||
"""List of community IDs of the child nodes of this community (optional)."""
|
||||
|
||||
attributes: dict[str, Any] | None = None
|
||||
"""A dictionary of additional attributes associated with the community (optional). To be included in the search prompt."""
|
||||
|
||||
@ -48,7 +51,8 @@ class Community(Named):
|
||||
entities_key: str = "entity_ids",
|
||||
relationships_key: str = "relationship_ids",
|
||||
covariates_key: str = "covariate_ids",
|
||||
sub_communities_key: str = "sub_community_ids",
|
||||
parent_key: str = "parent",
|
||||
children_key: str = "children",
|
||||
attributes_key: str = "attributes",
|
||||
size_key: str = "size",
|
||||
period_key: str = "period",
|
||||
@ -57,12 +61,13 @@ class Community(Named):
|
||||
return Community(
|
||||
id=d[id_key],
|
||||
title=d[title_key],
|
||||
short_id=d.get(short_id_key),
|
||||
level=d[level_key],
|
||||
parent=d[parent_key],
|
||||
children=d[children_key],
|
||||
short_id=d.get(short_id_key),
|
||||
entity_ids=d.get(entities_key),
|
||||
relationship_ids=d.get(relationships_key),
|
||||
covariate_ids=d.get(covariates_key),
|
||||
sub_community_ids=d.get(sub_communities_key),
|
||||
attributes=d.get(attributes_key),
|
||||
size=d.get(size_key),
|
||||
period=d.get(period_key),
|
||||
|
@ -29,7 +29,6 @@ CLAIM_DETAILS = "claim_details"
|
||||
|
||||
# COMMUNITY HIERARCHY TABLE SCHEMA
|
||||
SUB_COMMUNITY = "sub_community"
|
||||
SUB_COMMUNITY_SIZE = "sub_community_size"
|
||||
COMMUNITY_LEVEL = "level"
|
||||
|
||||
# COMMUNITY CONTEXT TABLE SCHEMA
|
||||
|
@ -56,23 +56,7 @@ class DynamicCommunitySelection:
|
||||
self.llm_kwargs = llm_kwargs
|
||||
|
||||
self.reports = {report.community_id: report for report in community_reports}
|
||||
# mapping from community to sub communities
|
||||
self.node2children = {
|
||||
community.short_id: (
|
||||
[]
|
||||
if community.sub_community_ids is None
|
||||
else [str(x) for x in community.sub_community_ids]
|
||||
)
|
||||
for community in communities
|
||||
if community.short_id is not None
|
||||
}
|
||||
|
||||
# mapping from community to parent community
|
||||
self.node2parent: dict[str, str] = {
|
||||
sub_community: community
|
||||
for community, sub_communities in self.node2children.items()
|
||||
for sub_community in sub_communities
|
||||
}
|
||||
self.communities = {community.short_id: community for community in communities}
|
||||
|
||||
# mapping from level to communities
|
||||
self.levels: dict[str, list[str]] = {}
|
||||
@ -140,18 +124,18 @@ class DynamicCommunitySelection:
|
||||
relevant_communities.add(community)
|
||||
# find children nodes of the current node and append them to the queue
|
||||
# TODO check why some sub_communities are NOT in report_df
|
||||
if community in self.node2children:
|
||||
for sub_community in self.node2children[community]:
|
||||
if sub_community in self.reports:
|
||||
communities_to_rate.append(sub_community)
|
||||
if community in self.communities:
|
||||
for child in self.communities[community].children:
|
||||
if child in self.reports:
|
||||
communities_to_rate.append(child)
|
||||
else:
|
||||
log.debug(
|
||||
"dynamic community selection: cannot find community %s in reports",
|
||||
sub_community,
|
||||
child,
|
||||
)
|
||||
# remove parent node if the current node is deemed relevant
|
||||
if not self.keep_parent and community in self.node2parent:
|
||||
relevant_communities.discard(self.node2parent[community])
|
||||
if not self.keep_parent and community in self.communities:
|
||||
relevant_communities.discard(self.communities[community].parent)
|
||||
queue = communities_to_rate
|
||||
level += 1
|
||||
if (
|
||||
|
@ -12,9 +12,6 @@ from typing import cast
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.summarize_communities.utils import (
|
||||
restore_community_hierarchy,
|
||||
)
|
||||
from graphrag.model.community import Community
|
||||
from graphrag.model.community_report import CommunityReport
|
||||
from graphrag.model.covariate import Covariate
|
||||
@ -197,27 +194,6 @@ def read_indexer_communities(
|
||||
]
|
||||
nodes_df = nodes_df.loc[nodes_df.community.isin(reports_df.community.unique())]
|
||||
|
||||
# reconstruct the community hierarchy
|
||||
# note that restore_community_hierarchy only return communities with sub communities
|
||||
community_hierarchy = restore_community_hierarchy(input=nodes_df)
|
||||
|
||||
# small datasets can result in hierarchies that are only one deep, so the hierarchy will have no rows
|
||||
if not community_hierarchy.empty:
|
||||
community_hierarchy = (
|
||||
community_hierarchy.groupby(["community"])
|
||||
.agg({"sub_community": list})
|
||||
.reset_index()
|
||||
.rename(columns={"sub_community": "sub_community_ids"})
|
||||
)
|
||||
# add sub community IDs to community DataFrame
|
||||
communities_df = communities_df.merge(
|
||||
community_hierarchy, on="community", how="left"
|
||||
)
|
||||
# replace NaN sub community IDs with empty list
|
||||
communities_df.sub_community_ids = communities_df.sub_community_ids.apply(
|
||||
lambda x: x if isinstance(x, list) else []
|
||||
)
|
||||
|
||||
return read_communities(
|
||||
communities_df,
|
||||
id_col="id",
|
||||
@ -227,7 +203,8 @@ def read_indexer_communities(
|
||||
entities_col=None,
|
||||
relationships_col=None,
|
||||
covariates_col=None,
|
||||
sub_communities_col="sub_community_ids",
|
||||
parent_col="parent",
|
||||
children_col="children",
|
||||
attributes_cols=None,
|
||||
)
|
||||
|
||||
|
@ -12,6 +12,7 @@ from graphrag.model.entity import Entity
|
||||
from graphrag.model.relationship import Relationship
|
||||
from graphrag.model.text_unit import TextUnit
|
||||
from graphrag.query.input.loaders.utils import (
|
||||
to_list,
|
||||
to_optional_dict,
|
||||
to_optional_float,
|
||||
to_optional_int,
|
||||
@ -154,7 +155,8 @@ def read_communities(
|
||||
entities_col: str | None = "entity_ids",
|
||||
relationships_col: str | None = "relationship_ids",
|
||||
covariates_col: str | None = "covariate_ids",
|
||||
sub_communities_col: str | None = "sub_community_ids",
|
||||
parent_col: str | None = "parent",
|
||||
children_col: str | None = "children",
|
||||
attributes_cols: list[str] | None = None,
|
||||
) -> list[Community]:
|
||||
"""Read communities from a dataframe using pre-converted records."""
|
||||
@ -172,7 +174,8 @@ def read_communities(
|
||||
covariate_ids=to_optional_dict(
|
||||
row, covariates_col, key_type=str, value_type=str
|
||||
),
|
||||
sub_community_ids=to_optional_list(row, sub_communities_col),
|
||||
parent=to_str(row, parent_col),
|
||||
children=to_list(row, children_col),
|
||||
attributes=(
|
||||
{col: row.get(col) for col in attributes_cols}
|
||||
if attributes_cols
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -11,6 +11,7 @@ from .util import (
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
update_document_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -43,6 +44,8 @@ async def test_create_base_text_units_metadata():
|
||||
config.input.metadata = ["title"]
|
||||
config.chunks.prepend_metadata = True
|
||||
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
@ -65,6 +68,8 @@ async def test_create_base_text_units_metadata_included_in_chunk():
|
||||
config.chunks.prepend_metadata = True
|
||||
config.chunks.chunk_size_includes_metadata = True
|
||||
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
|
@ -13,6 +13,7 @@ from .util import (
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
update_document_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -37,8 +38,6 @@ async def test_create_final_documents():
|
||||
|
||||
|
||||
async def test_create_final_documents_with_metadata_column():
|
||||
expected = load_test_table("documents")
|
||||
|
||||
context = await create_test_context(
|
||||
storage=["text_units"],
|
||||
)
|
||||
@ -46,6 +45,11 @@ async def test_create_final_documents_with_metadata_column():
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
config.input.metadata = ["title"]
|
||||
|
||||
# simulate the metadata construction during initial input loading
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
expected = await load_table_from_storage("documents", context.storage)
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
@ -54,12 +58,12 @@ async def test_create_final_documents_with_metadata_column():
|
||||
|
||||
actual = await load_table_from_storage("documents", context.storage)
|
||||
|
||||
# we should have dropped "title" and added "attributes"
|
||||
# our test dataframe does not have attributes, so we'll assert without it
|
||||
# our test dataframe does not have metadata, so we'll assert without it
|
||||
# and separately confirm it is in the output
|
||||
compare_outputs(
|
||||
actual, expected, columns=["id", "human_readable_id", "text", "text_unit_ids"]
|
||||
actual, expected, columns=["id", "human_readable_id", "text", "metadata"]
|
||||
)
|
||||
assert len(actual.columns) == 6
|
||||
assert len(actual.columns) == 7
|
||||
assert "title" in actual.columns
|
||||
assert "text_unit_ids" in actual.columns
|
||||
assert "metadata" in actual.columns
|
||||
|
@ -7,7 +7,7 @@ from pandas.testing import assert_series_equal
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.utils.storage import write_table_to_storage
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
pd.set_option("display.max_columns", None)
|
||||
|
||||
@ -43,7 +43,6 @@ async def create_test_context(storage: list[str] | None = None) -> PipelineRunCo
|
||||
if storage:
|
||||
for name in storage:
|
||||
table = load_test_table(name)
|
||||
# normal storage interface insists on bytes
|
||||
await write_table_to_storage(table, name, context.storage)
|
||||
|
||||
return context
|
||||
@ -83,3 +82,12 @@ def compare_outputs(
|
||||
print("Actual:")
|
||||
print(actual[column])
|
||||
raise
|
||||
|
||||
|
||||
async def update_document_metadata(metadata: list[str], context: PipelineRunContext):
|
||||
"""Takes the default documents and adds the configured metadata columns for later parsing by the text units and final documents workflows."""
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
documents["metadata"] = documents[metadata].apply(lambda row: row.to_dict(), axis=1)
|
||||
await write_table_to_storage(
|
||||
documents, "documents", context.storage
|
||||
) # write to the runtime context storage only
|
||||
|
Loading…
x
Reference in New Issue
Block a user