Collapse create summarized entities (#1237)

* Collapse entity summarize

* Semver
This commit is contained in:
Nathan Evans 2024-09-30 17:17:44 -07:00 committed by GitHub
parent 5220bb7ecc
commit 630679f8e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 192 additions and 22 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse entity summarize."
}

View File

@ -53,6 +53,29 @@ async def summarize_descriptions(
strategy: dict[str, Any] | None = None,
**kwargs,
) -> TableContainer:
"""Summarize entity and relationship descriptions from an entity graph."""
source = cast(pd.DataFrame, input.get_input())
output = await summarize_descriptions_df(
source,
cache,
callbacks,
column=column,
to=to,
strategy=strategy,
**kwargs,
)
return TableContainer(table=output)
async def summarize_descriptions_df(
input: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
column: str,
to: str,
strategy: dict[str, Any] | None = None,
**kwargs,
) -> pd.DataFrame:
"""
Summarize entity and relationship descriptions from an entity graph.
@ -111,7 +134,6 @@ async def summarize_descriptions(
```
"""
log.debug("summarize_descriptions strategy=%s", strategy)
output = cast(pd.DataFrame, input.get_input())
strategy = strategy or {}
strategy_exec = load_strategy(
strategy.get("type", SummarizeStrategyType.graph_intelligence)
@ -181,7 +203,7 @@ async def summarize_descriptions(
semaphore = asyncio.Semaphore(kwargs.get("num_threads", 4))
results = [
await get_resolved_entities(row, semaphore) for row in output.itertuples()
await get_resolved_entities(row, semaphore) for row in input.itertuples()
]
to_result = []
@ -191,8 +213,8 @@ async def summarize_descriptions(
to_result.append(result.graph)
else:
to_result.append(None)
output[to] = to_result
return TableContainer(table=output)
input[to] = to_result
return input
def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy:

View File

@ -3,8 +3,6 @@
"""A module containing build_steps method definition."""
from datashaper import AsyncType
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
workflow_name = "create_summarized_entities"
@ -20,28 +18,19 @@ def build_steps(
* `workflow:create_base_text_units`
"""
summarize_descriptions_config = config.get("summarize_descriptions", {})
strategy = summarize_descriptions_config.get("strategy", {})
num_threads = strategy.get("num_threads", 4)
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
return [
{
"verb": "summarize_descriptions",
"verb": "create_summarized_entities",
"args": {
**summarize_descriptions_config,
"column": "entity_graph",
"to": "entity_graph",
"async_mode": summarize_descriptions_config.get(
"async_mode", AsyncType.AsyncIO
),
"strategy": strategy,
"num_threads": num_threads,
"graphml_snapshot_enabled": graphml_snapshot_enabled,
},
"input": {"source": "workflow:create_base_extracted_entities"},
},
{
"verb": "snapshot_rows",
"enabled": graphml_snapshot_enabled,
"args": {
"base_name": "summarized_graph",
"column": "entity_graph",
"formats": [{"format": "text", "extension": "graphml"}],
},
},
]

View File

@ -16,6 +16,7 @@ from .create_final_relationships import (
create_final_relationships,
)
from .create_final_text_units import create_final_text_units
from .create_summarized_entities import create_summarized_entities
__all__ = [
"create_base_documents",
@ -29,4 +30,5 @@ __all__ = [
"create_final_nodes",
"create_final_relationships",
"create_final_text_units",
"create_summarized_entities",
]

View File

@ -0,0 +1,61 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform final documents."""
from typing import Any, cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.entities.summarize.description_summarize import (
summarize_descriptions_df,
)
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df
@verb(
name="create_summarized_entities",
treats_input_tables_as_immutable=True,
)
async def create_summarized_entities(
input: VerbInput,
cache: PipelineCache,
callbacks: VerbCallbacks,
storage: PipelineStorage,
strategy: dict[str, Any] | None = None,
num_threads: int = 4,
graphml_snapshot_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final documents."""
source = cast(pd.DataFrame, input.get_input())
summarized = await summarize_descriptions_df(
source,
cache,
callbacks,
column="entity_graph",
to="entity_graph",
strategy=strategy,
num_threads=num_threads,
)
if graphml_snapshot_enabled:
await snapshot_rows_df(
summarized,
column="entity_graph",
base_name="summarized_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)
return create_verb_result(cast(Table, summarized))

View File

@ -0,0 +1,92 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import networkx as nx
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.index.workflows.v1.create_summarized_entities import (
build_steps,
workflow_name,
)
from .util import (
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
remove_disabled_steps,
)
async def test_create_summarized_entities():
input_tables = load_input_tables([
"workflow:create_base_extracted_entities",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
del config["summarize_descriptions"]["strategy"]["llm"]
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
# the serialization of the graph may differ so we can't assert the dataframes directly
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
# let's parse a sample of the raw graphml
actual_graphml_0 = actual["entity_graph"][:1][0]
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
expected_graphml_0 = expected["entity_graph"][:1][0]
expected_graph_0 = nx.parse_graphml(expected_graphml_0)
assert (
actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes()
), "Graphml node count differs"
assert (
actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges()
), "Graphml edge count differs"
# ensure the mock summary was injected to the nodes
nodes = list(actual_graph_0.nodes(data=True))
assert (
nodes[0][1]["description"]
== "This is a MOCK response for the LLM. It is summarized!"
)
async def test_create_summarized_entities_with_snapshots():
input_tables = load_input_tables([
"workflow:create_base_extracted_entities",
])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
config = get_config_for_workflow(workflow_name)
del config["summarize_descriptions"]["strategy"]["llm"]
config["graphml_snapshot"] = True
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
storage=storage,
)
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
assert storage.keys() == [
"summarized_graph.graphml",
], "Graph snapshot keys differ"