mirror of
https://github.com/microsoft/graphrag.git
synced 2025-11-03 11:20:38 +00:00
Collapse create final community reports (#1227)
* Remove extraneous param * Add community report mocking assertions * Collapse primary report generation * Collapse embeddings * Format * Semver * Remove extraneous check * Move option set
This commit is contained in:
parent
336e6f9ca1
commit
a44788bfad
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Collapse create-final-community-reports."
|
||||
}
|
||||
@ -53,14 +53,37 @@ async def create_community_reports(
|
||||
num_threads: int = 4,
|
||||
**_kwargs,
|
||||
) -> TableContainer:
|
||||
"""Generate entities for each row, and optionally a graph of those entities."""
|
||||
"""Generate community summaries."""
|
||||
log.debug("create_community_reports strategy=%s", strategy)
|
||||
local_contexts = cast(pd.DataFrame, input.get_input())
|
||||
nodes_ctr = get_required_input_table(input, "nodes")
|
||||
nodes = cast(pd.DataFrame, nodes_ctr.table)
|
||||
community_hierarchy_ctr = get_required_input_table(input, "community_hierarchy")
|
||||
community_hierarchy = cast(pd.DataFrame, community_hierarchy_ctr.table)
|
||||
nodes = get_required_input_table(input, "nodes").table
|
||||
community_hierarchy = get_required_input_table(input, "community_hierarchy").table
|
||||
|
||||
output = await create_community_reports_df(
|
||||
local_contexts,
|
||||
nodes,
|
||||
community_hierarchy,
|
||||
callbacks,
|
||||
cache,
|
||||
strategy,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
return TableContainer(table=output)
|
||||
|
||||
|
||||
async def create_community_reports_df(
|
||||
local_contexts,
|
||||
nodes,
|
||||
community_hierarchy,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
strategy: dict,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
num_threads: int = 4,
|
||||
):
|
||||
"""Generate community summaries."""
|
||||
levels = get_levels(nodes)
|
||||
reports: list[CommunityReport | None] = []
|
||||
tick = progress_ticker(callbacks.progress, len(local_contexts))
|
||||
@ -99,7 +122,7 @@ async def create_community_reports(
|
||||
)
|
||||
reports.extend([lr for lr in local_reports if lr is not None])
|
||||
|
||||
return TableContainer(table=pd.DataFrame(reports))
|
||||
return pd.DataFrame(reports)
|
||||
|
||||
|
||||
async def _generate_report(
|
||||
|
||||
@ -37,25 +37,38 @@ def prepare_community_reports(
|
||||
max_tokens: int = 16_000,
|
||||
**_kwargs,
|
||||
) -> TableContainer:
|
||||
"""Generate entities for each row, and optionally a graph of those entities."""
|
||||
"""Prep communities for report generation."""
|
||||
# Prepare Community Reports
|
||||
node_df = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
||||
edge_df = cast(pd.DataFrame, get_required_input_table(input, "edges").table)
|
||||
claim_df = get_named_input_table(input, "claims")
|
||||
if claim_df is not None:
|
||||
claim_df = cast(pd.DataFrame, claim_df.table)
|
||||
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
||||
edges = cast(pd.DataFrame, get_required_input_table(input, "edges").table)
|
||||
claims = get_named_input_table(input, "claims")
|
||||
if claims:
|
||||
claims = cast(pd.DataFrame, claims.table)
|
||||
|
||||
levels = get_levels(node_df, schemas.NODE_LEVEL)
|
||||
output = prepare_community_reports_df(nodes, edges, claims, callbacks, max_tokens)
|
||||
|
||||
return TableContainer(table=output)
|
||||
|
||||
|
||||
def prepare_community_reports_df(
|
||||
nodes,
|
||||
edges,
|
||||
claims,
|
||||
callbacks: VerbCallbacks,
|
||||
max_tokens: int = 16_000,
|
||||
):
|
||||
"""Prep communities for report generation."""
|
||||
levels = get_levels(nodes, schemas.NODE_LEVEL)
|
||||
dfs = []
|
||||
|
||||
for level in progress_iterable(levels, callbacks.progress, len(levels)):
|
||||
communities_at_level_df = _prepare_reports_at_level(
|
||||
node_df, edge_df, claim_df, level, max_tokens
|
||||
nodes, edges, claims, level, max_tokens
|
||||
)
|
||||
dfs.append(communities_at_level_df)
|
||||
|
||||
# build initial local context for all communities
|
||||
return TableContainer(table=pd.concat(dfs))
|
||||
return pd.concat(dfs)
|
||||
|
||||
|
||||
def _prepare_reports_at_level(
|
||||
|
||||
@ -24,8 +24,25 @@ def restore_community_hierarchy(
|
||||
) -> TableContainer:
|
||||
"""Restore the community hierarchy from the node data."""
|
||||
node_df: pd.DataFrame = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
output = restore_community_hierarchy_df(
|
||||
node_df,
|
||||
name_column=name_column,
|
||||
community_column=community_column,
|
||||
level_column=level_column,
|
||||
)
|
||||
return TableContainer(table=output)
|
||||
|
||||
|
||||
def restore_community_hierarchy_df(
|
||||
input: pd.DataFrame,
|
||||
name_column: str = schemas.NODE_NAME,
|
||||
community_column: str = schemas.NODE_COMMUNITY,
|
||||
level_column: str = schemas.NODE_LEVEL,
|
||||
) -> pd.DataFrame:
|
||||
"""Restore the community hierarchy from the node data."""
|
||||
community_df = (
|
||||
node_df.groupby([community_column, level_column])
|
||||
input.groupby([community_column, level_column])
|
||||
.agg({name_column: list})
|
||||
.reset_index()
|
||||
)
|
||||
@ -75,4 +92,4 @@ def restore_community_hierarchy(
|
||||
if entities_found == len(current_entities):
|
||||
break
|
||||
|
||||
return TableContainer(table=pd.DataFrame(community_hierarchy))
|
||||
return pd.DataFrame(community_hierarchy)
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
import json
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 3000
|
||||
|
||||
MOCK_RESPONSES = [
|
||||
json.dumps({
|
||||
"title": "<report_title>",
|
||||
@ -18,7 +19,7 @@ MOCK_RESPONSES = [
|
||||
"explanation": "<insight_1_explanation",
|
||||
},
|
||||
{
|
||||
"summary": "<farts insight_2_summary>",
|
||||
"summary": "<insight_2_summary>",
|
||||
"explanation": "<insight_2_explanation",
|
||||
},
|
||||
],
|
||||
|
||||
@ -19,10 +19,6 @@ def build_steps(
|
||||
"""
|
||||
covariates_enabled = config.get("covariates_enabled", False)
|
||||
create_community_reports_config = config.get("create_community_reports", {})
|
||||
community_report_strategy = create_community_reports_config.get("strategy", {})
|
||||
community_report_max_input_length = community_report_strategy.get(
|
||||
"max_input_length", 16_000
|
||||
)
|
||||
base_text_embed = config.get("text_embed", {})
|
||||
community_report_full_content_embed_config = config.get(
|
||||
"community_report_full_content_embed", base_text_embed
|
||||
@ -36,103 +32,26 @@ def build_steps(
|
||||
skip_title_embedding = config.get("skip_title_embedding", False)
|
||||
skip_summary_embedding = config.get("skip_summary_embedding", False)
|
||||
skip_full_content_embedding = config.get("skip_full_content_embedding", False)
|
||||
input = {
|
||||
"source": "workflow:create_final_nodes",
|
||||
"relationships": "workflow:create_final_relationships",
|
||||
}
|
||||
if covariates_enabled:
|
||||
input["covariates"] = "workflow:create_final_covariates"
|
||||
|
||||
return [
|
||||
#
|
||||
# Subworkflow: Prepare Nodes
|
||||
#
|
||||
{
|
||||
"id": "nodes",
|
||||
"verb": "prepare_community_reports_nodes",
|
||||
"input": {"source": "workflow:create_final_nodes"},
|
||||
},
|
||||
#
|
||||
# Subworkflow: Prepare Edges
|
||||
#
|
||||
{
|
||||
"id": "edges",
|
||||
"verb": "prepare_community_reports_edges",
|
||||
"input": {"source": "workflow:create_final_relationships"},
|
||||
},
|
||||
#
|
||||
# Subworkflow: Prepare Claims Table
|
||||
#
|
||||
{
|
||||
"id": "claims",
|
||||
"enabled": covariates_enabled,
|
||||
"verb": "prepare_community_reports_claims",
|
||||
"input": {
|
||||
"source": "workflow:create_final_covariates",
|
||||
}
|
||||
if covariates_enabled
|
||||
else {},
|
||||
},
|
||||
#
|
||||
# Subworkflow: Get Community Hierarchy
|
||||
#
|
||||
{
|
||||
"id": "community_hierarchy",
|
||||
"verb": "restore_community_hierarchy",
|
||||
"input": {"source": "nodes"},
|
||||
},
|
||||
#
|
||||
# Main Workflow: Create Community Reports
|
||||
#
|
||||
{
|
||||
"id": "local_contexts",
|
||||
"verb": "prepare_community_reports",
|
||||
"args": {"max_tokens": community_report_max_input_length},
|
||||
"input": {
|
||||
"source": "nodes",
|
||||
"nodes": "nodes",
|
||||
"edges": "edges",
|
||||
**({"claims": "claims"} if covariates_enabled else {}),
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "create_community_reports",
|
||||
"verb": "create_final_community_reports",
|
||||
"args": {
|
||||
"covariates_enabled": covariates_enabled,
|
||||
"skip_full_content_embedding": skip_full_content_embedding,
|
||||
"skip_summary_embedding": skip_summary_embedding,
|
||||
"skip_title_embedding": skip_title_embedding,
|
||||
"full_content_text_embed": community_report_full_content_embed_config,
|
||||
"summary_text_embed": community_report_summary_embed_config,
|
||||
"title_text_embed": community_report_title_embed_config,
|
||||
**create_community_reports_config,
|
||||
},
|
||||
"input": {
|
||||
"source": "local_contexts",
|
||||
"community_hierarchy": "community_hierarchy",
|
||||
"nodes": "nodes",
|
||||
},
|
||||
},
|
||||
{
|
||||
# Generate a unique ID for each community report distinct from the community ID
|
||||
"verb": "window",
|
||||
"args": {"to": "id", "operation": "uuid", "column": "community"},
|
||||
},
|
||||
{
|
||||
"verb": "text_embed",
|
||||
"enabled": not skip_full_content_embedding,
|
||||
"args": {
|
||||
"embedding_name": "community_report_full_content",
|
||||
"column": "full_content",
|
||||
"to": "full_content_embedding",
|
||||
**community_report_full_content_embed_config,
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "text_embed",
|
||||
"enabled": not skip_summary_embedding,
|
||||
"args": {
|
||||
"embedding_name": "community_report_summary",
|
||||
"column": "summary",
|
||||
"to": "summary_embedding",
|
||||
**community_report_summary_embed_config,
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "text_embed",
|
||||
"enabled": not skip_title_embedding,
|
||||
"args": {
|
||||
"embedding_name": "community_report_title",
|
||||
"column": "title",
|
||||
"to": "title_embedding",
|
||||
**community_report_title_embed_config,
|
||||
},
|
||||
"input": input,
|
||||
},
|
||||
]
|
||||
|
||||
@ -26,7 +26,6 @@ def build_steps(
|
||||
{
|
||||
"verb": "create_final_documents",
|
||||
"args": {
|
||||
"columns": {"text_units": "text_unit_ids"},
|
||||
"skip_embedding": skip_raw_content_embedding,
|
||||
"text_embed": document_raw_content_embed_config,
|
||||
},
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
from .create_base_documents import create_base_documents
|
||||
from .create_base_text_units import create_base_text_units
|
||||
from .create_final_communities import create_final_communities
|
||||
from .create_final_community_reports import create_final_community_reports
|
||||
from .create_final_covariates import create_final_covariates
|
||||
from .create_final_documents import create_final_documents
|
||||
from .create_final_entities import create_final_entities
|
||||
@ -19,6 +20,7 @@ __all__ = [
|
||||
"create_base_documents",
|
||||
"create_base_text_units",
|
||||
"create_final_communities",
|
||||
"create_final_community_reports",
|
||||
"create_final_covariates",
|
||||
"create_final_documents",
|
||||
"create_final_entities",
|
||||
|
||||
@ -0,0 +1,188 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform community reports."""
|
||||
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.graph.extractors.community_reports.schemas import (
|
||||
CLAIM_DESCRIPTION,
|
||||
CLAIM_DETAILS,
|
||||
CLAIM_ID,
|
||||
CLAIM_STATUS,
|
||||
CLAIM_SUBJECT,
|
||||
CLAIM_TYPE,
|
||||
EDGE_DEGREE,
|
||||
EDGE_DESCRIPTION,
|
||||
EDGE_DETAILS,
|
||||
EDGE_ID,
|
||||
EDGE_SOURCE,
|
||||
EDGE_TARGET,
|
||||
NODE_DEGREE,
|
||||
NODE_DESCRIPTION,
|
||||
NODE_DETAILS,
|
||||
NODE_ID,
|
||||
NODE_NAME,
|
||||
)
|
||||
from graphrag.index.utils.ds_util import get_required_input_table
|
||||
from graphrag.index.verbs.graph.report.create_community_reports import (
|
||||
create_community_reports_df,
|
||||
)
|
||||
from graphrag.index.verbs.graph.report.prepare_community_reports import (
|
||||
prepare_community_reports_df,
|
||||
)
|
||||
from graphrag.index.verbs.graph.report.restore_community_hierarchy import (
|
||||
restore_community_hierarchy_df,
|
||||
)
|
||||
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||
|
||||
|
||||
@verb(name="create_final_community_reports", treats_input_tables_as_immutable=True)
|
||||
async def create_final_community_reports(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
strategy: dict,
|
||||
full_content_text_embed: dict,
|
||||
summary_text_embed: dict,
|
||||
title_text_embed: dict,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
num_threads: int = 4,
|
||||
skip_full_content_embedding: bool = False,
|
||||
skip_summary_embedding: bool = False,
|
||||
skip_title_embedding: bool = False,
|
||||
covariates_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform community reports."""
|
||||
nodes = _prep_nodes(cast(pd.DataFrame, input.get_input()))
|
||||
edges = _prep_edges(
|
||||
cast(pd.DataFrame, get_required_input_table(input, "relationships").table)
|
||||
)
|
||||
|
||||
claims = None
|
||||
if covariates_enabled:
|
||||
claims = _prep_claims(
|
||||
cast(pd.DataFrame, get_required_input_table(input, "covariates").table)
|
||||
)
|
||||
|
||||
community_hierarchy = restore_community_hierarchy_df(nodes)
|
||||
|
||||
local_contexts = prepare_community_reports_df(
|
||||
nodes, edges, claims, callbacks, strategy.get("max_input_length", 16_000)
|
||||
)
|
||||
|
||||
community_reports = await create_community_reports_df(
|
||||
local_contexts,
|
||||
nodes,
|
||||
community_hierarchy,
|
||||
callbacks,
|
||||
cache,
|
||||
strategy,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
community_reports["id"] = community_reports["community"].apply(
|
||||
lambda _x: str(uuid4())
|
||||
)
|
||||
|
||||
# Embed full content if not skipped
|
||||
if not skip_full_content_embedding:
|
||||
community_reports = await text_embed_df(
|
||||
community_reports,
|
||||
callbacks,
|
||||
cache,
|
||||
column="full_content",
|
||||
strategy=full_content_text_embed["strategy"],
|
||||
to="full_content_embedding",
|
||||
embedding_name="community_report_full_content",
|
||||
)
|
||||
|
||||
# Embed summary if not skipped
|
||||
if not skip_summary_embedding:
|
||||
community_reports = await text_embed_df(
|
||||
community_reports,
|
||||
callbacks,
|
||||
cache,
|
||||
column="summary",
|
||||
strategy=summary_text_embed["strategy"],
|
||||
to="summary_embedding",
|
||||
embedding_name="community_report_summary",
|
||||
)
|
||||
|
||||
# Embed title if not skipped
|
||||
if not skip_title_embedding:
|
||||
community_reports = await text_embed_df(
|
||||
community_reports,
|
||||
callbacks,
|
||||
cache,
|
||||
column="title",
|
||||
strategy=title_text_embed["strategy"],
|
||||
to="title_embedding",
|
||||
embedding_name="community_report_title",
|
||||
)
|
||||
|
||||
return create_verb_result(
|
||||
cast(
|
||||
Table,
|
||||
community_reports,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame:
|
||||
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
|
||||
# merge values of four columns into a map column
|
||||
input[NODE_DETAILS] = input.apply(
|
||||
lambda x: {
|
||||
NODE_ID: x[NODE_ID],
|
||||
NODE_NAME: x[NODE_NAME],
|
||||
NODE_DESCRIPTION: x[NODE_DESCRIPTION],
|
||||
NODE_DEGREE: x[NODE_DEGREE],
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
return input
|
||||
|
||||
|
||||
def _prep_edges(input: pd.DataFrame) -> pd.DataFrame:
|
||||
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
|
||||
input[EDGE_DETAILS] = input.apply(
|
||||
lambda x: {
|
||||
EDGE_ID: x[EDGE_ID],
|
||||
EDGE_SOURCE: x[EDGE_SOURCE],
|
||||
EDGE_TARGET: x[EDGE_TARGET],
|
||||
EDGE_DESCRIPTION: x[EDGE_DESCRIPTION],
|
||||
EDGE_DEGREE: x[EDGE_DEGREE],
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
return input
|
||||
|
||||
|
||||
def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
|
||||
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
|
||||
input[CLAIM_DETAILS] = input.apply(
|
||||
lambda x: {
|
||||
CLAIM_ID: x[CLAIM_ID],
|
||||
CLAIM_SUBJECT: x[CLAIM_SUBJECT],
|
||||
CLAIM_TYPE: x[CLAIM_TYPE],
|
||||
CLAIM_STATUS: x[CLAIM_STATUS],
|
||||
CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION],
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
return input
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""LLM Static Response method definition."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing_extensions import Unpack
|
||||
@ -12,6 +13,7 @@ from graphrag.llm.types import (
|
||||
CompletionInput,
|
||||
CompletionOutput,
|
||||
LLMInput,
|
||||
LLMOutput,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -35,3 +37,6 @@ class MockCompletionLLM(
|
||||
**kwargs: Unpack[LLMInput],
|
||||
) -> CompletionOutput:
|
||||
return self.responses[0]
|
||||
|
||||
async def _invoke_json(self, input: CompletionInput, **kwargs: Unpack[LLMInput]):
|
||||
return LLMOutput(output=self.responses[0], json=json.loads(self.responses[0]))
|
||||
|
||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -93,7 +93,7 @@
|
||||
"rank_explanation",
|
||||
"findings"
|
||||
],
|
||||
"subworkflows": 6,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_final_text_units": {
|
||||
|
||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -110,7 +110,7 @@
|
||||
"rank_explanation",
|
||||
"findings"
|
||||
],
|
||||
"subworkflows": 7,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_final_text_units": {
|
||||
|
||||
86
tests/verbs/test_create_final_community_reports.py
Normal file
86
tests/verbs/test_create_final_community_reports.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.index.workflows.v1.create_final_community_reports import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
compare_outputs,
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_expected,
|
||||
load_input_tables,
|
||||
remove_disabled_steps,
|
||||
)
|
||||
|
||||
|
||||
async def test_create_final_community_reports():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_final_nodes",
|
||||
"workflow:create_final_covariates",
|
||||
"workflow:create_final_relationships",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
# deleting the llm config results in a default mock injection in run_graph_intelligence
|
||||
del config["create_community_reports"]["strategy"]["llm"]
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(actual.columns) == len(expected.columns)
|
||||
|
||||
# only assert a couple of columns that are not mock - most of this table is LLM-generated
|
||||
compare_outputs(actual, expected, columns=["community", "level"])
|
||||
|
||||
# assert a handful of mock data items to confirm they get put in the right spot
|
||||
assert actual["rank"][:1][0] == 2
|
||||
assert actual["rank_explanation"][:1][0] == "<rating_explanation>"
|
||||
|
||||
|
||||
async def test_create_final_community_reports_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_final_nodes",
|
||||
"workflow:create_final_covariates",
|
||||
"workflow:create_final_relationships",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
# deleting the llm config results in a default mock injection in run_graph_intelligence
|
||||
del config["create_community_reports"]["strategy"]["llm"]
|
||||
|
||||
config["skip_full_content_embedding"] = False
|
||||
config["community_report_full_content_embed"]["strategy"]["type"] = "mock"
|
||||
config["skip_summary_embedding"] = False
|
||||
config["community_report_summary_embed"]["strategy"]["type"] = "mock"
|
||||
config["skip_title_embedding"] = False
|
||||
config["community_report_title_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(actual.columns) == len(expected.columns) + 3
|
||||
assert "full_content_embedding" in actual.columns
|
||||
assert len(actual["full_content_embedding"][:1][0]) == 3
|
||||
assert "summary_embedding" in actual.columns
|
||||
assert len(actual["summary_embedding"][:1][0]) == 3
|
||||
assert "title_embedding" in actual.columns
|
||||
assert len(actual["title_embedding"][:1][0]) == 3
|
||||
@ -15,6 +15,8 @@ from graphrag.index import (
|
||||
)
|
||||
from graphrag.index.run.utils import _create_run_context
|
||||
|
||||
pd.set_option("display.max_columns", None)
|
||||
|
||||
|
||||
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
|
||||
"""Harvest all the referenced input IDs from the workflow being tested and pass them here."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user