mirror of
https://github.com/microsoft/graphrag.git
synced 2025-11-09 14:24:07 +00:00
Collapse create final community reports (#1227)
* Remove extraneous param * Add community report mocking assertions * Collapse primary report generation * Collapse embeddings * Format * Semver * Remove extraneous check * Move option set
This commit is contained in:
parent
336e6f9ca1
commit
a44788bfad
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Collapse create-final-community-reports."
|
||||||
|
}
|
||||||
@ -53,14 +53,37 @@ async def create_community_reports(
|
|||||||
num_threads: int = 4,
|
num_threads: int = 4,
|
||||||
**_kwargs,
|
**_kwargs,
|
||||||
) -> TableContainer:
|
) -> TableContainer:
|
||||||
"""Generate entities for each row, and optionally a graph of those entities."""
|
"""Generate community summaries."""
|
||||||
log.debug("create_community_reports strategy=%s", strategy)
|
log.debug("create_community_reports strategy=%s", strategy)
|
||||||
local_contexts = cast(pd.DataFrame, input.get_input())
|
local_contexts = cast(pd.DataFrame, input.get_input())
|
||||||
nodes_ctr = get_required_input_table(input, "nodes")
|
nodes = get_required_input_table(input, "nodes").table
|
||||||
nodes = cast(pd.DataFrame, nodes_ctr.table)
|
community_hierarchy = get_required_input_table(input, "community_hierarchy").table
|
||||||
community_hierarchy_ctr = get_required_input_table(input, "community_hierarchy")
|
|
||||||
community_hierarchy = cast(pd.DataFrame, community_hierarchy_ctr.table)
|
|
||||||
|
|
||||||
|
output = await create_community_reports_df(
|
||||||
|
local_contexts,
|
||||||
|
nodes,
|
||||||
|
community_hierarchy,
|
||||||
|
callbacks,
|
||||||
|
cache,
|
||||||
|
strategy,
|
||||||
|
async_mode=async_mode,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TableContainer(table=output)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_community_reports_df(
|
||||||
|
local_contexts,
|
||||||
|
nodes,
|
||||||
|
community_hierarchy,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
cache: PipelineCache,
|
||||||
|
strategy: dict,
|
||||||
|
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||||
|
num_threads: int = 4,
|
||||||
|
):
|
||||||
|
"""Generate community summaries."""
|
||||||
levels = get_levels(nodes)
|
levels = get_levels(nodes)
|
||||||
reports: list[CommunityReport | None] = []
|
reports: list[CommunityReport | None] = []
|
||||||
tick = progress_ticker(callbacks.progress, len(local_contexts))
|
tick = progress_ticker(callbacks.progress, len(local_contexts))
|
||||||
@ -99,7 +122,7 @@ async def create_community_reports(
|
|||||||
)
|
)
|
||||||
reports.extend([lr for lr in local_reports if lr is not None])
|
reports.extend([lr for lr in local_reports if lr is not None])
|
||||||
|
|
||||||
return TableContainer(table=pd.DataFrame(reports))
|
return pd.DataFrame(reports)
|
||||||
|
|
||||||
|
|
||||||
async def _generate_report(
|
async def _generate_report(
|
||||||
|
|||||||
@ -37,25 +37,38 @@ def prepare_community_reports(
|
|||||||
max_tokens: int = 16_000,
|
max_tokens: int = 16_000,
|
||||||
**_kwargs,
|
**_kwargs,
|
||||||
) -> TableContainer:
|
) -> TableContainer:
|
||||||
"""Generate entities for each row, and optionally a graph of those entities."""
|
"""Prep communities for report generation."""
|
||||||
# Prepare Community Reports
|
# Prepare Community Reports
|
||||||
node_df = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
||||||
edge_df = cast(pd.DataFrame, get_required_input_table(input, "edges").table)
|
edges = cast(pd.DataFrame, get_required_input_table(input, "edges").table)
|
||||||
claim_df = get_named_input_table(input, "claims")
|
claims = get_named_input_table(input, "claims")
|
||||||
if claim_df is not None:
|
if claims:
|
||||||
claim_df = cast(pd.DataFrame, claim_df.table)
|
claims = cast(pd.DataFrame, claims.table)
|
||||||
|
|
||||||
levels = get_levels(node_df, schemas.NODE_LEVEL)
|
output = prepare_community_reports_df(nodes, edges, claims, callbacks, max_tokens)
|
||||||
|
|
||||||
|
return TableContainer(table=output)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_community_reports_df(
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
|
claims,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
max_tokens: int = 16_000,
|
||||||
|
):
|
||||||
|
"""Prep communities for report generation."""
|
||||||
|
levels = get_levels(nodes, schemas.NODE_LEVEL)
|
||||||
dfs = []
|
dfs = []
|
||||||
|
|
||||||
for level in progress_iterable(levels, callbacks.progress, len(levels)):
|
for level in progress_iterable(levels, callbacks.progress, len(levels)):
|
||||||
communities_at_level_df = _prepare_reports_at_level(
|
communities_at_level_df = _prepare_reports_at_level(
|
||||||
node_df, edge_df, claim_df, level, max_tokens
|
nodes, edges, claims, level, max_tokens
|
||||||
)
|
)
|
||||||
dfs.append(communities_at_level_df)
|
dfs.append(communities_at_level_df)
|
||||||
|
|
||||||
# build initial local context for all communities
|
# build initial local context for all communities
|
||||||
return TableContainer(table=pd.concat(dfs))
|
return pd.concat(dfs)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_reports_at_level(
|
def _prepare_reports_at_level(
|
||||||
|
|||||||
@ -24,8 +24,25 @@ def restore_community_hierarchy(
|
|||||||
) -> TableContainer:
|
) -> TableContainer:
|
||||||
"""Restore the community hierarchy from the node data."""
|
"""Restore the community hierarchy from the node data."""
|
||||||
node_df: pd.DataFrame = cast(pd.DataFrame, input.get_input())
|
node_df: pd.DataFrame = cast(pd.DataFrame, input.get_input())
|
||||||
|
|
||||||
|
output = restore_community_hierarchy_df(
|
||||||
|
node_df,
|
||||||
|
name_column=name_column,
|
||||||
|
community_column=community_column,
|
||||||
|
level_column=level_column,
|
||||||
|
)
|
||||||
|
return TableContainer(table=output)
|
||||||
|
|
||||||
|
|
||||||
|
def restore_community_hierarchy_df(
|
||||||
|
input: pd.DataFrame,
|
||||||
|
name_column: str = schemas.NODE_NAME,
|
||||||
|
community_column: str = schemas.NODE_COMMUNITY,
|
||||||
|
level_column: str = schemas.NODE_LEVEL,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Restore the community hierarchy from the node data."""
|
||||||
community_df = (
|
community_df = (
|
||||||
node_df.groupby([community_column, level_column])
|
input.groupby([community_column, level_column])
|
||||||
.agg({name_column: list})
|
.agg({name_column: list})
|
||||||
.reset_index()
|
.reset_index()
|
||||||
)
|
)
|
||||||
@ -75,4 +92,4 @@ def restore_community_hierarchy(
|
|||||||
if entities_found == len(current_entities):
|
if entities_found == len(current_entities):
|
||||||
break
|
break
|
||||||
|
|
||||||
return TableContainer(table=pd.DataFrame(community_hierarchy))
|
return pd.DataFrame(community_hierarchy)
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 3000
|
DEFAULT_CHUNK_SIZE = 3000
|
||||||
|
|
||||||
MOCK_RESPONSES = [
|
MOCK_RESPONSES = [
|
||||||
json.dumps({
|
json.dumps({
|
||||||
"title": "<report_title>",
|
"title": "<report_title>",
|
||||||
@ -18,7 +19,7 @@ MOCK_RESPONSES = [
|
|||||||
"explanation": "<insight_1_explanation",
|
"explanation": "<insight_1_explanation",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"summary": "<farts insight_2_summary>",
|
"summary": "<insight_2_summary>",
|
||||||
"explanation": "<insight_2_explanation",
|
"explanation": "<insight_2_explanation",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
|||||||
@ -19,10 +19,6 @@ def build_steps(
|
|||||||
"""
|
"""
|
||||||
covariates_enabled = config.get("covariates_enabled", False)
|
covariates_enabled = config.get("covariates_enabled", False)
|
||||||
create_community_reports_config = config.get("create_community_reports", {})
|
create_community_reports_config = config.get("create_community_reports", {})
|
||||||
community_report_strategy = create_community_reports_config.get("strategy", {})
|
|
||||||
community_report_max_input_length = community_report_strategy.get(
|
|
||||||
"max_input_length", 16_000
|
|
||||||
)
|
|
||||||
base_text_embed = config.get("text_embed", {})
|
base_text_embed = config.get("text_embed", {})
|
||||||
community_report_full_content_embed_config = config.get(
|
community_report_full_content_embed_config = config.get(
|
||||||
"community_report_full_content_embed", base_text_embed
|
"community_report_full_content_embed", base_text_embed
|
||||||
@ -36,103 +32,26 @@ def build_steps(
|
|||||||
skip_title_embedding = config.get("skip_title_embedding", False)
|
skip_title_embedding = config.get("skip_title_embedding", False)
|
||||||
skip_summary_embedding = config.get("skip_summary_embedding", False)
|
skip_summary_embedding = config.get("skip_summary_embedding", False)
|
||||||
skip_full_content_embedding = config.get("skip_full_content_embedding", False)
|
skip_full_content_embedding = config.get("skip_full_content_embedding", False)
|
||||||
|
input = {
|
||||||
|
"source": "workflow:create_final_nodes",
|
||||||
|
"relationships": "workflow:create_final_relationships",
|
||||||
|
}
|
||||||
|
if covariates_enabled:
|
||||||
|
input["covariates"] = "workflow:create_final_covariates"
|
||||||
|
|
||||||
return [
|
return [
|
||||||
#
|
|
||||||
# Subworkflow: Prepare Nodes
|
|
||||||
#
|
|
||||||
{
|
{
|
||||||
"id": "nodes",
|
"verb": "create_final_community_reports",
|
||||||
"verb": "prepare_community_reports_nodes",
|
|
||||||
"input": {"source": "workflow:create_final_nodes"},
|
|
||||||
},
|
|
||||||
#
|
|
||||||
# Subworkflow: Prepare Edges
|
|
||||||
#
|
|
||||||
{
|
|
||||||
"id": "edges",
|
|
||||||
"verb": "prepare_community_reports_edges",
|
|
||||||
"input": {"source": "workflow:create_final_relationships"},
|
|
||||||
},
|
|
||||||
#
|
|
||||||
# Subworkflow: Prepare Claims Table
|
|
||||||
#
|
|
||||||
{
|
|
||||||
"id": "claims",
|
|
||||||
"enabled": covariates_enabled,
|
|
||||||
"verb": "prepare_community_reports_claims",
|
|
||||||
"input": {
|
|
||||||
"source": "workflow:create_final_covariates",
|
|
||||||
}
|
|
||||||
if covariates_enabled
|
|
||||||
else {},
|
|
||||||
},
|
|
||||||
#
|
|
||||||
# Subworkflow: Get Community Hierarchy
|
|
||||||
#
|
|
||||||
{
|
|
||||||
"id": "community_hierarchy",
|
|
||||||
"verb": "restore_community_hierarchy",
|
|
||||||
"input": {"source": "nodes"},
|
|
||||||
},
|
|
||||||
#
|
|
||||||
# Main Workflow: Create Community Reports
|
|
||||||
#
|
|
||||||
{
|
|
||||||
"id": "local_contexts",
|
|
||||||
"verb": "prepare_community_reports",
|
|
||||||
"args": {"max_tokens": community_report_max_input_length},
|
|
||||||
"input": {
|
|
||||||
"source": "nodes",
|
|
||||||
"nodes": "nodes",
|
|
||||||
"edges": "edges",
|
|
||||||
**({"claims": "claims"} if covariates_enabled else {}),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "create_community_reports",
|
|
||||||
"args": {
|
"args": {
|
||||||
|
"covariates_enabled": covariates_enabled,
|
||||||
|
"skip_full_content_embedding": skip_full_content_embedding,
|
||||||
|
"skip_summary_embedding": skip_summary_embedding,
|
||||||
|
"skip_title_embedding": skip_title_embedding,
|
||||||
|
"full_content_text_embed": community_report_full_content_embed_config,
|
||||||
|
"summary_text_embed": community_report_summary_embed_config,
|
||||||
|
"title_text_embed": community_report_title_embed_config,
|
||||||
**create_community_reports_config,
|
**create_community_reports_config,
|
||||||
},
|
},
|
||||||
"input": {
|
"input": input,
|
||||||
"source": "local_contexts",
|
|
||||||
"community_hierarchy": "community_hierarchy",
|
|
||||||
"nodes": "nodes",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
# Generate a unique ID for each community report distinct from the community ID
|
|
||||||
"verb": "window",
|
|
||||||
"args": {"to": "id", "operation": "uuid", "column": "community"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "text_embed",
|
|
||||||
"enabled": not skip_full_content_embedding,
|
|
||||||
"args": {
|
|
||||||
"embedding_name": "community_report_full_content",
|
|
||||||
"column": "full_content",
|
|
||||||
"to": "full_content_embedding",
|
|
||||||
**community_report_full_content_embed_config,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "text_embed",
|
|
||||||
"enabled": not skip_summary_embedding,
|
|
||||||
"args": {
|
|
||||||
"embedding_name": "community_report_summary",
|
|
||||||
"column": "summary",
|
|
||||||
"to": "summary_embedding",
|
|
||||||
**community_report_summary_embed_config,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "text_embed",
|
|
||||||
"enabled": not skip_title_embedding,
|
|
||||||
"args": {
|
|
||||||
"embedding_name": "community_report_title",
|
|
||||||
"column": "title",
|
|
||||||
"to": "title_embedding",
|
|
||||||
**community_report_title_embed_config,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|||||||
@ -26,7 +26,6 @@ def build_steps(
|
|||||||
{
|
{
|
||||||
"verb": "create_final_documents",
|
"verb": "create_final_documents",
|
||||||
"args": {
|
"args": {
|
||||||
"columns": {"text_units": "text_unit_ids"},
|
|
||||||
"skip_embedding": skip_raw_content_embedding,
|
"skip_embedding": skip_raw_content_embedding,
|
||||||
"text_embed": document_raw_content_embed_config,
|
"text_embed": document_raw_content_embed_config,
|
||||||
},
|
},
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
from .create_base_documents import create_base_documents
|
from .create_base_documents import create_base_documents
|
||||||
from .create_base_text_units import create_base_text_units
|
from .create_base_text_units import create_base_text_units
|
||||||
from .create_final_communities import create_final_communities
|
from .create_final_communities import create_final_communities
|
||||||
|
from .create_final_community_reports import create_final_community_reports
|
||||||
from .create_final_covariates import create_final_covariates
|
from .create_final_covariates import create_final_covariates
|
||||||
from .create_final_documents import create_final_documents
|
from .create_final_documents import create_final_documents
|
||||||
from .create_final_entities import create_final_entities
|
from .create_final_entities import create_final_entities
|
||||||
@ -19,6 +20,7 @@ __all__ = [
|
|||||||
"create_base_documents",
|
"create_base_documents",
|
||||||
"create_base_text_units",
|
"create_base_text_units",
|
||||||
"create_final_communities",
|
"create_final_communities",
|
||||||
|
"create_final_community_reports",
|
||||||
"create_final_covariates",
|
"create_final_covariates",
|
||||||
"create_final_documents",
|
"create_final_documents",
|
||||||
"create_final_entities",
|
"create_final_entities",
|
||||||
|
|||||||
@ -0,0 +1,188 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
"""All the steps to transform community reports."""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from datashaper import (
|
||||||
|
AsyncType,
|
||||||
|
Table,
|
||||||
|
VerbCallbacks,
|
||||||
|
VerbInput,
|
||||||
|
verb,
|
||||||
|
)
|
||||||
|
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||||
|
|
||||||
|
from graphrag.index.cache import PipelineCache
|
||||||
|
from graphrag.index.graph.extractors.community_reports.schemas import (
|
||||||
|
CLAIM_DESCRIPTION,
|
||||||
|
CLAIM_DETAILS,
|
||||||
|
CLAIM_ID,
|
||||||
|
CLAIM_STATUS,
|
||||||
|
CLAIM_SUBJECT,
|
||||||
|
CLAIM_TYPE,
|
||||||
|
EDGE_DEGREE,
|
||||||
|
EDGE_DESCRIPTION,
|
||||||
|
EDGE_DETAILS,
|
||||||
|
EDGE_ID,
|
||||||
|
EDGE_SOURCE,
|
||||||
|
EDGE_TARGET,
|
||||||
|
NODE_DEGREE,
|
||||||
|
NODE_DESCRIPTION,
|
||||||
|
NODE_DETAILS,
|
||||||
|
NODE_ID,
|
||||||
|
NODE_NAME,
|
||||||
|
)
|
||||||
|
from graphrag.index.utils.ds_util import get_required_input_table
|
||||||
|
from graphrag.index.verbs.graph.report.create_community_reports import (
|
||||||
|
create_community_reports_df,
|
||||||
|
)
|
||||||
|
from graphrag.index.verbs.graph.report.prepare_community_reports import (
|
||||||
|
prepare_community_reports_df,
|
||||||
|
)
|
||||||
|
from graphrag.index.verbs.graph.report.restore_community_hierarchy import (
|
||||||
|
restore_community_hierarchy_df,
|
||||||
|
)
|
||||||
|
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||||
|
|
||||||
|
|
||||||
|
@verb(name="create_final_community_reports", treats_input_tables_as_immutable=True)
|
||||||
|
async def create_final_community_reports(
|
||||||
|
input: VerbInput,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
cache: PipelineCache,
|
||||||
|
strategy: dict,
|
||||||
|
full_content_text_embed: dict,
|
||||||
|
summary_text_embed: dict,
|
||||||
|
title_text_embed: dict,
|
||||||
|
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||||
|
num_threads: int = 4,
|
||||||
|
skip_full_content_embedding: bool = False,
|
||||||
|
skip_summary_embedding: bool = False,
|
||||||
|
skip_title_embedding: bool = False,
|
||||||
|
covariates_enabled: bool = False,
|
||||||
|
**_kwargs: dict,
|
||||||
|
) -> VerbResult:
|
||||||
|
"""All the steps to transform community reports."""
|
||||||
|
nodes = _prep_nodes(cast(pd.DataFrame, input.get_input()))
|
||||||
|
edges = _prep_edges(
|
||||||
|
cast(pd.DataFrame, get_required_input_table(input, "relationships").table)
|
||||||
|
)
|
||||||
|
|
||||||
|
claims = None
|
||||||
|
if covariates_enabled:
|
||||||
|
claims = _prep_claims(
|
||||||
|
cast(pd.DataFrame, get_required_input_table(input, "covariates").table)
|
||||||
|
)
|
||||||
|
|
||||||
|
community_hierarchy = restore_community_hierarchy_df(nodes)
|
||||||
|
|
||||||
|
local_contexts = prepare_community_reports_df(
|
||||||
|
nodes, edges, claims, callbacks, strategy.get("max_input_length", 16_000)
|
||||||
|
)
|
||||||
|
|
||||||
|
community_reports = await create_community_reports_df(
|
||||||
|
local_contexts,
|
||||||
|
nodes,
|
||||||
|
community_hierarchy,
|
||||||
|
callbacks,
|
||||||
|
cache,
|
||||||
|
strategy,
|
||||||
|
async_mode=async_mode,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
|
community_reports["id"] = community_reports["community"].apply(
|
||||||
|
lambda _x: str(uuid4())
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embed full content if not skipped
|
||||||
|
if not skip_full_content_embedding:
|
||||||
|
community_reports = await text_embed_df(
|
||||||
|
community_reports,
|
||||||
|
callbacks,
|
||||||
|
cache,
|
||||||
|
column="full_content",
|
||||||
|
strategy=full_content_text_embed["strategy"],
|
||||||
|
to="full_content_embedding",
|
||||||
|
embedding_name="community_report_full_content",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embed summary if not skipped
|
||||||
|
if not skip_summary_embedding:
|
||||||
|
community_reports = await text_embed_df(
|
||||||
|
community_reports,
|
||||||
|
callbacks,
|
||||||
|
cache,
|
||||||
|
column="summary",
|
||||||
|
strategy=summary_text_embed["strategy"],
|
||||||
|
to="summary_embedding",
|
||||||
|
embedding_name="community_report_summary",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embed title if not skipped
|
||||||
|
if not skip_title_embedding:
|
||||||
|
community_reports = await text_embed_df(
|
||||||
|
community_reports,
|
||||||
|
callbacks,
|
||||||
|
cache,
|
||||||
|
column="title",
|
||||||
|
strategy=title_text_embed["strategy"],
|
||||||
|
to="title_embedding",
|
||||||
|
embedding_name="community_report_title",
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_verb_result(
|
||||||
|
cast(
|
||||||
|
Table,
|
||||||
|
community_reports,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
|
||||||
|
# merge values of four columns into a map column
|
||||||
|
input[NODE_DETAILS] = input.apply(
|
||||||
|
lambda x: {
|
||||||
|
NODE_ID: x[NODE_ID],
|
||||||
|
NODE_NAME: x[NODE_NAME],
|
||||||
|
NODE_DESCRIPTION: x[NODE_DESCRIPTION],
|
||||||
|
NODE_DEGREE: x[NODE_DEGREE],
|
||||||
|
},
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def _prep_edges(input: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
|
||||||
|
input[EDGE_DETAILS] = input.apply(
|
||||||
|
lambda x: {
|
||||||
|
EDGE_ID: x[EDGE_ID],
|
||||||
|
EDGE_SOURCE: x[EDGE_SOURCE],
|
||||||
|
EDGE_TARGET: x[EDGE_TARGET],
|
||||||
|
EDGE_DESCRIPTION: x[EDGE_DESCRIPTION],
|
||||||
|
EDGE_DEGREE: x[EDGE_DEGREE],
|
||||||
|
},
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
|
||||||
|
input[CLAIM_DETAILS] = input.apply(
|
||||||
|
lambda x: {
|
||||||
|
CLAIM_ID: x[CLAIM_ID],
|
||||||
|
CLAIM_SUBJECT: x[CLAIM_SUBJECT],
|
||||||
|
CLAIM_TYPE: x[CLAIM_TYPE],
|
||||||
|
CLAIM_STATUS: x[CLAIM_STATUS],
|
||||||
|
CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION],
|
||||||
|
},
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
return input
|
||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
"""LLM Static Response method definition."""
|
"""LLM Static Response method definition."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
@ -12,6 +13,7 @@ from graphrag.llm.types import (
|
|||||||
CompletionInput,
|
CompletionInput,
|
||||||
CompletionOutput,
|
CompletionOutput,
|
||||||
LLMInput,
|
LLMInput,
|
||||||
|
LLMOutput,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -35,3 +37,6 @@ class MockCompletionLLM(
|
|||||||
**kwargs: Unpack[LLMInput],
|
**kwargs: Unpack[LLMInput],
|
||||||
) -> CompletionOutput:
|
) -> CompletionOutput:
|
||||||
return self.responses[0]
|
return self.responses[0]
|
||||||
|
|
||||||
|
async def _invoke_json(self, input: CompletionInput, **kwargs: Unpack[LLMInput]):
|
||||||
|
return LLMOutput(output=self.responses[0], json=json.loads(self.responses[0]))
|
||||||
|
|||||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -93,7 +93,7 @@
|
|||||||
"rank_explanation",
|
"rank_explanation",
|
||||||
"findings"
|
"findings"
|
||||||
],
|
],
|
||||||
"subworkflows": 6,
|
"subworkflows": 1,
|
||||||
"max_runtime": 300
|
"max_runtime": 300
|
||||||
},
|
},
|
||||||
"create_final_text_units": {
|
"create_final_text_units": {
|
||||||
|
|||||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -110,7 +110,7 @@
|
|||||||
"rank_explanation",
|
"rank_explanation",
|
||||||
"findings"
|
"findings"
|
||||||
],
|
],
|
||||||
"subworkflows": 7,
|
"subworkflows": 1,
|
||||||
"max_runtime": 300
|
"max_runtime": 300
|
||||||
},
|
},
|
||||||
"create_final_text_units": {
|
"create_final_text_units": {
|
||||||
|
|||||||
86
tests/verbs/test_create_final_community_reports.py
Normal file
86
tests/verbs/test_create_final_community_reports.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
from graphrag.index.workflows.v1.create_final_community_reports import (
|
||||||
|
build_steps,
|
||||||
|
workflow_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .util import (
|
||||||
|
compare_outputs,
|
||||||
|
get_config_for_workflow,
|
||||||
|
get_workflow_output,
|
||||||
|
load_expected,
|
||||||
|
load_input_tables,
|
||||||
|
remove_disabled_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_final_community_reports():
|
||||||
|
input_tables = load_input_tables([
|
||||||
|
"workflow:create_final_nodes",
|
||||||
|
"workflow:create_final_covariates",
|
||||||
|
"workflow:create_final_relationships",
|
||||||
|
])
|
||||||
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
|
# deleting the llm config results in a default mock injection in run_graph_intelligence
|
||||||
|
del config["create_community_reports"]["strategy"]["llm"]
|
||||||
|
|
||||||
|
steps = remove_disabled_steps(build_steps(config))
|
||||||
|
|
||||||
|
actual = await get_workflow_output(
|
||||||
|
input_tables,
|
||||||
|
{
|
||||||
|
"steps": steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(actual.columns) == len(expected.columns)
|
||||||
|
|
||||||
|
# only assert a couple of columns that are not mock - most of this table is LLM-generated
|
||||||
|
compare_outputs(actual, expected, columns=["community", "level"])
|
||||||
|
|
||||||
|
# assert a handful of mock data items to confirm they get put in the right spot
|
||||||
|
assert actual["rank"][:1][0] == 2
|
||||||
|
assert actual["rank_explanation"][:1][0] == "<rating_explanation>"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_final_community_reports_with_embeddings():
|
||||||
|
input_tables = load_input_tables([
|
||||||
|
"workflow:create_final_nodes",
|
||||||
|
"workflow:create_final_covariates",
|
||||||
|
"workflow:create_final_relationships",
|
||||||
|
])
|
||||||
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
|
# deleting the llm config results in a default mock injection in run_graph_intelligence
|
||||||
|
del config["create_community_reports"]["strategy"]["llm"]
|
||||||
|
|
||||||
|
config["skip_full_content_embedding"] = False
|
||||||
|
config["community_report_full_content_embed"]["strategy"]["type"] = "mock"
|
||||||
|
config["skip_summary_embedding"] = False
|
||||||
|
config["community_report_summary_embed"]["strategy"]["type"] = "mock"
|
||||||
|
config["skip_title_embedding"] = False
|
||||||
|
config["community_report_title_embed"]["strategy"]["type"] = "mock"
|
||||||
|
|
||||||
|
steps = remove_disabled_steps(build_steps(config))
|
||||||
|
|
||||||
|
actual = await get_workflow_output(
|
||||||
|
input_tables,
|
||||||
|
{
|
||||||
|
"steps": steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(actual.columns) == len(expected.columns) + 3
|
||||||
|
assert "full_content_embedding" in actual.columns
|
||||||
|
assert len(actual["full_content_embedding"][:1][0]) == 3
|
||||||
|
assert "summary_embedding" in actual.columns
|
||||||
|
assert len(actual["summary_embedding"][:1][0]) == 3
|
||||||
|
assert "title_embedding" in actual.columns
|
||||||
|
assert len(actual["title_embedding"][:1][0]) == 3
|
||||||
@ -15,6 +15,8 @@ from graphrag.index import (
|
|||||||
)
|
)
|
||||||
from graphrag.index.run.utils import _create_run_context
|
from graphrag.index.run.utils import _create_run_context
|
||||||
|
|
||||||
|
pd.set_option("display.max_columns", None)
|
||||||
|
|
||||||
|
|
||||||
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
|
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
|
||||||
"""Harvest all the referenced input IDs from the workflow being tested and pass them here."""
|
"""Harvest all the referenced input IDs from the workflow being tested and pass them here."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user