mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-01 09:30:31 +00:00
Collapse create summarized entities (#1237)
* Collapse entity summarize * Semver
This commit is contained in:
parent
5220bb7ecc
commit
630679f8e3
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Collapse entity summarize."
|
||||
}
|
||||
@ -53,6 +53,29 @@ async def summarize_descriptions(
|
||||
strategy: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> TableContainer:
|
||||
"""Summarize entity and relationship descriptions from an entity graph."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
output = await summarize_descriptions_df(
|
||||
source,
|
||||
cache,
|
||||
callbacks,
|
||||
column=column,
|
||||
to=to,
|
||||
strategy=strategy,
|
||||
**kwargs,
|
||||
)
|
||||
return TableContainer(table=output)
|
||||
|
||||
|
||||
async def summarize_descriptions_df(
|
||||
input: pd.DataFrame,
|
||||
cache: PipelineCache,
|
||||
callbacks: VerbCallbacks,
|
||||
column: str,
|
||||
to: str,
|
||||
strategy: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Summarize entity and relationship descriptions from an entity graph.
|
||||
|
||||
@ -111,7 +134,6 @@ async def summarize_descriptions(
|
||||
```
|
||||
"""
|
||||
log.debug("summarize_descriptions strategy=%s", strategy)
|
||||
output = cast(pd.DataFrame, input.get_input())
|
||||
strategy = strategy or {}
|
||||
strategy_exec = load_strategy(
|
||||
strategy.get("type", SummarizeStrategyType.graph_intelligence)
|
||||
@ -181,7 +203,7 @@ async def summarize_descriptions(
|
||||
semaphore = asyncio.Semaphore(kwargs.get("num_threads", 4))
|
||||
|
||||
results = [
|
||||
await get_resolved_entities(row, semaphore) for row in output.itertuples()
|
||||
await get_resolved_entities(row, semaphore) for row in input.itertuples()
|
||||
]
|
||||
|
||||
to_result = []
|
||||
@ -191,8 +213,8 @@ async def summarize_descriptions(
|
||||
to_result.append(result.graph)
|
||||
else:
|
||||
to_result.append(None)
|
||||
output[to] = to_result
|
||||
return TableContainer(table=output)
|
||||
input[to] = to_result
|
||||
return input
|
||||
|
||||
|
||||
def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy:
|
||||
|
||||
@ -3,8 +3,6 @@
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
from datashaper import AsyncType
|
||||
|
||||
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
|
||||
workflow_name = "create_summarized_entities"
|
||||
@ -20,28 +18,19 @@ def build_steps(
|
||||
* `workflow:create_base_text_units`
|
||||
"""
|
||||
summarize_descriptions_config = config.get("summarize_descriptions", {})
|
||||
strategy = summarize_descriptions_config.get("strategy", {})
|
||||
num_threads = strategy.get("num_threads", 4)
|
||||
|
||||
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": "summarize_descriptions",
|
||||
"verb": "create_summarized_entities",
|
||||
"args": {
|
||||
**summarize_descriptions_config,
|
||||
"column": "entity_graph",
|
||||
"to": "entity_graph",
|
||||
"async_mode": summarize_descriptions_config.get(
|
||||
"async_mode", AsyncType.AsyncIO
|
||||
),
|
||||
"strategy": strategy,
|
||||
"num_threads": num_threads,
|
||||
"graphml_snapshot_enabled": graphml_snapshot_enabled,
|
||||
},
|
||||
"input": {"source": "workflow:create_base_extracted_entities"},
|
||||
},
|
||||
{
|
||||
"verb": "snapshot_rows",
|
||||
"enabled": graphml_snapshot_enabled,
|
||||
"args": {
|
||||
"base_name": "summarized_graph",
|
||||
"column": "entity_graph",
|
||||
"formats": [{"format": "text", "extension": "graphml"}],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@ -16,6 +16,7 @@ from .create_final_relationships import (
|
||||
create_final_relationships,
|
||||
)
|
||||
from .create_final_text_units import create_final_text_units
|
||||
from .create_summarized_entities import create_summarized_entities
|
||||
|
||||
__all__ = [
|
||||
"create_base_documents",
|
||||
@ -29,4 +30,5 @@ __all__ = [
|
||||
"create_final_nodes",
|
||||
"create_final_relationships",
|
||||
"create_final_text_units",
|
||||
"create_summarized_entities",
|
||||
]
|
||||
|
||||
@ -0,0 +1,61 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform final documents."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
from graphrag.index.verbs.entities.summarize.description_summarize import (
|
||||
summarize_descriptions_df,
|
||||
)
|
||||
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df
|
||||
|
||||
|
||||
@verb(
|
||||
name="create_summarized_entities",
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def create_summarized_entities(
|
||||
input: VerbInput,
|
||||
cache: PipelineCache,
|
||||
callbacks: VerbCallbacks,
|
||||
storage: PipelineStorage,
|
||||
strategy: dict[str, Any] | None = None,
|
||||
num_threads: int = 4,
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final documents."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
summarized = await summarize_descriptions_df(
|
||||
source,
|
||||
cache,
|
||||
callbacks,
|
||||
column="entity_graph",
|
||||
to="entity_graph",
|
||||
strategy=strategy,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
if graphml_snapshot_enabled:
|
||||
await snapshot_rows_df(
|
||||
summarized,
|
||||
column="entity_graph",
|
||||
base_name="summarized_graph",
|
||||
storage=storage,
|
||||
formats=[{"format": "text", "extension": "graphml"}],
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, summarized))
|
||||
92
tests/verbs/test_create_summarized_entities.py
Normal file
92
tests/verbs/test_create_summarized_entities.py
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.index.workflows.v1.create_summarized_entities import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_expected,
|
||||
load_input_tables,
|
||||
remove_disabled_steps,
|
||||
)
|
||||
|
||||
|
||||
async def test_create_summarized_entities():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_extracted_entities",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
del config["summarize_descriptions"]["strategy"]["llm"]
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
# the serialization of the graph may differ so we can't assert the dataframes directly
|
||||
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
||||
|
||||
# let's parse a sample of the raw graphml
|
||||
actual_graphml_0 = actual["entity_graph"][:1][0]
|
||||
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
|
||||
|
||||
expected_graphml_0 = expected["entity_graph"][:1][0]
|
||||
expected_graph_0 = nx.parse_graphml(expected_graphml_0)
|
||||
|
||||
assert (
|
||||
actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes()
|
||||
), "Graphml node count differs"
|
||||
assert (
|
||||
actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges()
|
||||
), "Graphml edge count differs"
|
||||
|
||||
# ensure the mock summary was injected to the nodes
|
||||
nodes = list(actual_graph_0.nodes(data=True))
|
||||
assert (
|
||||
nodes[0][1]["description"]
|
||||
== "This is a MOCK response for the LLM. It is summarized!"
|
||||
)
|
||||
|
||||
|
||||
async def test_create_summarized_entities_with_snapshots():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_extracted_entities",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
del config["summarize_descriptions"]["strategy"]["llm"]
|
||||
config["graphml_snapshot"] = True
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
||||
|
||||
assert storage.keys() == [
|
||||
"summarized_graph.graphml",
|
||||
], "Graph snapshot keys differ"
|
||||
Loading…
x
Reference in New Issue
Block a user