mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00

* Move covariate run conditional * All pipeline registration * Fix method name construction * Rename context storage -> output_storage * Rename OutputConfig as generic StorageConfig * Reuse Storage model under InputConfig * Move input storage creation out of document loading * Move document loading into workflows * Semver * Fix smoke test config for new workflows * Fix unit tests --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
152 lines
5.3 KiB
Python
152 lines
5.3 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""A module containing run_workflow method definition."""
|
|
|
|
from datetime import datetime, timezone
|
|
from typing import cast
|
|
from uuid import uuid4
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
|
from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS
|
|
from graphrag.index.operations.cluster_graph import cluster_graph
|
|
from graphrag.index.operations.create_graph import create_graph
|
|
from graphrag.index.typing.context import PipelineRunContext
|
|
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
|
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
|
|
|
|
|
async def run_workflow(
|
|
config: GraphRagConfig,
|
|
context: PipelineRunContext,
|
|
) -> WorkflowFunctionOutput:
|
|
"""All the steps to transform final communities."""
|
|
entities = await load_table_from_storage("entities", context.output_storage)
|
|
relationships = await load_table_from_storage(
|
|
"relationships", context.output_storage
|
|
)
|
|
|
|
max_cluster_size = config.cluster_graph.max_cluster_size
|
|
use_lcc = config.cluster_graph.use_lcc
|
|
seed = config.cluster_graph.seed
|
|
|
|
output = create_communities(
|
|
entities,
|
|
relationships,
|
|
max_cluster_size=max_cluster_size,
|
|
use_lcc=use_lcc,
|
|
seed=seed,
|
|
)
|
|
|
|
await write_table_to_storage(output, "communities", context.output_storage)
|
|
|
|
return WorkflowFunctionOutput(result=output)
|
|
|
|
|
|
def create_communities(
|
|
entities: pd.DataFrame,
|
|
relationships: pd.DataFrame,
|
|
max_cluster_size: int,
|
|
use_lcc: bool,
|
|
seed: int | None = None,
|
|
) -> pd.DataFrame:
|
|
"""All the steps to transform final communities."""
|
|
graph = create_graph(relationships, edge_attr=["weight"])
|
|
|
|
clusters = cluster_graph(
|
|
graph,
|
|
max_cluster_size,
|
|
use_lcc,
|
|
seed=seed,
|
|
)
|
|
|
|
communities = pd.DataFrame(
|
|
clusters, columns=pd.Index(["level", "community", "parent", "title"])
|
|
).explode("title")
|
|
communities["community"] = communities["community"].astype(int)
|
|
|
|
# aggregate entity ids for each community
|
|
entity_ids = communities.merge(entities, on="title", how="inner")
|
|
entity_ids = (
|
|
entity_ids.groupby("community").agg(entity_ids=("id", list)).reset_index()
|
|
)
|
|
|
|
# aggregate relationships ids for each community
|
|
# these are limited to only those where the source and target are in the same community
|
|
max_level = communities["level"].max()
|
|
all_grouped = pd.DataFrame(
|
|
columns=["community", "level", "relationship_ids", "text_unit_ids"] # type: ignore
|
|
)
|
|
for level in range(max_level + 1):
|
|
communities_at_level = communities.loc[communities["level"] == level]
|
|
sources = relationships.merge(
|
|
communities_at_level, left_on="source", right_on="title", how="inner"
|
|
)
|
|
targets = sources.merge(
|
|
communities_at_level, left_on="target", right_on="title", how="inner"
|
|
)
|
|
matched = targets.loc[targets["community_x"] == targets["community_y"]]
|
|
text_units = matched.explode("text_unit_ids")
|
|
grouped = (
|
|
text_units.groupby(["community_x", "level_x", "parent_x"])
|
|
.agg(relationship_ids=("id", list), text_unit_ids=("text_unit_ids", list))
|
|
.reset_index()
|
|
)
|
|
grouped.rename(
|
|
columns={
|
|
"community_x": "community",
|
|
"level_x": "level",
|
|
"parent_x": "parent",
|
|
},
|
|
inplace=True,
|
|
)
|
|
all_grouped = pd.concat([
|
|
all_grouped,
|
|
grouped.loc[
|
|
:, ["community", "level", "parent", "relationship_ids", "text_unit_ids"]
|
|
],
|
|
])
|
|
|
|
# deduplicate the lists
|
|
all_grouped["relationship_ids"] = all_grouped["relationship_ids"].apply(
|
|
lambda x: sorted(set(x))
|
|
)
|
|
all_grouped["text_unit_ids"] = all_grouped["text_unit_ids"].apply(
|
|
lambda x: sorted(set(x))
|
|
)
|
|
|
|
# join it all up and add some new fields
|
|
final_communities = all_grouped.merge(entity_ids, on="community", how="inner")
|
|
final_communities["id"] = [str(uuid4()) for _ in range(len(final_communities))]
|
|
final_communities["human_readable_id"] = final_communities["community"]
|
|
final_communities["title"] = "Community " + final_communities["community"].astype(
|
|
str
|
|
)
|
|
final_communities["parent"] = final_communities["parent"].astype(int)
|
|
# collect the children so we have a tree going both ways
|
|
parent_grouped = cast(
|
|
"pd.DataFrame",
|
|
final_communities.groupby("parent").agg(children=("community", "unique")),
|
|
)
|
|
final_communities = final_communities.merge(
|
|
parent_grouped,
|
|
left_on="community",
|
|
right_on="parent",
|
|
how="left",
|
|
)
|
|
# replace NaN children with empty list
|
|
final_communities["children"] = final_communities["children"].apply(
|
|
lambda x: x if isinstance(x, np.ndarray) else [] # type: ignore
|
|
)
|
|
# add fields for incremental update tracking
|
|
final_communities["period"] = datetime.now(timezone.utc).date().isoformat()
|
|
final_communities["size"] = final_communities.loc[:, "entity_ids"].apply(len)
|
|
|
|
return final_communities.loc[
|
|
:,
|
|
COMMUNITIES_FINAL_COLUMNS,
|
|
]
|