Flow cleanup (#1510)

* Move snapshots out of flows into verbs

* Move degree compute out of extract_graph

* Move entity/relationship df merging into extract

* Move "title" to extraction source

* Move text_unit_ids agg closer to extraction

* Move data definition

* Update test data

* Semver

* Update smoke tests

* Fix empty degree field and update smoke tests and verb data

* Move extractors (#1516)

* Consolidate graph embedding and umap

* Consolidate claim extraction

* Consolidate graph extractor

* Move graph utils

* Move summarizers

* Semver

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>

* Fix syntax typo

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Nathan Evans 2024-12-18 18:07:44 -08:00 committed by GitHub
parent d0543d1fd6
commit c1c09bab80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 143 additions and 173 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Streamline flows."
}

View File

@ -9,15 +9,11 @@ import pandas as pd
from graphrag.index.operations.cluster_graph import cluster_graph from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage
async def compute_communities( def compute_communities(
base_relationship_edges: pd.DataFrame, base_relationship_edges: pd.DataFrame,
storage: PipelineStorage,
clustering_strategy: dict[str, Any], clustering_strategy: dict[str, Any],
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame: ) -> pd.DataFrame:
"""All the steps to create the base entity graph.""" """All the steps to create the base entity graph."""
graph = create_graph(base_relationship_edges) graph = create_graph(base_relationship_edges)
@ -32,12 +28,4 @@ async def compute_communities(
).explode("title") ).explode("title")
base_communities["community"] = base_communities["community"].astype(int) base_communities["community"] = base_communities["community"].astype(int)
if snapshot_transient_enabled:
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)
return base_communities return base_communities

View File

@ -15,18 +15,14 @@ from datashaper import (
) )
from graphrag.index.operations.chunk_text import chunk_text from graphrag.index.operations.chunk_text import chunk_text
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.utils.hashing import gen_sha512_hash from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.storage.pipeline_storage import PipelineStorage
async def create_base_text_units( def create_base_text_units(
documents: pd.DataFrame, documents: pd.DataFrame,
callbacks: VerbCallbacks, callbacks: VerbCallbacks,
storage: PipelineStorage,
chunk_by_columns: list[str], chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None, chunk_strategy: dict[str, Any] | None = None,
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame: ) -> pd.DataFrame:
"""All the steps to transform base text_units.""" """All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True]) sort = documents.sort_values(by=["id"], ascending=[True])
@ -74,19 +70,7 @@ async def create_base_text_units(
# rename for downstream consumption # rename for downstream consumption
chunked.rename(columns={"chunk": "text"}, inplace=True) chunked.rename(columns={"chunk": "text"}, inplace=True)
output = cast( return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))
"pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)
)
if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_text_units",
storage=storage,
formats=["parquet"],
)
return output
# TODO: would be nice to inline this completely in the main method with pandas # TODO: would be nice to inline this completely in the main method with pandas

View File

@ -10,6 +10,7 @@ from datashaper import (
VerbCallbacks, VerbCallbacks,
) )
from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.create_graph import create_graph from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.embed_graph.embed_graph import embed_graph from graphrag.index.operations.embed_graph.embed_graph import embed_graph
from graphrag.index.operations.layout_graph.layout_graph import layout_graph from graphrag.index.operations.layout_graph.layout_graph import layout_graph
@ -37,15 +38,19 @@ def create_final_nodes(
layout_strategy, layout_strategy,
embeddings=graph_embeddings, embeddings=graph_embeddings,
) )
nodes = base_entity_nodes.merge(
layout, left_on="title", right_on="label", how="left" degrees = compute_degree(graph)
nodes = (
base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left")
.merge(degrees, on="title", how="left")
.merge(base_communities, on="title", how="left")
) )
nodes["level"] = nodes["level"].fillna(0).astype(int)
joined = nodes.merge(base_communities, on="title", how="left") nodes["community"] = nodes["community"].fillna(-1).astype(int)
joined["level"] = joined["level"].fillna(0).astype(int) # disconnected nodes and those with no community even at level 0 can be missing degree
joined["community"] = joined["community"].fillna(-1).astype(int) nodes["degree"] = nodes["degree"].fillna(0).astype(int)
return nodes.loc[
return joined.loc[
:, :,
[ [
"id", "id",

View File

@ -5,20 +5,25 @@
import pandas as pd import pandas as pd
from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.compute_edge_combined_degree import ( from graphrag.index.operations.compute_edge_combined_degree import (
compute_edge_combined_degree, compute_edge_combined_degree,
) )
from graphrag.index.operations.create_graph import create_graph
def create_final_relationships( def create_final_relationships(
base_relationship_edges: pd.DataFrame, base_relationship_edges: pd.DataFrame,
base_entity_nodes: pd.DataFrame,
) -> pd.DataFrame: ) -> pd.DataFrame:
"""All the steps to transform final relationships.""" """All the steps to transform final relationships."""
relationships = base_relationship_edges relationships = base_relationship_edges
graph = create_graph(base_relationship_edges)
degrees = compute_degree(graph)
relationships["combined_degree"] = compute_edge_combined_degree( relationships["combined_degree"] = compute_edge_combined_degree(
relationships, relationships,
base_entity_nodes, degrees,
node_name_column="title", node_name_column="title",
node_degree_column="degree", node_degree_column="degree",
edge_source_column="source", edge_source_column="source",

View File

@ -6,7 +6,6 @@
from typing import Any from typing import Any
from uuid import uuid4 from uuid import uuid4
import networkx as nx
import pandas as pd import pandas as pd
from datashaper import ( from datashaper import (
AsyncType, AsyncType,
@ -14,33 +13,26 @@ from datashaper import (
) )
from graphrag.cache.pipeline_cache import PipelineCache from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.extract_entities import extract_entities from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.operations.summarize_descriptions import ( from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions, summarize_descriptions,
) )
from graphrag.storage.pipeline_storage import PipelineStorage
async def extract_graph( async def extract_graph(
text_units: pd.DataFrame, text_units: pd.DataFrame,
callbacks: VerbCallbacks, callbacks: VerbCallbacks,
cache: PipelineCache, cache: PipelineCache,
storage: PipelineStorage,
extraction_strategy: dict[str, Any] | None = None, extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4, extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO, extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None, entity_types: list[str] | None = None,
summarization_strategy: dict[str, Any] | None = None, summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4, summarization_num_threads: int = 4,
snapshot_graphml_enabled: bool = False,
snapshot_transient_enabled: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]: ) -> tuple[pd.DataFrame, pd.DataFrame]:
"""All the steps to create the base entity graph.""" """All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later # this returns a graph for each text unit, to be merged later
entity_dfs, relationship_dfs = await extract_entities( entities, relationships = await extract_entities(
text_units, text_units,
callbacks, callbacks,
cache, cache,
@ -52,87 +44,38 @@ async def extract_graph(
num_threads=extraction_num_threads, num_threads=extraction_num_threads,
) )
if not _validate_data(entity_dfs): if not _validate_data(entities):
error_msg = "Entity Extraction failed. No entities detected during extraction." error_msg = "Entity Extraction failed. No entities detected during extraction."
callbacks.error(error_msg) callbacks.error(error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
if not _validate_data(relationship_dfs): if not _validate_data(relationships):
error_msg = ( error_msg = (
"Entity Extraction failed. No relationships detected during extraction." "Entity Extraction failed. No relationships detected during extraction."
) )
callbacks.error(error_msg) callbacks.error(error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
merged_entities = _merge_entities(entity_dfs)
merged_relationships = _merge_relationships(relationship_dfs)
entity_summaries, relationship_summaries = await summarize_descriptions( entity_summaries, relationship_summaries = await summarize_descriptions(
merged_entities, entities,
merged_relationships, relationships,
callbacks, callbacks,
cache, cache,
strategy=summarization_strategy, strategy=summarization_strategy,
num_threads=summarization_num_threads, num_threads=summarization_num_threads,
) )
base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries) base_relationship_edges = _prep_edges(relationships, relationship_summaries)
graph = create_graph(base_relationship_edges) base_entity_nodes = _prep_nodes(entities, entity_summaries)
base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)
if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
await snapshot_graphml(
graph,
name="graph",
storage=storage,
)
if snapshot_transient_enabled:
await snapshot(
base_entity_nodes,
name="base_entity_nodes",
storage=storage,
formats=["parquet"],
)
await snapshot(
base_relationship_edges,
name="base_relationship_edges",
storage=storage,
formats=["parquet"],
)
return (base_entity_nodes, base_relationship_edges) return (base_entity_nodes, base_relationship_edges)
def _merge_entities(entity_dfs) -> pd.DataFrame: def _prep_nodes(entities, summaries) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["name", "type"], sort=False)
.agg({"description": list, "source_id": list})
.reset_index()
)
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg({"description": list, "source_id": list, "weight": "sum"})
.reset_index()
)
def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
degrees_df = _compute_degree(graph)
entities.drop(columns=["description"], inplace=True) entities.drop(columns=["description"], inplace=True)
nodes = ( nodes = entities.merge(summaries, on="title", how="left").drop_duplicates(
entities.merge(summaries, on="name", how="left") subset="title"
.merge(degrees_df, on="name")
.drop_duplicates(subset="name")
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
) )
nodes = nodes.loc[nodes["title"].notna()].reset_index() nodes = nodes.loc[nodes["title"].notna()].reset_index()
nodes["human_readable_id"] = nodes.index nodes["human_readable_id"] = nodes.index
@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
relationships.drop(columns=["description"]) relationships.drop(columns=["description"])
.drop_duplicates(subset=["source", "target"]) .drop_duplicates(subset=["source", "target"])
.merge(summaries, on=["source", "target"], how="left") .merge(summaries, on=["source", "target"], how="left")
.rename(columns={"source_id": "text_unit_ids"})
) )
edges["human_readable_id"] = edges.index edges["human_readable_id"] = edges.index
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4())) edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
return edges return edges
def _compute_degree(graph: nx.Graph) -> pd.DataFrame: def _validate_data(df: pd.DataFrame) -> bool:
return pd.DataFrame([ """Validate that the dataframe has data."""
{"name": node, "degree": int(degree)} return len(df) > 0
for node, degree in graph.degree # type: ignore
])
def _validate_data(df_list: list[pd.DataFrame]) -> bool:
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
return any(
len(df) > 0 for df in df_list
) # Check for len, not .empty, as the dfs have schemas in some cases

View File

@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings(
strategy=text_embed_config["strategy"], strategy=text_embed_config["strategy"],
) )
data = data.loc[:, ["id", "embedding"]]
if snapshot_embeddings_enabled is True: if snapshot_embeddings_enabled is True:
data = data.loc[:, ["id", "embedding"]]
await snapshot( await snapshot(
data, data,
name=f"embeddings.{name}", name=f"embeddings.{name}",

View File

@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_graph definition."""
import networkx as nx
import pandas as pd
def compute_degree(graph: nx.Graph) -> pd.DataFrame:
"""Create a new DataFrame with the degree of each node in the graph."""
return pd.DataFrame([
{"title": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])

View File

@ -37,7 +37,7 @@ async def extract_entities(
async_mode: AsyncType = AsyncType.AsyncIO, async_mode: AsyncType = AsyncType.AsyncIO,
entity_types=DEFAULT_ENTITY_TYPES, entity_types=DEFAULT_ENTITY_TYPES,
num_threads: int = 4, num_threads: int = 4,
) -> tuple[list[pd.DataFrame], list[pd.DataFrame]]: ) -> tuple[pd.DataFrame, pd.DataFrame]:
""" """
Extract entities from a piece of text. Extract entities from a piece of text.
@ -138,7 +138,10 @@ async def extract_entities(
entity_dfs.append(pd.DataFrame(result[0])) entity_dfs.append(pd.DataFrame(result[0]))
relationship_dfs.append(pd.DataFrame(result[1])) relationship_dfs.append(pd.DataFrame(result[1]))
return (entity_dfs, relationship_dfs) entities = _merge_entities(entity_dfs)
relationships = _merge_relationships(relationship_dfs)
return (entities, relationships)
def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy: def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
@ -162,3 +165,25 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr
case _: case _:
msg = f"Unknown strategy: {strategy_type}" msg = f"Unknown strategy: {strategy_type}"
raise ValueError(msg) raise ValueError(msg)
def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["title", "type"], sort=False)
.agg(description=("description", list), text_unit_ids=("source_id", list))
.reset_index()
)
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg(
description=("description", list),
text_unit_ids=("source_id", list),
weight=("weight", "sum"),
)
.reset_index()
)

View File

@ -106,7 +106,7 @@ async def run_extract_entities(
) )
entities = [ entities = [
({"name": item[0], **(item[1] or {})}) ({"title": item[0], **(item[1] or {})})
for item in graph.nodes(data=True) for item in graph.nodes(data=True)
if item is not None if item is not None
] ]

View File

@ -58,7 +58,7 @@ async def run( # noqa RUF029 async is required for interface
return EntityExtractionResult( return EntityExtractionResult(
entities=[ entities=[
{"type": entity_type, "name": name} {"type": entity_type, "title": name}
for name, entity_type in entity_map.items() for name, entity_type in entity_map.items()
], ],
relationships=[], relationships=[],

View File

@ -88,7 +88,7 @@ async def summarize_descriptions(
node_futures = [ node_futures = [
do_summarize_descriptions( do_summarize_descriptions(
str(row[1]["name"]), str(row[1]["title"]),
sorted(set(row[1]["description"])), sorted(set(row[1]["description"])),
ticker, ticker,
semaphore, semaphore,
@ -100,7 +100,7 @@ async def summarize_descriptions(
node_descriptions = [ node_descriptions = [
{ {
"name": result.id, "title": result.id,
"description": result.description, "description": result.description,
} }
for result in node_results for result in node_results

View File

@ -14,6 +14,7 @@ from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
from graphrag.index.flows.compute_communities import compute_communities from graphrag.index.flows.compute_communities import compute_communities
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.storage.pipeline_storage import PipelineStorage
workflow_name = "compute_communities" workflow_name = "compute_communities"
@ -62,13 +63,19 @@ async def workflow(
"""All the steps to create the base entity graph.""" """All the steps to create the base entity graph."""
base_relationship_edges = await runtime_storage.get("base_relationship_edges") base_relationship_edges = await runtime_storage.get("base_relationship_edges")
base_communities = await compute_communities( base_communities = compute_communities(
base_relationship_edges, base_relationship_edges,
storage,
clustering_strategy=clustering_strategy, clustering_strategy=clustering_strategy,
snapshot_transient_enabled=snapshot_transient_enabled,
) )
await runtime_storage.set("base_communities", base_communities) await runtime_storage.set("base_communities", base_communities)
if snapshot_transient_enabled:
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)
return create_verb_result(cast("Table", pd.DataFrame())) return create_verb_result(cast("Table", pd.DataFrame()))

View File

@ -19,6 +19,7 @@ from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkf
from graphrag.index.flows.create_base_text_units import ( from graphrag.index.flows.create_base_text_units import (
create_base_text_units, create_base_text_units,
) )
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.storage.pipeline_storage import PipelineStorage
workflow_name = "create_base_text_units" workflow_name = "create_base_text_units"
@ -65,17 +66,23 @@ async def workflow(
"""All the steps to transform base text_units.""" """All the steps to transform base text_units."""
source = cast("pd.DataFrame", input.get_input()) source = cast("pd.DataFrame", input.get_input())
output = await create_base_text_units( output = create_base_text_units(
source, source,
callbacks, callbacks,
storage,
chunk_by_columns, chunk_by_columns,
chunk_strategy=chunk_strategy, chunk_strategy=chunk_strategy,
snapshot_transient_enabled=snapshot_transient_enabled,
) )
await runtime_storage.set("base_text_units", output) await runtime_storage.set("base_text_units", output)
if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_text_units",
storage=storage,
formats=["parquet"],
)
return create_verb_result( return create_verb_result(
cast( cast(
"Table", "Table",

View File

@ -53,8 +53,7 @@ async def workflow(
) -> VerbResult: ) -> VerbResult:
"""All the steps to transform final relationships.""" """All the steps to transform final relationships."""
base_relationship_edges = await runtime_storage.get("base_relationship_edges") base_relationship_edges = await runtime_storage.get("base_relationship_edges")
base_entity_nodes = await runtime_storage.get("base_entity_nodes")
output = create_final_relationships(base_relationship_edges, base_entity_nodes) output = create_final_relationships(base_relationship_edges)
return create_verb_result(cast("Table", output)) return create_verb_result(cast("Table", output))

View File

@ -19,6 +19,9 @@ from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkf
from graphrag.index.flows.extract_graph import ( from graphrag.index.flows.extract_graph import (
extract_graph, extract_graph,
) )
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.storage.pipeline_storage import PipelineStorage
workflow_name = "extract_graph" workflow_name = "extract_graph"
@ -90,18 +93,38 @@ async def workflow(
text_units, text_units,
callbacks, callbacks,
cache, cache,
storage,
extraction_strategy=extraction_strategy, extraction_strategy=extraction_strategy,
extraction_num_threads=extraction_num_threads, extraction_num_threads=extraction_num_threads,
extraction_async_mode=extraction_async_mode, extraction_async_mode=extraction_async_mode,
entity_types=entity_types, entity_types=entity_types,
summarization_strategy=summarization_strategy, summarization_strategy=summarization_strategy,
summarization_num_threads=summarization_num_threads, summarization_num_threads=summarization_num_threads,
snapshot_graphml_enabled=snapshot_graphml_enabled,
snapshot_transient_enabled=snapshot_transient_enabled,
) )
await runtime_storage.set("base_entity_nodes", base_entity_nodes) await runtime_storage.set("base_entity_nodes", base_entity_nodes)
await runtime_storage.set("base_relationship_edges", base_relationship_edges) await runtime_storage.set("base_relationship_edges", base_relationship_edges)
if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
graph = create_graph(base_relationship_edges)
await snapshot_graphml(
graph,
name="graph",
storage=storage,
)
if snapshot_transient_enabled:
await snapshot(
base_entity_nodes,
name="base_entity_nodes",
storage=storage,
formats=["parquet"],
)
await snapshot(
base_relationship_edges,
name="base_relationship_edges",
storage=storage,
formats=["parquet"],
)
return create_verb_result(cast("Table", pd.DataFrame())) return create_verb_result(cast("Table", pd.DataFrame()))

View File

@ -58,8 +58,8 @@
], ],
"nan_allowed_columns": [ "nan_allowed_columns": [
"description", "description",
"community", "x",
"level" "y"
], ],
"subworkflows": 1, "subworkflows": 1,
"max_runtime": 150, "max_runtime": 150,

View File

@ -76,8 +76,8 @@
], ],
"nan_allowed_columns": [ "nan_allowed_columns": [
"description", "description",
"community", "x",
"level" "y"
], ],
"subworkflows": 1, "subworkflows": 1,
"max_runtime": 150, "max_runtime": 150,

View File

@ -42,7 +42,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
# self.assertItemsEqual isn't available yet, or I am just silly # self.assertItemsEqual isn't available yet, or I am just silly
# so we sort the lists and compare them # so we sort the lists and compare them
assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([ assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([
entity["name"] for entity in results.entities entity["title"] for entity in results.entities
]) ])
async def test_run_extract_entities_multiple_documents_correct_entities_returned( async def test_run_extract_entities_multiple_documents_correct_entities_returned(
@ -81,7 +81,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
# self.assertItemsEqual isn't available yet, or I am just silly # self.assertItemsEqual isn't available yet, or I am just silly
# so we sort the lists and compare them # so we sort the lists and compare them
assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([ assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([
entity["name"] for entity in results.entities entity["title"] for entity in results.entities
]) ])
async def test_run_extract_entities_multiple_documents_correct_edges_returned(self): async def test_run_extract_entities_multiple_documents_correct_edges_returned(self):

View File

@ -4,7 +4,6 @@
from graphrag.index.flows.compute_communities import ( from graphrag.index.flows.compute_communities import (
compute_communities, compute_communities,
) )
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.compute_communities import ( from graphrag.index.workflows.v1.compute_communities import (
workflow_name, workflow_name,
) )
@ -16,37 +15,15 @@ from .util import (
) )
async def test_compute_communities(): def test_compute_communities():
edges = load_test_table("base_relationship_edges") edges = load_test_table("base_relationship_edges")
expected = load_test_table("base_communities") expected = load_test_table("base_communities")
context = create_run_context(None, None, None)
config = get_config_for_workflow(workflow_name) config = get_config_for_workflow(workflow_name)
clustering_strategy = config["cluster_graph"]["strategy"] clustering_strategy = config["cluster_graph"]["strategy"]
actual = await compute_communities( actual = compute_communities(edges, clustering_strategy=clustering_strategy)
edges, storage=context.storage, clustering_strategy=clustering_strategy
)
columns = list(expected.columns.values) columns = list(expected.columns.values)
compare_outputs(actual, expected, columns) compare_outputs(actual, expected, columns)
assert len(actual.columns) == len(expected.columns) assert len(actual.columns) == len(expected.columns)
async def test_compute_communities_with_snapshots():
edges = load_test_table("base_relationship_edges")
context = create_run_context(None, None, None)
config = get_config_for_workflow(workflow_name)
clustering_strategy = config["cluster_graph"]["strategy"]
await compute_communities(
edges,
storage=context.storage,
clustering_strategy=clustering_strategy,
snapshot_transient_enabled=True,
)
assert context.storage.keys() == [
"base_communities.parquet",
], "Community snapshot keys differ"

View File

@ -16,10 +16,9 @@ from .util import (
def test_create_final_relationships(): def test_create_final_relationships():
edges = load_test_table("base_relationship_edges") edges = load_test_table("base_relationship_edges")
nodes = load_test_table("base_entity_nodes")
expected = load_test_table(workflow_name) expected = load_test_table(workflow_name)
actual = create_final_relationships(edges, nodes) actual = create_final_relationships(edges)
assert "id" in expected.columns assert "id" in expected.columns
columns = list(expected.columns.values) columns = list(expected.columns.values)