Collapse create final community reports (#1227)

* Remove extraneous param

* Add community report mocking assertions

* Collapse primary report generation

* Collapse embeddings

* Format

* Semver

* Remove extraneous check

* Move option set
This commit is contained in:
Nathan Evans 2024-09-30 10:46:07 -07:00 committed by Alonso Guevara
parent 336e6f9ca1
commit a44788bfad
14 changed files with 376 additions and 117 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create-final-community-reports."
}

View File

@ -53,14 +53,37 @@ async def create_community_reports(
num_threads: int = 4,
**_kwargs,
) -> TableContainer:
"""Generate entities for each row, and optionally a graph of those entities."""
"""Generate community summaries."""
log.debug("create_community_reports strategy=%s", strategy)
local_contexts = cast(pd.DataFrame, input.get_input())
nodes_ctr = get_required_input_table(input, "nodes")
nodes = cast(pd.DataFrame, nodes_ctr.table)
community_hierarchy_ctr = get_required_input_table(input, "community_hierarchy")
community_hierarchy = cast(pd.DataFrame, community_hierarchy_ctr.table)
nodes = get_required_input_table(input, "nodes").table
community_hierarchy = get_required_input_table(input, "community_hierarchy").table
output = await create_community_reports_df(
local_contexts,
nodes,
community_hierarchy,
callbacks,
cache,
strategy,
async_mode=async_mode,
num_threads=num_threads,
)
return TableContainer(table=output)
async def create_community_reports_df(
local_contexts,
nodes,
community_hierarchy,
callbacks: VerbCallbacks,
cache: PipelineCache,
strategy: dict,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
):
"""Generate community summaries."""
levels = get_levels(nodes)
reports: list[CommunityReport | None] = []
tick = progress_ticker(callbacks.progress, len(local_contexts))
@ -99,7 +122,7 @@ async def create_community_reports(
)
reports.extend([lr for lr in local_reports if lr is not None])
return TableContainer(table=pd.DataFrame(reports))
return pd.DataFrame(reports)
async def _generate_report(

View File

@ -37,25 +37,38 @@ def prepare_community_reports(
max_tokens: int = 16_000,
**_kwargs,
) -> TableContainer:
"""Generate entities for each row, and optionally a graph of those entities."""
"""Prep communities for report generation."""
# Prepare Community Reports
node_df = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
edge_df = cast(pd.DataFrame, get_required_input_table(input, "edges").table)
claim_df = get_named_input_table(input, "claims")
if claim_df is not None:
claim_df = cast(pd.DataFrame, claim_df.table)
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
edges = cast(pd.DataFrame, get_required_input_table(input, "edges").table)
claims = get_named_input_table(input, "claims")
if claims:
claims = cast(pd.DataFrame, claims.table)
levels = get_levels(node_df, schemas.NODE_LEVEL)
output = prepare_community_reports_df(nodes, edges, claims, callbacks, max_tokens)
return TableContainer(table=output)
def prepare_community_reports_df(
nodes,
edges,
claims,
callbacks: VerbCallbacks,
max_tokens: int = 16_000,
):
"""Prep communities for report generation."""
levels = get_levels(nodes, schemas.NODE_LEVEL)
dfs = []
for level in progress_iterable(levels, callbacks.progress, len(levels)):
communities_at_level_df = _prepare_reports_at_level(
node_df, edge_df, claim_df, level, max_tokens
nodes, edges, claims, level, max_tokens
)
dfs.append(communities_at_level_df)
# build initial local context for all communities
return TableContainer(table=pd.concat(dfs))
return pd.concat(dfs)
def _prepare_reports_at_level(

View File

@ -24,8 +24,25 @@ def restore_community_hierarchy(
) -> TableContainer:
"""Restore the community hierarchy from the node data."""
node_df: pd.DataFrame = cast(pd.DataFrame, input.get_input())
output = restore_community_hierarchy_df(
node_df,
name_column=name_column,
community_column=community_column,
level_column=level_column,
)
return TableContainer(table=output)
def restore_community_hierarchy_df(
input: pd.DataFrame,
name_column: str = schemas.NODE_NAME,
community_column: str = schemas.NODE_COMMUNITY,
level_column: str = schemas.NODE_LEVEL,
) -> pd.DataFrame:
"""Restore the community hierarchy from the node data."""
community_df = (
node_df.groupby([community_column, level_column])
input.groupby([community_column, level_column])
.agg({name_column: list})
.reset_index()
)
@ -75,4 +92,4 @@ def restore_community_hierarchy(
if entities_found == len(current_entities):
break
return TableContainer(table=pd.DataFrame(community_hierarchy))
return pd.DataFrame(community_hierarchy)

View File

@ -6,6 +6,7 @@
import json
DEFAULT_CHUNK_SIZE = 3000
MOCK_RESPONSES = [
json.dumps({
"title": "<report_title>",
@ -18,7 +19,7 @@ MOCK_RESPONSES = [
"explanation": "<insight_1_explanation",
},
{
"summary": "<farts insight_2_summary>",
"summary": "<insight_2_summary>",
"explanation": "<insight_2_explanation",
},
],

View File

@ -19,10 +19,6 @@ def build_steps(
"""
covariates_enabled = config.get("covariates_enabled", False)
create_community_reports_config = config.get("create_community_reports", {})
community_report_strategy = create_community_reports_config.get("strategy", {})
community_report_max_input_length = community_report_strategy.get(
"max_input_length", 16_000
)
base_text_embed = config.get("text_embed", {})
community_report_full_content_embed_config = config.get(
"community_report_full_content_embed", base_text_embed
@ -36,103 +32,26 @@ def build_steps(
skip_title_embedding = config.get("skip_title_embedding", False)
skip_summary_embedding = config.get("skip_summary_embedding", False)
skip_full_content_embedding = config.get("skip_full_content_embedding", False)
input = {
"source": "workflow:create_final_nodes",
"relationships": "workflow:create_final_relationships",
}
if covariates_enabled:
input["covariates"] = "workflow:create_final_covariates"
return [
#
# Subworkflow: Prepare Nodes
#
{
"id": "nodes",
"verb": "prepare_community_reports_nodes",
"input": {"source": "workflow:create_final_nodes"},
},
#
# Subworkflow: Prepare Edges
#
{
"id": "edges",
"verb": "prepare_community_reports_edges",
"input": {"source": "workflow:create_final_relationships"},
},
#
# Subworkflow: Prepare Claims Table
#
{
"id": "claims",
"enabled": covariates_enabled,
"verb": "prepare_community_reports_claims",
"input": {
"source": "workflow:create_final_covariates",
}
if covariates_enabled
else {},
},
#
# Subworkflow: Get Community Hierarchy
#
{
"id": "community_hierarchy",
"verb": "restore_community_hierarchy",
"input": {"source": "nodes"},
},
#
# Main Workflow: Create Community Reports
#
{
"id": "local_contexts",
"verb": "prepare_community_reports",
"args": {"max_tokens": community_report_max_input_length},
"input": {
"source": "nodes",
"nodes": "nodes",
"edges": "edges",
**({"claims": "claims"} if covariates_enabled else {}),
},
},
{
"verb": "create_community_reports",
"verb": "create_final_community_reports",
"args": {
"covariates_enabled": covariates_enabled,
"skip_full_content_embedding": skip_full_content_embedding,
"skip_summary_embedding": skip_summary_embedding,
"skip_title_embedding": skip_title_embedding,
"full_content_text_embed": community_report_full_content_embed_config,
"summary_text_embed": community_report_summary_embed_config,
"title_text_embed": community_report_title_embed_config,
**create_community_reports_config,
},
"input": {
"source": "local_contexts",
"community_hierarchy": "community_hierarchy",
"nodes": "nodes",
},
},
{
# Generate a unique ID for each community report distinct from the community ID
"verb": "window",
"args": {"to": "id", "operation": "uuid", "column": "community"},
},
{
"verb": "text_embed",
"enabled": not skip_full_content_embedding,
"args": {
"embedding_name": "community_report_full_content",
"column": "full_content",
"to": "full_content_embedding",
**community_report_full_content_embed_config,
},
},
{
"verb": "text_embed",
"enabled": not skip_summary_embedding,
"args": {
"embedding_name": "community_report_summary",
"column": "summary",
"to": "summary_embedding",
**community_report_summary_embed_config,
},
},
{
"verb": "text_embed",
"enabled": not skip_title_embedding,
"args": {
"embedding_name": "community_report_title",
"column": "title",
"to": "title_embedding",
**community_report_title_embed_config,
},
"input": input,
},
]

View File

@ -26,7 +26,6 @@ def build_steps(
{
"verb": "create_final_documents",
"args": {
"columns": {"text_units": "text_unit_ids"},
"skip_embedding": skip_raw_content_embedding,
"text_embed": document_raw_content_embed_config,
},

View File

@ -6,6 +6,7 @@
from .create_base_documents import create_base_documents
from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_community_reports import create_final_community_reports
from .create_final_covariates import create_final_covariates
from .create_final_documents import create_final_documents
from .create_final_entities import create_final_entities
@ -19,6 +20,7 @@ __all__ = [
"create_base_documents",
"create_base_text_units",
"create_final_communities",
"create_final_community_reports",
"create_final_covariates",
"create_final_documents",
"create_final_entities",

View File

@ -0,0 +1,188 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform community reports."""
from typing import cast
from uuid import uuid4
import pandas as pd
from datashaper import (
AsyncType,
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.graph.extractors.community_reports.schemas import (
CLAIM_DESCRIPTION,
CLAIM_DETAILS,
CLAIM_ID,
CLAIM_STATUS,
CLAIM_SUBJECT,
CLAIM_TYPE,
EDGE_DEGREE,
EDGE_DESCRIPTION,
EDGE_DETAILS,
EDGE_ID,
EDGE_SOURCE,
EDGE_TARGET,
NODE_DEGREE,
NODE_DESCRIPTION,
NODE_DETAILS,
NODE_ID,
NODE_NAME,
)
from graphrag.index.utils.ds_util import get_required_input_table
from graphrag.index.verbs.graph.report.create_community_reports import (
create_community_reports_df,
)
from graphrag.index.verbs.graph.report.prepare_community_reports import (
prepare_community_reports_df,
)
from graphrag.index.verbs.graph.report.restore_community_hierarchy import (
restore_community_hierarchy_df,
)
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
@verb(name="create_final_community_reports", treats_input_tables_as_immutable=True)
async def create_final_community_reports(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
strategy: dict,
full_content_text_embed: dict,
summary_text_embed: dict,
title_text_embed: dict,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
skip_full_content_embedding: bool = False,
skip_summary_embedding: bool = False,
skip_title_embedding: bool = False,
covariates_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform community reports."""
nodes = _prep_nodes(cast(pd.DataFrame, input.get_input()))
edges = _prep_edges(
cast(pd.DataFrame, get_required_input_table(input, "relationships").table)
)
claims = None
if covariates_enabled:
claims = _prep_claims(
cast(pd.DataFrame, get_required_input_table(input, "covariates").table)
)
community_hierarchy = restore_community_hierarchy_df(nodes)
local_contexts = prepare_community_reports_df(
nodes, edges, claims, callbacks, strategy.get("max_input_length", 16_000)
)
community_reports = await create_community_reports_df(
local_contexts,
nodes,
community_hierarchy,
callbacks,
cache,
strategy,
async_mode=async_mode,
num_threads=num_threads,
)
community_reports["id"] = community_reports["community"].apply(
lambda _x: str(uuid4())
)
# Embed full content if not skipped
if not skip_full_content_embedding:
community_reports = await text_embed_df(
community_reports,
callbacks,
cache,
column="full_content",
strategy=full_content_text_embed["strategy"],
to="full_content_embedding",
embedding_name="community_report_full_content",
)
# Embed summary if not skipped
if not skip_summary_embedding:
community_reports = await text_embed_df(
community_reports,
callbacks,
cache,
column="summary",
strategy=summary_text_embed["strategy"],
to="summary_embedding",
embedding_name="community_report_summary",
)
# Embed title if not skipped
if not skip_title_embedding:
community_reports = await text_embed_df(
community_reports,
callbacks,
cache,
column="title",
strategy=title_text_embed["strategy"],
to="title_embedding",
embedding_name="community_report_title",
)
return create_verb_result(
cast(
Table,
community_reports,
)
)
def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
# merge values of four columns into a map column
input[NODE_DETAILS] = input.apply(
lambda x: {
NODE_ID: x[NODE_ID],
NODE_NAME: x[NODE_NAME],
NODE_DESCRIPTION: x[NODE_DESCRIPTION],
NODE_DEGREE: x[NODE_DEGREE],
},
axis=1,
)
return input
def _prep_edges(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
input[EDGE_DETAILS] = input.apply(
lambda x: {
EDGE_ID: x[EDGE_ID],
EDGE_SOURCE: x[EDGE_SOURCE],
EDGE_TARGET: x[EDGE_TARGET],
EDGE_DESCRIPTION: x[EDGE_DESCRIPTION],
EDGE_DEGREE: x[EDGE_DEGREE],
},
axis=1,
)
return input
def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
input[CLAIM_DETAILS] = input.apply(
lambda x: {
CLAIM_ID: x[CLAIM_ID],
CLAIM_SUBJECT: x[CLAIM_SUBJECT],
CLAIM_TYPE: x[CLAIM_TYPE],
CLAIM_STATUS: x[CLAIM_STATUS],
CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION],
},
axis=1,
)
return input

View File

@ -3,6 +3,7 @@
"""LLM Static Response method definition."""
import json
import logging
from typing_extensions import Unpack
@ -12,6 +13,7 @@ from graphrag.llm.types import (
CompletionInput,
CompletionOutput,
LLMInput,
LLMOutput,
)
log = logging.getLogger(__name__)
@ -35,3 +37,6 @@ class MockCompletionLLM(
**kwargs: Unpack[LLMInput],
) -> CompletionOutput:
return self.responses[0]
async def _invoke_json(self, input: CompletionInput, **kwargs: Unpack[LLMInput]):
return LLMOutput(output=self.responses[0], json=json.loads(self.responses[0]))

View File

@ -93,7 +93,7 @@
"rank_explanation",
"findings"
],
"subworkflows": 6,
"subworkflows": 1,
"max_runtime": 300
},
"create_final_text_units": {

View File

@ -110,7 +110,7 @@
"rank_explanation",
"findings"
],
"subworkflows": 7,
"subworkflows": 1,
"max_runtime": 300
},
"create_final_text_units": {

View File

@ -0,0 +1,86 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.workflows.v1.create_final_community_reports import (
build_steps,
workflow_name,
)
from .util import (
compare_outputs,
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
remove_disabled_steps,
)
async def test_create_final_community_reports():
input_tables = load_input_tables([
"workflow:create_final_nodes",
"workflow:create_final_covariates",
"workflow:create_final_relationships",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
# deleting the llm config results in a default mock injection in run_graph_intelligence
del config["create_community_reports"]["strategy"]["llm"]
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert len(actual.columns) == len(expected.columns)
# only assert a couple of columns that are not mock - most of this table is LLM-generated
compare_outputs(actual, expected, columns=["community", "level"])
# assert a handful of mock data items to confirm they get put in the right spot
assert actual["rank"][:1][0] == 2
assert actual["rank_explanation"][:1][0] == "<rating_explanation>"
async def test_create_final_community_reports_with_embeddings():
input_tables = load_input_tables([
"workflow:create_final_nodes",
"workflow:create_final_covariates",
"workflow:create_final_relationships",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
# deleting the llm config results in a default mock injection in run_graph_intelligence
del config["create_community_reports"]["strategy"]["llm"]
config["skip_full_content_embedding"] = False
config["community_report_full_content_embed"]["strategy"]["type"] = "mock"
config["skip_summary_embedding"] = False
config["community_report_summary_embed"]["strategy"]["type"] = "mock"
config["skip_title_embedding"] = False
config["community_report_title_embed"]["strategy"]["type"] = "mock"
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert len(actual.columns) == len(expected.columns) + 3
assert "full_content_embedding" in actual.columns
assert len(actual["full_content_embedding"][:1][0]) == 3
assert "summary_embedding" in actual.columns
assert len(actual["summary_embedding"][:1][0]) == 3
assert "title_embedding" in actual.columns
assert len(actual["title_embedding"][:1][0]) == 3

View File

@ -15,6 +15,8 @@ from graphrag.index import (
)
from graphrag.index.run.utils import _create_run_context
pd.set_option("display.max_columns", None)
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
"""Harvest all the referenced input IDs from the workflow being tested and pass them here."""