Collapse create final community reports (#1227)

* Remove extraneous param

* Add community report mocking assertions

* Collapse primary report generation

* Collapse embeddings

* Format

* Semver

* Remove extraneous check

* Move option set
This commit is contained in:
Nathan Evans 2024-09-30 10:46:07 -07:00 committed by Alonso Guevara
parent 336e6f9ca1
commit a44788bfad
14 changed files with 376 additions and 117 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create-final-community-reports."
}

View File

@ -53,14 +53,37 @@ async def create_community_reports(
num_threads: int = 4, num_threads: int = 4,
**_kwargs, **_kwargs,
) -> TableContainer: ) -> TableContainer:
"""Generate entities for each row, and optionally a graph of those entities.""" """Generate community summaries."""
log.debug("create_community_reports strategy=%s", strategy) log.debug("create_community_reports strategy=%s", strategy)
local_contexts = cast(pd.DataFrame, input.get_input()) local_contexts = cast(pd.DataFrame, input.get_input())
nodes_ctr = get_required_input_table(input, "nodes") nodes = get_required_input_table(input, "nodes").table
nodes = cast(pd.DataFrame, nodes_ctr.table) community_hierarchy = get_required_input_table(input, "community_hierarchy").table
community_hierarchy_ctr = get_required_input_table(input, "community_hierarchy")
community_hierarchy = cast(pd.DataFrame, community_hierarchy_ctr.table)
output = await create_community_reports_df(
local_contexts,
nodes,
community_hierarchy,
callbacks,
cache,
strategy,
async_mode=async_mode,
num_threads=num_threads,
)
return TableContainer(table=output)
async def create_community_reports_df(
local_contexts,
nodes,
community_hierarchy,
callbacks: VerbCallbacks,
cache: PipelineCache,
strategy: dict,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
):
"""Generate community summaries."""
levels = get_levels(nodes) levels = get_levels(nodes)
reports: list[CommunityReport | None] = [] reports: list[CommunityReport | None] = []
tick = progress_ticker(callbacks.progress, len(local_contexts)) tick = progress_ticker(callbacks.progress, len(local_contexts))
@ -99,7 +122,7 @@ async def create_community_reports(
) )
reports.extend([lr for lr in local_reports if lr is not None]) reports.extend([lr for lr in local_reports if lr is not None])
return TableContainer(table=pd.DataFrame(reports)) return pd.DataFrame(reports)
async def _generate_report( async def _generate_report(

View File

@ -37,25 +37,38 @@ def prepare_community_reports(
max_tokens: int = 16_000, max_tokens: int = 16_000,
**_kwargs, **_kwargs,
) -> TableContainer: ) -> TableContainer:
"""Generate entities for each row, and optionally a graph of those entities.""" """Prep communities for report generation."""
# Prepare Community Reports # Prepare Community Reports
node_df = cast(pd.DataFrame, get_required_input_table(input, "nodes").table) nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
edge_df = cast(pd.DataFrame, get_required_input_table(input, "edges").table) edges = cast(pd.DataFrame, get_required_input_table(input, "edges").table)
claim_df = get_named_input_table(input, "claims") claims = get_named_input_table(input, "claims")
if claim_df is not None: if claims:
claim_df = cast(pd.DataFrame, claim_df.table) claims = cast(pd.DataFrame, claims.table)
levels = get_levels(node_df, schemas.NODE_LEVEL) output = prepare_community_reports_df(nodes, edges, claims, callbacks, max_tokens)
return TableContainer(table=output)
def prepare_community_reports_df(
nodes,
edges,
claims,
callbacks: VerbCallbacks,
max_tokens: int = 16_000,
):
"""Prep communities for report generation."""
levels = get_levels(nodes, schemas.NODE_LEVEL)
dfs = [] dfs = []
for level in progress_iterable(levels, callbacks.progress, len(levels)): for level in progress_iterable(levels, callbacks.progress, len(levels)):
communities_at_level_df = _prepare_reports_at_level( communities_at_level_df = _prepare_reports_at_level(
node_df, edge_df, claim_df, level, max_tokens nodes, edges, claims, level, max_tokens
) )
dfs.append(communities_at_level_df) dfs.append(communities_at_level_df)
# build initial local context for all communities # build initial local context for all communities
return TableContainer(table=pd.concat(dfs)) return pd.concat(dfs)
def _prepare_reports_at_level( def _prepare_reports_at_level(

View File

@ -24,8 +24,25 @@ def restore_community_hierarchy(
) -> TableContainer: ) -> TableContainer:
"""Restore the community hierarchy from the node data.""" """Restore the community hierarchy from the node data."""
node_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()) node_df: pd.DataFrame = cast(pd.DataFrame, input.get_input())
output = restore_community_hierarchy_df(
node_df,
name_column=name_column,
community_column=community_column,
level_column=level_column,
)
return TableContainer(table=output)
def restore_community_hierarchy_df(
input: pd.DataFrame,
name_column: str = schemas.NODE_NAME,
community_column: str = schemas.NODE_COMMUNITY,
level_column: str = schemas.NODE_LEVEL,
) -> pd.DataFrame:
"""Restore the community hierarchy from the node data."""
community_df = ( community_df = (
node_df.groupby([community_column, level_column]) input.groupby([community_column, level_column])
.agg({name_column: list}) .agg({name_column: list})
.reset_index() .reset_index()
) )
@ -75,4 +92,4 @@ def restore_community_hierarchy(
if entities_found == len(current_entities): if entities_found == len(current_entities):
break break
return TableContainer(table=pd.DataFrame(community_hierarchy)) return pd.DataFrame(community_hierarchy)

View File

@ -6,6 +6,7 @@
import json import json
DEFAULT_CHUNK_SIZE = 3000 DEFAULT_CHUNK_SIZE = 3000
MOCK_RESPONSES = [ MOCK_RESPONSES = [
json.dumps({ json.dumps({
"title": "<report_title>", "title": "<report_title>",
@ -18,7 +19,7 @@ MOCK_RESPONSES = [
"explanation": "<insight_1_explanation", "explanation": "<insight_1_explanation",
}, },
{ {
"summary": "<farts insight_2_summary>", "summary": "<insight_2_summary>",
"explanation": "<insight_2_explanation", "explanation": "<insight_2_explanation",
}, },
], ],

View File

@ -19,10 +19,6 @@ def build_steps(
""" """
covariates_enabled = config.get("covariates_enabled", False) covariates_enabled = config.get("covariates_enabled", False)
create_community_reports_config = config.get("create_community_reports", {}) create_community_reports_config = config.get("create_community_reports", {})
community_report_strategy = create_community_reports_config.get("strategy", {})
community_report_max_input_length = community_report_strategy.get(
"max_input_length", 16_000
)
base_text_embed = config.get("text_embed", {}) base_text_embed = config.get("text_embed", {})
community_report_full_content_embed_config = config.get( community_report_full_content_embed_config = config.get(
"community_report_full_content_embed", base_text_embed "community_report_full_content_embed", base_text_embed
@ -36,103 +32,26 @@ def build_steps(
skip_title_embedding = config.get("skip_title_embedding", False) skip_title_embedding = config.get("skip_title_embedding", False)
skip_summary_embedding = config.get("skip_summary_embedding", False) skip_summary_embedding = config.get("skip_summary_embedding", False)
skip_full_content_embedding = config.get("skip_full_content_embedding", False) skip_full_content_embedding = config.get("skip_full_content_embedding", False)
input = {
"source": "workflow:create_final_nodes",
"relationships": "workflow:create_final_relationships",
}
if covariates_enabled:
input["covariates"] = "workflow:create_final_covariates"
return [ return [
#
# Subworkflow: Prepare Nodes
#
{ {
"id": "nodes", "verb": "create_final_community_reports",
"verb": "prepare_community_reports_nodes",
"input": {"source": "workflow:create_final_nodes"},
},
#
# Subworkflow: Prepare Edges
#
{
"id": "edges",
"verb": "prepare_community_reports_edges",
"input": {"source": "workflow:create_final_relationships"},
},
#
# Subworkflow: Prepare Claims Table
#
{
"id": "claims",
"enabled": covariates_enabled,
"verb": "prepare_community_reports_claims",
"input": {
"source": "workflow:create_final_covariates",
}
if covariates_enabled
else {},
},
#
# Subworkflow: Get Community Hierarchy
#
{
"id": "community_hierarchy",
"verb": "restore_community_hierarchy",
"input": {"source": "nodes"},
},
#
# Main Workflow: Create Community Reports
#
{
"id": "local_contexts",
"verb": "prepare_community_reports",
"args": {"max_tokens": community_report_max_input_length},
"input": {
"source": "nodes",
"nodes": "nodes",
"edges": "edges",
**({"claims": "claims"} if covariates_enabled else {}),
},
},
{
"verb": "create_community_reports",
"args": { "args": {
"covariates_enabled": covariates_enabled,
"skip_full_content_embedding": skip_full_content_embedding,
"skip_summary_embedding": skip_summary_embedding,
"skip_title_embedding": skip_title_embedding,
"full_content_text_embed": community_report_full_content_embed_config,
"summary_text_embed": community_report_summary_embed_config,
"title_text_embed": community_report_title_embed_config,
**create_community_reports_config, **create_community_reports_config,
}, },
"input": { "input": input,
"source": "local_contexts",
"community_hierarchy": "community_hierarchy",
"nodes": "nodes",
},
},
{
# Generate a unique ID for each community report distinct from the community ID
"verb": "window",
"args": {"to": "id", "operation": "uuid", "column": "community"},
},
{
"verb": "text_embed",
"enabled": not skip_full_content_embedding,
"args": {
"embedding_name": "community_report_full_content",
"column": "full_content",
"to": "full_content_embedding",
**community_report_full_content_embed_config,
},
},
{
"verb": "text_embed",
"enabled": not skip_summary_embedding,
"args": {
"embedding_name": "community_report_summary",
"column": "summary",
"to": "summary_embedding",
**community_report_summary_embed_config,
},
},
{
"verb": "text_embed",
"enabled": not skip_title_embedding,
"args": {
"embedding_name": "community_report_title",
"column": "title",
"to": "title_embedding",
**community_report_title_embed_config,
},
}, },
] ]

View File

@ -26,7 +26,6 @@ def build_steps(
{ {
"verb": "create_final_documents", "verb": "create_final_documents",
"args": { "args": {
"columns": {"text_units": "text_unit_ids"},
"skip_embedding": skip_raw_content_embedding, "skip_embedding": skip_raw_content_embedding,
"text_embed": document_raw_content_embed_config, "text_embed": document_raw_content_embed_config,
}, },

View File

@ -6,6 +6,7 @@
from .create_base_documents import create_base_documents from .create_base_documents import create_base_documents
from .create_base_text_units import create_base_text_units from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities from .create_final_communities import create_final_communities
from .create_final_community_reports import create_final_community_reports
from .create_final_covariates import create_final_covariates from .create_final_covariates import create_final_covariates
from .create_final_documents import create_final_documents from .create_final_documents import create_final_documents
from .create_final_entities import create_final_entities from .create_final_entities import create_final_entities
@ -19,6 +20,7 @@ __all__ = [
"create_base_documents", "create_base_documents",
"create_base_text_units", "create_base_text_units",
"create_final_communities", "create_final_communities",
"create_final_community_reports",
"create_final_covariates", "create_final_covariates",
"create_final_documents", "create_final_documents",
"create_final_entities", "create_final_entities",

View File

@ -0,0 +1,188 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform community reports."""
from typing import cast
from uuid import uuid4
import pandas as pd
from datashaper import (
AsyncType,
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.graph.extractors.community_reports.schemas import (
CLAIM_DESCRIPTION,
CLAIM_DETAILS,
CLAIM_ID,
CLAIM_STATUS,
CLAIM_SUBJECT,
CLAIM_TYPE,
EDGE_DEGREE,
EDGE_DESCRIPTION,
EDGE_DETAILS,
EDGE_ID,
EDGE_SOURCE,
EDGE_TARGET,
NODE_DEGREE,
NODE_DESCRIPTION,
NODE_DETAILS,
NODE_ID,
NODE_NAME,
)
from graphrag.index.utils.ds_util import get_required_input_table
from graphrag.index.verbs.graph.report.create_community_reports import (
create_community_reports_df,
)
from graphrag.index.verbs.graph.report.prepare_community_reports import (
prepare_community_reports_df,
)
from graphrag.index.verbs.graph.report.restore_community_hierarchy import (
restore_community_hierarchy_df,
)
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
@verb(name="create_final_community_reports", treats_input_tables_as_immutable=True)
async def create_final_community_reports(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
strategy: dict,
full_content_text_embed: dict,
summary_text_embed: dict,
title_text_embed: dict,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
skip_full_content_embedding: bool = False,
skip_summary_embedding: bool = False,
skip_title_embedding: bool = False,
covariates_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform community reports."""
nodes = _prep_nodes(cast(pd.DataFrame, input.get_input()))
edges = _prep_edges(
cast(pd.DataFrame, get_required_input_table(input, "relationships").table)
)
claims = None
if covariates_enabled:
claims = _prep_claims(
cast(pd.DataFrame, get_required_input_table(input, "covariates").table)
)
community_hierarchy = restore_community_hierarchy_df(nodes)
local_contexts = prepare_community_reports_df(
nodes, edges, claims, callbacks, strategy.get("max_input_length", 16_000)
)
community_reports = await create_community_reports_df(
local_contexts,
nodes,
community_hierarchy,
callbacks,
cache,
strategy,
async_mode=async_mode,
num_threads=num_threads,
)
community_reports["id"] = community_reports["community"].apply(
lambda _x: str(uuid4())
)
# Embed full content if not skipped
if not skip_full_content_embedding:
community_reports = await text_embed_df(
community_reports,
callbacks,
cache,
column="full_content",
strategy=full_content_text_embed["strategy"],
to="full_content_embedding",
embedding_name="community_report_full_content",
)
# Embed summary if not skipped
if not skip_summary_embedding:
community_reports = await text_embed_df(
community_reports,
callbacks,
cache,
column="summary",
strategy=summary_text_embed["strategy"],
to="summary_embedding",
embedding_name="community_report_summary",
)
# Embed title if not skipped
if not skip_title_embedding:
community_reports = await text_embed_df(
community_reports,
callbacks,
cache,
column="title",
strategy=title_text_embed["strategy"],
to="title_embedding",
embedding_name="community_report_title",
)
return create_verb_result(
cast(
Table,
community_reports,
)
)
def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
# merge values of four columns into a map column
input[NODE_DETAILS] = input.apply(
lambda x: {
NODE_ID: x[NODE_ID],
NODE_NAME: x[NODE_NAME],
NODE_DESCRIPTION: x[NODE_DESCRIPTION],
NODE_DEGREE: x[NODE_DEGREE],
},
axis=1,
)
return input
def _prep_edges(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
input[EDGE_DETAILS] = input.apply(
lambda x: {
EDGE_ID: x[EDGE_ID],
EDGE_SOURCE: x[EDGE_SOURCE],
EDGE_TARGET: x[EDGE_TARGET],
EDGE_DESCRIPTION: x[EDGE_DESCRIPTION],
EDGE_DEGREE: x[EDGE_DEGREE],
},
axis=1,
)
return input
def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
input[CLAIM_DETAILS] = input.apply(
lambda x: {
CLAIM_ID: x[CLAIM_ID],
CLAIM_SUBJECT: x[CLAIM_SUBJECT],
CLAIM_TYPE: x[CLAIM_TYPE],
CLAIM_STATUS: x[CLAIM_STATUS],
CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION],
},
axis=1,
)
return input

View File

@ -3,6 +3,7 @@
"""LLM Static Response method definition.""" """LLM Static Response method definition."""
import json
import logging import logging
from typing_extensions import Unpack from typing_extensions import Unpack
@ -12,6 +13,7 @@ from graphrag.llm.types import (
CompletionInput, CompletionInput,
CompletionOutput, CompletionOutput,
LLMInput, LLMInput,
LLMOutput,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,3 +37,6 @@ class MockCompletionLLM(
**kwargs: Unpack[LLMInput], **kwargs: Unpack[LLMInput],
) -> CompletionOutput: ) -> CompletionOutput:
return self.responses[0] return self.responses[0]
async def _invoke_json(self, input: CompletionInput, **kwargs: Unpack[LLMInput]):
return LLMOutput(output=self.responses[0], json=json.loads(self.responses[0]))

View File

@ -93,7 +93,7 @@
"rank_explanation", "rank_explanation",
"findings" "findings"
], ],
"subworkflows": 6, "subworkflows": 1,
"max_runtime": 300 "max_runtime": 300
}, },
"create_final_text_units": { "create_final_text_units": {

View File

@ -110,7 +110,7 @@
"rank_explanation", "rank_explanation",
"findings" "findings"
], ],
"subworkflows": 7, "subworkflows": 1,
"max_runtime": 300 "max_runtime": 300
}, },
"create_final_text_units": { "create_final_text_units": {

View File

@ -0,0 +1,86 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.workflows.v1.create_final_community_reports import (
build_steps,
workflow_name,
)
from .util import (
compare_outputs,
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
remove_disabled_steps,
)
async def test_create_final_community_reports():
input_tables = load_input_tables([
"workflow:create_final_nodes",
"workflow:create_final_covariates",
"workflow:create_final_relationships",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
# deleting the llm config results in a default mock injection in run_graph_intelligence
del config["create_community_reports"]["strategy"]["llm"]
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert len(actual.columns) == len(expected.columns)
# only assert a couple of columns that are not mock - most of this table is LLM-generated
compare_outputs(actual, expected, columns=["community", "level"])
# assert a handful of mock data items to confirm they get put in the right spot
assert actual["rank"][:1][0] == 2
assert actual["rank_explanation"][:1][0] == "<rating_explanation>"
async def test_create_final_community_reports_with_embeddings():
input_tables = load_input_tables([
"workflow:create_final_nodes",
"workflow:create_final_covariates",
"workflow:create_final_relationships",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
# deleting the llm config results in a default mock injection in run_graph_intelligence
del config["create_community_reports"]["strategy"]["llm"]
config["skip_full_content_embedding"] = False
config["community_report_full_content_embed"]["strategy"]["type"] = "mock"
config["skip_summary_embedding"] = False
config["community_report_summary_embed"]["strategy"]["type"] = "mock"
config["skip_title_embedding"] = False
config["community_report_title_embed"]["strategy"]["type"] = "mock"
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert len(actual.columns) == len(expected.columns) + 3
assert "full_content_embedding" in actual.columns
assert len(actual["full_content_embedding"][:1][0]) == 3
assert "summary_embedding" in actual.columns
assert len(actual["summary_embedding"][:1][0]) == 3
assert "title_embedding" in actual.columns
assert len(actual["title_embedding"][:1][0]) == 3

View File

@ -15,6 +15,8 @@ from graphrag.index import (
) )
from graphrag.index.run.utils import _create_run_context from graphrag.index.run.utils import _create_run_context
pd.set_option("display.max_columns", None)
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]: def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
"""Harvest all the referenced input IDs from the workflow being tested and pass them here.""" """Harvest all the referenced input IDs from the workflow being tested and pass them here."""