Flow cleanup (#1510)

* Move snapshots out of flows into verbs

* Move degree compute out of extract_graph

* Move entity/relationship df merging into extract

* Move "title" to extraction source

* Move text_unit_ids agg closer to extraction

* Move data definition

* Update test data

* Semver

* Update smoke tests

* Fix empty degree field and update smoke tests and verb data

* Move extractors (#1516)

* Consolidate graph embedding and umap

* Consolidate claim extraction

* Consolidate graph extractor

* Move graph utils

* Move summarizers

* Semver

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>

* Fix syntax typo

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Nathan Evans 2024-12-18 18:07:44 -08:00 committed by GitHub
parent d0543d1fd6
commit c1c09bab80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 143 additions and 173 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Streamline flows."
}

View File

@ -9,15 +9,11 @@ import pandas as pd
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage
async def compute_communities(
def compute_communities(
base_relationship_edges: pd.DataFrame,
storage: PipelineStorage,
clustering_strategy: dict[str, Any],
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
graph = create_graph(base_relationship_edges)
@ -32,12 +28,4 @@ async def compute_communities(
).explode("title")
base_communities["community"] = base_communities["community"].astype(int)
if snapshot_transient_enabled:
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)
return base_communities

View File

@ -15,18 +15,14 @@ from datashaper import (
)
from graphrag.index.operations.chunk_text import chunk_text
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.storage.pipeline_storage import PipelineStorage
async def create_base_text_units(
def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
storage: PipelineStorage,
chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None,
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
@ -74,19 +70,7 @@ async def create_base_text_units(
# rename for downstream consumption
chunked.rename(columns={"chunk": "text"}, inplace=True)
output = cast(
"pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)
)
if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_text_units",
storage=storage,
formats=["parquet"],
)
return output
return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))
# TODO: would be nice to inline this completely in the main method with pandas

View File

@ -10,6 +10,7 @@ from datashaper import (
VerbCallbacks,
)
from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
from graphrag.index.operations.layout_graph.layout_graph import layout_graph
@ -37,15 +38,19 @@ def create_final_nodes(
layout_strategy,
embeddings=graph_embeddings,
)
nodes = base_entity_nodes.merge(
layout, left_on="title", right_on="label", how="left"
degrees = compute_degree(graph)
nodes = (
base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left")
.merge(degrees, on="title", how="left")
.merge(base_communities, on="title", how="left")
)
joined = nodes.merge(base_communities, on="title", how="left")
joined["level"] = joined["level"].fillna(0).astype(int)
joined["community"] = joined["community"].fillna(-1).astype(int)
return joined.loc[
nodes["level"] = nodes["level"].fillna(0).astype(int)
nodes["community"] = nodes["community"].fillna(-1).astype(int)
# disconnected nodes and those with no community even at level 0 can be missing degree
nodes["degree"] = nodes["degree"].fillna(0).astype(int)
return nodes.loc[
:,
[
"id",

View File

@ -5,20 +5,25 @@
import pandas as pd
from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.compute_edge_combined_degree import (
compute_edge_combined_degree,
)
from graphrag.index.operations.create_graph import create_graph
def create_final_relationships(
base_relationship_edges: pd.DataFrame,
base_entity_nodes: pd.DataFrame,
) -> pd.DataFrame:
"""All the steps to transform final relationships."""
relationships = base_relationship_edges
graph = create_graph(base_relationship_edges)
degrees = compute_degree(graph)
relationships["combined_degree"] = compute_edge_combined_degree(
relationships,
base_entity_nodes,
degrees,
node_name_column="title",
node_degree_column="degree",
edge_source_column="source",

View File

@ -6,7 +6,6 @@
from typing import Any
from uuid import uuid4
import networkx as nx
import pandas as pd
from datashaper import (
AsyncType,
@ -14,33 +13,26 @@ from datashaper import (
)
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.storage.pipeline_storage import PipelineStorage
async def extract_graph(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
snapshot_graphml_enabled: bool = False,
snapshot_transient_enabled: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entity_dfs, relationship_dfs = await extract_entities(
entities, relationships = await extract_entities(
text_units,
callbacks,
cache,
@ -52,87 +44,38 @@ async def extract_graph(
num_threads=extraction_num_threads,
)
if not _validate_data(entity_dfs):
if not _validate_data(entities):
error_msg = "Entity Extraction failed. No entities detected during extraction."
callbacks.error(error_msg)
raise ValueError(error_msg)
if not _validate_data(relationship_dfs):
if not _validate_data(relationships):
error_msg = (
"Entity Extraction failed. No relationships detected during extraction."
)
callbacks.error(error_msg)
raise ValueError(error_msg)
merged_entities = _merge_entities(entity_dfs)
merged_relationships = _merge_relationships(relationship_dfs)
entity_summaries, relationship_summaries = await summarize_descriptions(
merged_entities,
merged_relationships,
entities,
relationships,
callbacks,
cache,
strategy=summarization_strategy,
num_threads=summarization_num_threads,
)
base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)
base_relationship_edges = _prep_edges(relationships, relationship_summaries)
graph = create_graph(base_relationship_edges)
base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)
if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
await snapshot_graphml(
graph,
name="graph",
storage=storage,
)
if snapshot_transient_enabled:
await snapshot(
base_entity_nodes,
name="base_entity_nodes",
storage=storage,
formats=["parquet"],
)
await snapshot(
base_relationship_edges,
name="base_relationship_edges",
storage=storage,
formats=["parquet"],
)
base_entity_nodes = _prep_nodes(entities, entity_summaries)
return (base_entity_nodes, base_relationship_edges)
def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["name", "type"], sort=False)
.agg({"description": list, "source_id": list})
.reset_index()
)
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg({"description": list, "source_id": list, "weight": "sum"})
.reset_index()
)
def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
degrees_df = _compute_degree(graph)
def _prep_nodes(entities, summaries) -> pd.DataFrame:
entities.drop(columns=["description"], inplace=True)
nodes = (
entities.merge(summaries, on="name", how="left")
.merge(degrees_df, on="name")
.drop_duplicates(subset="name")
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
nodes = entities.merge(summaries, on="title", how="left").drop_duplicates(
subset="title"
)
nodes = nodes.loc[nodes["title"].notna()].reset_index()
nodes["human_readable_id"] = nodes.index
@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
relationships.drop(columns=["description"])
.drop_duplicates(subset=["source", "target"])
.merge(summaries, on=["source", "target"], how="left")
.rename(columns={"source_id": "text_unit_ids"})
)
edges["human_readable_id"] = edges.index
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
return edges
def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
return pd.DataFrame([
{"name": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])
def _validate_data(df_list: list[pd.DataFrame]) -> bool:
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
return any(
len(df) > 0 for df in df_list
) # Check for len, not .empty, as the dfs have schemas in some cases
def _validate_data(df: pd.DataFrame) -> bool:
"""Validate that the dataframe has data."""
return len(df) > 0

View File

@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings(
strategy=text_embed_config["strategy"],
)
data = data.loc[:, ["id", "embedding"]]
if snapshot_embeddings_enabled is True:
data = data.loc[:, ["id", "embedding"]]
await snapshot(
data,
name=f"embeddings.{name}",

View File

@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_graph definition."""
import networkx as nx
import pandas as pd
def compute_degree(graph: nx.Graph) -> pd.DataFrame:
"""Create a new DataFrame with the degree of each node in the graph."""
return pd.DataFrame([
{"title": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])

View File

@ -37,7 +37,7 @@ async def extract_entities(
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types=DEFAULT_ENTITY_TYPES,
num_threads: int = 4,
) -> tuple[list[pd.DataFrame], list[pd.DataFrame]]:
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Extract entities from a piece of text.
@ -138,7 +138,10 @@ async def extract_entities(
entity_dfs.append(pd.DataFrame(result[0]))
relationship_dfs.append(pd.DataFrame(result[1]))
return (entity_dfs, relationship_dfs)
entities = _merge_entities(entity_dfs)
relationships = _merge_relationships(relationship_dfs)
return (entities, relationships)
def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
@ -162,3 +165,25 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr
case _:
msg = f"Unknown strategy: {strategy_type}"
raise ValueError(msg)
def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["title", "type"], sort=False)
.agg(description=("description", list), text_unit_ids=("source_id", list))
.reset_index()
)
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg(
description=("description", list),
text_unit_ids=("source_id", list),
weight=("weight", "sum"),
)
.reset_index()
)

View File

@ -106,7 +106,7 @@ async def run_extract_entities(
)
entities = [
({"name": item[0], **(item[1] or {})})
({"title": item[0], **(item[1] or {})})
for item in graph.nodes(data=True)
if item is not None
]

View File

@ -58,7 +58,7 @@ async def run( # noqa RUF029 async is required for interface
return EntityExtractionResult(
entities=[
{"type": entity_type, "name": name}
{"type": entity_type, "title": name}
for name, entity_type in entity_map.items()
],
relationships=[],

View File

@ -88,7 +88,7 @@ async def summarize_descriptions(
node_futures = [
do_summarize_descriptions(
str(row[1]["name"]),
str(row[1]["title"]),
sorted(set(row[1]["description"])),
ticker,
semaphore,
@ -100,7 +100,7 @@ async def summarize_descriptions(
node_descriptions = [
{
"name": result.id,
"title": result.id,
"description": result.description,
}
for result in node_results

View File

@ -14,6 +14,7 @@ from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
from graphrag.index.flows.compute_communities import compute_communities
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage
workflow_name = "compute_communities"
@ -62,13 +63,19 @@ async def workflow(
"""All the steps to create the base entity graph."""
base_relationship_edges = await runtime_storage.get("base_relationship_edges")
base_communities = await compute_communities(
base_communities = compute_communities(
base_relationship_edges,
storage,
clustering_strategy=clustering_strategy,
snapshot_transient_enabled=snapshot_transient_enabled,
)
await runtime_storage.set("base_communities", base_communities)
if snapshot_transient_enabled:
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)
return create_verb_result(cast("Table", pd.DataFrame()))

View File

@ -19,6 +19,7 @@ from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkf
from graphrag.index.flows.create_base_text_units import (
create_base_text_units,
)
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage
workflow_name = "create_base_text_units"
@ -65,17 +66,23 @@ async def workflow(
"""All the steps to transform base text_units."""
source = cast("pd.DataFrame", input.get_input())
output = await create_base_text_units(
output = create_base_text_units(
source,
callbacks,
storage,
chunk_by_columns,
chunk_strategy=chunk_strategy,
snapshot_transient_enabled=snapshot_transient_enabled,
)
await runtime_storage.set("base_text_units", output)
if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_text_units",
storage=storage,
formats=["parquet"],
)
return create_verb_result(
cast(
"Table",

View File

@ -53,8 +53,7 @@ async def workflow(
) -> VerbResult:
"""All the steps to transform final relationships."""
base_relationship_edges = await runtime_storage.get("base_relationship_edges")
base_entity_nodes = await runtime_storage.get("base_entity_nodes")
output = create_final_relationships(base_relationship_edges, base_entity_nodes)
output = create_final_relationships(base_relationship_edges)
return create_verb_result(cast("Table", output))

View File

@ -19,6 +19,9 @@ from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkf
from graphrag.index.flows.extract_graph import (
extract_graph,
)
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.storage.pipeline_storage import PipelineStorage
workflow_name = "extract_graph"
@ -90,18 +93,38 @@ async def workflow(
text_units,
callbacks,
cache,
storage,
extraction_strategy=extraction_strategy,
extraction_num_threads=extraction_num_threads,
extraction_async_mode=extraction_async_mode,
entity_types=entity_types,
summarization_strategy=summarization_strategy,
summarization_num_threads=summarization_num_threads,
snapshot_graphml_enabled=snapshot_graphml_enabled,
snapshot_transient_enabled=snapshot_transient_enabled,
)
await runtime_storage.set("base_entity_nodes", base_entity_nodes)
await runtime_storage.set("base_relationship_edges", base_relationship_edges)
if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
graph = create_graph(base_relationship_edges)
await snapshot_graphml(
graph,
name="graph",
storage=storage,
)
if snapshot_transient_enabled:
await snapshot(
base_entity_nodes,
name="base_entity_nodes",
storage=storage,
formats=["parquet"],
)
await snapshot(
base_relationship_edges,
name="base_relationship_edges",
storage=storage,
formats=["parquet"],
)
return create_verb_result(cast("Table", pd.DataFrame()))

View File

@ -58,8 +58,8 @@
],
"nan_allowed_columns": [
"description",
"community",
"level"
"x",
"y"
],
"subworkflows": 1,
"max_runtime": 150,

View File

@ -76,8 +76,8 @@
],
"nan_allowed_columns": [
"description",
"community",
"level"
"x",
"y"
],
"subworkflows": 1,
"max_runtime": 150,

View File

@ -42,7 +42,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
# self.assertItemsEqual isn't available yet, or I am just silly
# so we sort the lists and compare them
assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([
entity["name"] for entity in results.entities
entity["title"] for entity in results.entities
])
async def test_run_extract_entities_multiple_documents_correct_entities_returned(
@ -81,7 +81,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
# self.assertItemsEqual isn't available yet, or I am just silly
# so we sort the lists and compare them
assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([
entity["name"] for entity in results.entities
entity["title"] for entity in results.entities
])
async def test_run_extract_entities_multiple_documents_correct_edges_returned(self):

View File

@ -4,7 +4,6 @@
from graphrag.index.flows.compute_communities import (
compute_communities,
)
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.compute_communities import (
workflow_name,
)
@ -16,37 +15,15 @@ from .util import (
)
async def test_compute_communities():
def test_compute_communities():
edges = load_test_table("base_relationship_edges")
expected = load_test_table("base_communities")
context = create_run_context(None, None, None)
config = get_config_for_workflow(workflow_name)
clustering_strategy = config["cluster_graph"]["strategy"]
actual = await compute_communities(
edges, storage=context.storage, clustering_strategy=clustering_strategy
)
actual = compute_communities(edges, clustering_strategy=clustering_strategy)
columns = list(expected.columns.values)
compare_outputs(actual, expected, columns)
assert len(actual.columns) == len(expected.columns)
async def test_compute_communities_with_snapshots():
edges = load_test_table("base_relationship_edges")
context = create_run_context(None, None, None)
config = get_config_for_workflow(workflow_name)
clustering_strategy = config["cluster_graph"]["strategy"]
await compute_communities(
edges,
storage=context.storage,
clustering_strategy=clustering_strategy,
snapshot_transient_enabled=True,
)
assert context.storage.keys() == [
"base_communities.parquet",
], "Community snapshot keys differ"

View File

@ -16,10 +16,9 @@ from .util import (
def test_create_final_relationships():
edges = load_test_table("base_relationship_edges")
nodes = load_test_table("base_entity_nodes")
expected = load_test_table(workflow_name)
actual = create_final_relationships(edges, nodes)
actual = create_final_relationships(edges)
assert "id" in expected.columns
columns = list(expected.columns.values)