From a44788bfad6543ab2fe2ad75ecdc43bbc59084a7 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Mon, 30 Sep 2024 10:46:07 -0700 Subject: [PATCH] Collapse create final community reports (#1227) * Remove extraneous param * Add community report mocking assertions * Collapse primary report generation * Collapse embeddings * Format * Semver * Remove extraneous check * Move option set --- .../patch-20240927182543835770.json | 4 + .../graph/report/create_community_reports.py | 35 +++- .../graph/report/prepare_community_reports.py | 31 ++- .../report/restore_community_hierarchy.py | 21 +- .../strategies/graph_intelligence/defaults.py | 3 +- .../v1/create_final_community_reports.py | 111 ++--------- .../workflows/v1/create_final_documents.py | 1 - .../index/workflows/v1/subflows/__init__.py | 2 + .../create_final_community_reports.py | 188 ++++++++++++++++++ graphrag/llm/mock/mock_completion_llm.py | 5 + tests/fixtures/min-csv/config.json | 2 +- tests/fixtures/text/config.json | 2 +- .../test_create_final_community_reports.py | 86 ++++++++ tests/verbs/util.py | 2 + 14 files changed, 376 insertions(+), 117 deletions(-) create mode 100644 .semversioner/next-release/patch-20240927182543835770.json create mode 100644 graphrag/index/workflows/v1/subflows/create_final_community_reports.py create mode 100644 tests/verbs/test_create_final_community_reports.py diff --git a/.semversioner/next-release/patch-20240927182543835770.json b/.semversioner/next-release/patch-20240927182543835770.json new file mode 100644 index 00000000..90e686a1 --- /dev/null +++ b/.semversioner/next-release/patch-20240927182543835770.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Collapse create-final-community-reports." +} diff --git a/graphrag/index/verbs/graph/report/create_community_reports.py b/graphrag/index/verbs/graph/report/create_community_reports.py index c67d5107..1764362c 100644 --- a/graphrag/index/verbs/graph/report/create_community_reports.py +++ b/graphrag/index/verbs/graph/report/create_community_reports.py @@ -53,14 +53,37 @@ async def create_community_reports( num_threads: int = 4, **_kwargs, ) -> TableContainer: - """Generate entities for each row, and optionally a graph of those entities.""" + """Generate community summaries.""" log.debug("create_community_reports strategy=%s", strategy) local_contexts = cast(pd.DataFrame, input.get_input()) - nodes_ctr = get_required_input_table(input, "nodes") - nodes = cast(pd.DataFrame, nodes_ctr.table) - community_hierarchy_ctr = get_required_input_table(input, "community_hierarchy") - community_hierarchy = cast(pd.DataFrame, community_hierarchy_ctr.table) + nodes = get_required_input_table(input, "nodes").table + community_hierarchy = get_required_input_table(input, "community_hierarchy").table + output = await create_community_reports_df( + local_contexts, + nodes, + community_hierarchy, + callbacks, + cache, + strategy, + async_mode=async_mode, + num_threads=num_threads, + ) + + return TableContainer(table=output) + + +async def create_community_reports_df( + local_contexts, + nodes, + community_hierarchy, + callbacks: VerbCallbacks, + cache: PipelineCache, + strategy: dict, + async_mode: AsyncType = AsyncType.AsyncIO, + num_threads: int = 4, +): + """Generate community summaries.""" levels = get_levels(nodes) reports: list[CommunityReport | None] = [] tick = progress_ticker(callbacks.progress, len(local_contexts)) @@ -99,7 +122,7 @@ async def create_community_reports( ) reports.extend([lr for lr in local_reports if lr is not None]) - return TableContainer(table=pd.DataFrame(reports)) + return pd.DataFrame(reports) async def _generate_report( diff --git a/graphrag/index/verbs/graph/report/prepare_community_reports.py b/graphrag/index/verbs/graph/report/prepare_community_reports.py index 3c9ebd45..a6a3a24f 100644 --- a/graphrag/index/verbs/graph/report/prepare_community_reports.py +++ b/graphrag/index/verbs/graph/report/prepare_community_reports.py @@ -37,25 +37,38 @@ def prepare_community_reports( max_tokens: int = 16_000, **_kwargs, ) -> TableContainer: - """Generate entities for each row, and optionally a graph of those entities.""" + """Prep communities for report generation.""" # Prepare Community Reports - node_df = cast(pd.DataFrame, get_required_input_table(input, "nodes").table) - edge_df = cast(pd.DataFrame, get_required_input_table(input, "edges").table) - claim_df = get_named_input_table(input, "claims") - if claim_df is not None: - claim_df = cast(pd.DataFrame, claim_df.table) + nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table) + edges = cast(pd.DataFrame, get_required_input_table(input, "edges").table) + claims = get_named_input_table(input, "claims") + if claims: + claims = cast(pd.DataFrame, claims.table) - levels = get_levels(node_df, schemas.NODE_LEVEL) + output = prepare_community_reports_df(nodes, edges, claims, callbacks, max_tokens) + + return TableContainer(table=output) + + +def prepare_community_reports_df( + nodes, + edges, + claims, + callbacks: VerbCallbacks, + max_tokens: int = 16_000, +): + """Prep communities for report generation.""" + levels = get_levels(nodes, schemas.NODE_LEVEL) dfs = [] for level in progress_iterable(levels, callbacks.progress, len(levels)): communities_at_level_df = _prepare_reports_at_level( - node_df, edge_df, claim_df, level, max_tokens + nodes, edges, claims, level, max_tokens ) dfs.append(communities_at_level_df) # build initial local context for all communities - return TableContainer(table=pd.concat(dfs)) + return pd.concat(dfs) def _prepare_reports_at_level( diff --git a/graphrag/index/verbs/graph/report/restore_community_hierarchy.py b/graphrag/index/verbs/graph/report/restore_community_hierarchy.py index 437369f0..2bf5fd00 100644 --- a/graphrag/index/verbs/graph/report/restore_community_hierarchy.py +++ b/graphrag/index/verbs/graph/report/restore_community_hierarchy.py @@ -24,8 +24,25 @@ def restore_community_hierarchy( ) -> TableContainer: """Restore the community hierarchy from the node data.""" node_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()) + + output = restore_community_hierarchy_df( + node_df, + name_column=name_column, + community_column=community_column, + level_column=level_column, + ) + return TableContainer(table=output) + + +def restore_community_hierarchy_df( + input: pd.DataFrame, + name_column: str = schemas.NODE_NAME, + community_column: str = schemas.NODE_COMMUNITY, + level_column: str = schemas.NODE_LEVEL, +) -> pd.DataFrame: + """Restore the community hierarchy from the node data.""" community_df = ( - node_df.groupby([community_column, level_column]) + input.groupby([community_column, level_column]) .agg({name_column: list}) .reset_index() ) @@ -75,4 +92,4 @@ def restore_community_hierarchy( if entities_found == len(current_entities): break - return TableContainer(table=pd.DataFrame(community_hierarchy)) + return pd.DataFrame(community_hierarchy) diff --git a/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py b/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py index 708d48d2..c184fb8e 100644 --- a/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py +++ b/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py @@ -6,6 +6,7 @@ import json DEFAULT_CHUNK_SIZE = 3000 + MOCK_RESPONSES = [ json.dumps({ "title": "", @@ -18,7 +19,7 @@ MOCK_RESPONSES = [ "explanation": "", + "summary": "", "explanation": " VerbResult: + """All the steps to transform community reports.""" + nodes = _prep_nodes(cast(pd.DataFrame, input.get_input())) + edges = _prep_edges( + cast(pd.DataFrame, get_required_input_table(input, "relationships").table) + ) + + claims = None + if covariates_enabled: + claims = _prep_claims( + cast(pd.DataFrame, get_required_input_table(input, "covariates").table) + ) + + community_hierarchy = restore_community_hierarchy_df(nodes) + + local_contexts = prepare_community_reports_df( + nodes, edges, claims, callbacks, strategy.get("max_input_length", 16_000) + ) + + community_reports = await create_community_reports_df( + local_contexts, + nodes, + community_hierarchy, + callbacks, + cache, + strategy, + async_mode=async_mode, + num_threads=num_threads, + ) + + community_reports["id"] = community_reports["community"].apply( + lambda _x: str(uuid4()) + ) + + # Embed full content if not skipped + if not skip_full_content_embedding: + community_reports = await text_embed_df( + community_reports, + callbacks, + cache, + column="full_content", + strategy=full_content_text_embed["strategy"], + to="full_content_embedding", + embedding_name="community_report_full_content", + ) + + # Embed summary if not skipped + if not skip_summary_embedding: + community_reports = await text_embed_df( + community_reports, + callbacks, + cache, + column="summary", + strategy=summary_text_embed["strategy"], + to="summary_embedding", + embedding_name="community_report_summary", + ) + + # Embed title if not skipped + if not skip_title_embedding: + community_reports = await text_embed_df( + community_reports, + callbacks, + cache, + column="title", + strategy=title_text_embed["strategy"], + to="title_embedding", + embedding_name="community_report_title", + ) + + return create_verb_result( + cast( + Table, + community_reports, + ) + ) + + +def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame: + input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) + # merge values of four columns into a map column + input[NODE_DETAILS] = input.apply( + lambda x: { + NODE_ID: x[NODE_ID], + NODE_NAME: x[NODE_NAME], + NODE_DESCRIPTION: x[NODE_DESCRIPTION], + NODE_DEGREE: x[NODE_DEGREE], + }, + axis=1, + ) + return input + + +def _prep_edges(input: pd.DataFrame) -> pd.DataFrame: + input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) + input[EDGE_DETAILS] = input.apply( + lambda x: { + EDGE_ID: x[EDGE_ID], + EDGE_SOURCE: x[EDGE_SOURCE], + EDGE_TARGET: x[EDGE_TARGET], + EDGE_DESCRIPTION: x[EDGE_DESCRIPTION], + EDGE_DEGREE: x[EDGE_DEGREE], + }, + axis=1, + ) + return input + + +def _prep_claims(input: pd.DataFrame) -> pd.DataFrame: + input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) + input[CLAIM_DETAILS] = input.apply( + lambda x: { + CLAIM_ID: x[CLAIM_ID], + CLAIM_SUBJECT: x[CLAIM_SUBJECT], + CLAIM_TYPE: x[CLAIM_TYPE], + CLAIM_STATUS: x[CLAIM_STATUS], + CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION], + }, + axis=1, + ) + return input diff --git a/graphrag/llm/mock/mock_completion_llm.py b/graphrag/llm/mock/mock_completion_llm.py index 8cb8e950..7eba4ca7 100644 --- a/graphrag/llm/mock/mock_completion_llm.py +++ b/graphrag/llm/mock/mock_completion_llm.py @@ -3,6 +3,7 @@ """LLM Static Response method definition.""" +import json import logging from typing_extensions import Unpack @@ -12,6 +13,7 @@ from graphrag.llm.types import ( CompletionInput, CompletionOutput, LLMInput, + LLMOutput, ) log = logging.getLogger(__name__) @@ -35,3 +37,6 @@ class MockCompletionLLM( **kwargs: Unpack[LLMInput], ) -> CompletionOutput: return self.responses[0] + + async def _invoke_json(self, input: CompletionInput, **kwargs: Unpack[LLMInput]): + return LLMOutput(output=self.responses[0], json=json.loads(self.responses[0])) diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 09d6e418..cfcbc5f4 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -93,7 +93,7 @@ "rank_explanation", "findings" ], - "subworkflows": 6, + "subworkflows": 1, "max_runtime": 300 }, "create_final_text_units": { diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 204937f9..d05d08cd 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -110,7 +110,7 @@ "rank_explanation", "findings" ], - "subworkflows": 7, + "subworkflows": 1, "max_runtime": 300 }, "create_final_text_units": { diff --git a/tests/verbs/test_create_final_community_reports.py b/tests/verbs/test_create_final_community_reports.py new file mode 100644 index 00000000..2a5004cb --- /dev/null +++ b/tests/verbs/test_create_final_community_reports.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from graphrag.index.workflows.v1.create_final_community_reports import ( + build_steps, + workflow_name, +) + +from .util import ( + compare_outputs, + get_config_for_workflow, + get_workflow_output, + load_expected, + load_input_tables, + remove_disabled_steps, +) + + +async def test_create_final_community_reports(): + input_tables = load_input_tables([ + "workflow:create_final_nodes", + "workflow:create_final_covariates", + "workflow:create_final_relationships", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + # deleting the llm config results in a default mock injection in run_graph_intelligence + del config["create_community_reports"]["strategy"]["llm"] + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + assert len(actual.columns) == len(expected.columns) + + # only assert a couple of columns that are not mock - most of this table is LLM-generated + compare_outputs(actual, expected, columns=["community", "level"]) + + # assert a handful of mock data items to confirm they get put in the right spot + assert actual["rank"][:1][0] == 2 + assert actual["rank_explanation"][:1][0] == "" + + +async def test_create_final_community_reports_with_embeddings(): + input_tables = load_input_tables([ + "workflow:create_final_nodes", + "workflow:create_final_covariates", + "workflow:create_final_relationships", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + # deleting the llm config results in a default mock injection in run_graph_intelligence + del config["create_community_reports"]["strategy"]["llm"] + + config["skip_full_content_embedding"] = False + config["community_report_full_content_embed"]["strategy"]["type"] = "mock" + config["skip_summary_embedding"] = False + config["community_report_summary_embed"]["strategy"]["type"] = "mock" + config["skip_title_embedding"] = False + config["community_report_title_embed"]["strategy"]["type"] = "mock" + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + assert len(actual.columns) == len(expected.columns) + 3 + assert "full_content_embedding" in actual.columns + assert len(actual["full_content_embedding"][:1][0]) == 3 + assert "summary_embedding" in actual.columns + assert len(actual["summary_embedding"][:1][0]) == 3 + assert "title_embedding" in actual.columns + assert len(actual["title_embedding"][:1][0]) == 3 diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 7dd948f8..70b5714d 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -15,6 +15,8 @@ from graphrag.index import ( ) from graphrag.index.run.utils import _create_run_context +pd.set_option("display.max_columns", None) + def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]: """Harvest all the referenced input IDs from the workflow being tested and pass them here."""