Collapse graph documents workflows (#1284)

* Copy base documents logic into final documents

* Delete create_base_documents

* Combine graph creation under create_base_entity_graph

* Delete collapsed workflows

* Migrate most graph internals to nx.Graph

* Fix None edge case

* Semver

* Remove comment typo

* Fix smoke tests
This commit is contained in:
Nathan Evans 2024-10-15 12:58:58 -07:00 committed by GitHub
parent 137a5cd550
commit ce5b1207e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 446 additions and 1098 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse intermediate workflow outputs."
}

View File

@ -49,9 +49,7 @@ from graphrag.index.config.workflow import (
PipelineWorkflowReference,
)
from graphrag.index.workflows.default_workflows import (
create_base_documents,
create_base_entity_graph,
create_base_extracted_entities,
create_base_text_units,
create_final_communities,
create_final_community_reports,
@ -61,7 +59,6 @@ from graphrag.index.workflows.default_workflows import (
create_final_nodes,
create_final_relationships,
create_final_text_units,
create_summarized_entities,
)
log = logging.getLogger(__name__)
@ -173,17 +170,12 @@ def _document_workflows(
)
return [
PipelineWorkflowReference(
name=create_base_documents,
name=create_final_documents,
config={
"document_attribute_columns": list(
{*(settings.input.document_attribute_columns)}
- builtin_document_attributes
)
},
),
PipelineWorkflowReference(
name=create_final_documents,
config={
),
"document_raw_content_embed": _get_embedding_settings(
settings.embeddings,
"document_raw_content",
@ -267,10 +259,9 @@ def _graph_workflows(
)
return [
PipelineWorkflowReference(
name=create_base_extracted_entities,
name=create_base_entity_graph,
config={
"graphml_snapshot": settings.snapshots.graphml,
"raw_entity_snapshot": settings.snapshots.raw_entities,
"entity_extract": {
**settings.entity_extraction.parallelization.model_dump(),
"async_mode": settings.entity_extraction.async_mode,
@ -279,12 +270,6 @@ def _graph_workflows(
),
"entity_types": settings.entity_extraction.entity_types,
},
},
),
PipelineWorkflowReference(
name=create_summarized_entities,
config={
"graphml_snapshot": settings.snapshots.graphml,
"summarize_descriptions": {
**settings.summarize_descriptions.parallelization.model_dump(),
"async_mode": settings.summarize_descriptions.async_mode,
@ -292,12 +277,6 @@ def _graph_workflows(
settings.root_dir,
),
},
},
),
PipelineWorkflowReference(
name=create_base_entity_graph,
config={
"graphml_snapshot": settings.snapshots.graphml,
"embed_graph_enabled": settings.embed_graph.enabled,
"cluster_graph": {
"strategy": settings.cluster_graph.resolved_strategy()

View File

@ -1,64 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Transform base documents by joining them with their text_units and adding optional attributes."""
import pandas as pd
def create_base_documents(
documents: pd.DataFrame,
text_units: pd.DataFrame,
document_attribute_columns: list[str] | None = None,
) -> pd.DataFrame:
"""Transform base documents by joining them with their text_units and adding optional attributes."""
exploded = (
text_units.explode("document_ids")
.loc[:, ["id", "document_ids", "text"]]
.rename(
columns={
"document_ids": "chunk_doc_id",
"id": "chunk_id",
"text": "chunk_text",
}
)
)
joined = exploded.merge(
documents,
left_on="chunk_doc_id",
right_on="id",
how="inner",
copy=False,
)
docs_with_text_units = joined.groupby("id", sort=False).agg(
text_units=("chunk_id", list)
)
rejoined = docs_with_text_units.merge(
documents,
on="id",
how="right",
copy=False,
).reset_index(drop=True)
rejoined.rename(columns={"text": "raw_content"}, inplace=True)
rejoined["id"] = rejoined["id"].astype(str)
# Convert attribute columns to strings and collapse them into a JSON object
if document_attribute_columns:
# Convert all specified columns to string at once
rejoined[document_attribute_columns] = rejoined[
document_attribute_columns
].astype(str)
# Collapse the document_attribute_columns into a single JSON object column
rejoined["attributes"] = rejoined[document_attribute_columns].to_dict(
orient="records"
)
# Drop the original attribute columns after collapsing them
rejoined.drop(columns=document_attribute_columns, inplace=True)
return rejoined

View File

@ -7,26 +7,76 @@ from typing import Any, cast
import pandas as pd
from datashaper import (
AsyncType,
VerbCallbacks,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.embed_graph import embed_graph
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.merge_graphs import merge_graphs
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.operations.snapshot_rows import snapshot_rows
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.index.storage import PipelineStorage
async def create_base_entity_graph(
entities: pd.DataFrame,
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
text_column: str,
id_column: str,
clustering_strategy: dict[str, Any],
embedding_strategy: dict[str, Any] | None,
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
node_merge_config: dict[str, Any] | None = None,
edge_merge_config: dict[str, Any] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
embedding_strategy: dict[str, Any] | None = None,
graphml_snapshot_enabled: bool = False,
raw_entity_snapshot_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entities, entity_graphs = await extract_entities(
text_units,
callbacks,
cache,
text_column=text_column,
id_column=id_column,
strategy=extraction_strategy,
async_mode=extraction_async_mode,
entity_types=entity_types,
to="entities",
num_threads=extraction_num_threads,
)
merged_graph = merge_graphs(
entity_graphs,
callbacks,
node_operations=node_merge_config,
edge_operations=edge_merge_config,
)
summarized = await summarize_descriptions(
merged_graph,
callbacks,
cache,
strategy=summarization_strategy,
num_threads=summarization_num_threads,
)
clustered = cluster_graph(
entities,
summarized,
callbacks,
column="entity_graph",
strategy=clustering_strategy,
@ -34,15 +84,6 @@ async def create_base_entity_graph(
level_to="level",
)
if graphml_snapshot_enabled:
await snapshot_rows(
clustered,
column="clustered_graph",
base_name="clustered_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)
if embedding_strategy:
clustered["embeddings"] = await embed_graph(
clustered,
@ -51,16 +92,40 @@ async def create_base_entity_graph(
strategy=embedding_strategy,
)
# take second snapshot after embedding
# todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph?
if raw_entity_snapshot_enabled:
await snapshot(
entities,
name="raw_extracted_entities",
storage=storage,
formats=["json"],
)
if graphml_snapshot_enabled:
await snapshot_graphml(
merged_graph,
name="merged_graph",
storage=storage,
)
await snapshot_graphml(
summarized,
name="summarized_graph",
storage=storage,
)
await snapshot_rows(
clustered,
column="entity_graph",
base_name="embedded_graph",
column="clustered_graph",
base_name="clustered_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)
if embedding_strategy:
await snapshot_rows(
clustered,
column="entity_graph",
base_name="embedded_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)
final_columns = ["level", "clustered_graph"]
if embedding_strategy:

View File

@ -1,79 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to extract and format covariates."""
from typing import Any
import pandas as pd
from datashaper import (
AsyncType,
VerbCallbacks,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.merge_graphs import merge_graphs
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_rows import snapshot_rows
from graphrag.index.storage import PipelineStorage
async def create_base_extracted_entities(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
column: str,
id_column: str,
nodes: dict[str, Any],
edges: dict[str, Any],
extraction_strategy: dict[str, Any] | None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
graphml_snapshot_enabled: bool = False,
raw_entity_snapshot_enabled: bool = False,
num_threads: int = 4,
) -> pd.DataFrame:
"""All the steps to extract and format covariates."""
entity_graph = await extract_entities(
text_units,
callbacks,
cache,
column=column,
id_column=id_column,
strategy=extraction_strategy,
async_mode=async_mode,
entity_types=entity_types,
to="entities",
graph_to="entity_graph",
num_threads=num_threads,
)
if raw_entity_snapshot_enabled:
await snapshot(
entity_graph,
name="raw_extracted_entities",
storage=storage,
formats=["json"],
)
merged_graph = merge_graphs(
entity_graph,
callbacks,
column="entity_graph",
to="entity_graph",
nodes=nodes,
edges=edges,
)
if graphml_snapshot_enabled:
await snapshot_rows(
merged_graph,
base_name="merged_graph",
column="entity_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)
return merged_graph

View File

@ -14,20 +14,71 @@ from graphrag.index.operations.embed_text import embed_text
async def create_final_documents(
documents: pd.DataFrame,
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
document_attribute_columns: list[str] | None = None,
raw_content_text_embed: dict | None = None,
) -> pd.DataFrame:
"""All the steps to transform final documents."""
documents.rename(columns={"text_units": "text_unit_ids"}, inplace=True)
exploded = (
text_units.explode("document_ids")
.loc[:, ["id", "document_ids", "text"]]
.rename(
columns={
"document_ids": "chunk_doc_id",
"id": "chunk_id",
"text": "chunk_text",
}
)
)
joined = exploded.merge(
documents,
left_on="chunk_doc_id",
right_on="id",
how="inner",
copy=False,
)
docs_with_text_units = joined.groupby("id", sort=False).agg(
text_units=("chunk_id", list)
)
rejoined = docs_with_text_units.merge(
documents,
on="id",
how="right",
copy=False,
).reset_index(drop=True)
rejoined.rename(
columns={"text": "raw_content", "text_units": "text_unit_ids"}, inplace=True
)
rejoined["id"] = rejoined["id"].astype(str)
# Convert attribute columns to strings and collapse them into a JSON object
if document_attribute_columns:
# Convert all specified columns to string at once
rejoined[document_attribute_columns] = rejoined[
document_attribute_columns
].astype(str)
# Collapse the document_attribute_columns into a single JSON object column
rejoined["attributes"] = rejoined[document_attribute_columns].to_dict(
orient="records"
)
# Drop the original attribute columns after collapsing them
rejoined.drop(columns=document_attribute_columns, inplace=True)
if raw_content_text_embed:
documents["raw_content_embedding"] = await embed_text(
documents,
rejoined["raw_content_embedding"] = await embed_text(
rejoined,
callbacks,
cache,
column="raw_content",
strategy=raw_content_text_embed["strategy"],
)
return documents
return rejoined

View File

@ -1,50 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to summarize entities."""
from typing import Any
import pandas as pd
from datashaper import (
VerbCallbacks,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.snapshot_rows import snapshot_rows
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.index.storage import PipelineStorage
async def create_summarized_entities(
entities: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
summarization_strategy: dict[str, Any] | None = None,
num_threads: int = 4,
graphml_snapshot_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to summarize entities."""
summarized = await summarize_descriptions(
entities,
callbacks,
cache,
column="entity_graph",
to="entity_graph",
strategy=summarization_strategy,
num_threads=num_threads,
)
if graphml_snapshot_enabled:
await snapshot_rows(
summarized,
column="entity_graph",
base_name="summarized_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)
return summarized

View File

@ -14,7 +14,7 @@ from datashaper import VerbCallbacks, progress_iterable
from graspologic.partition import hierarchical_leiden
from graphrag.index.graph.utils import stable_largest_connected_component
from graphrag.index.utils import gen_uuid, load_graph
from graphrag.index.utils import gen_uuid
Communities = list[tuple[int, str, list[str]]]
@ -33,7 +33,7 @@ log = logging.getLogger(__name__)
def cluster_graph(
input: pd.DataFrame,
input: nx.Graph,
callbacks: VerbCallbacks,
strategy: dict[str, Any],
column: str,
@ -41,63 +41,64 @@ def cluster_graph(
level_to: str | None = None,
) -> pd.DataFrame:
"""Apply a hierarchical clustering algorithm to a graph."""
results = input[column].apply(lambda graph: run_layout(strategy, graph))
output = pd.DataFrame()
# TODO: for back-compat, downstream expects a graphml string
output[column] = ["\n".join(nx.generate_graphml(input))]
communities = run_layout(strategy, input)
community_map_to = "communities"
input[community_map_to] = results
output[community_map_to] = [communities]
level_to = level_to or f"{to}_level"
input[level_to] = input.apply(
output[level_to] = output.apply(
lambda x: list({level for level, _, _ in x[community_map_to]}), axis=1
)
input[to] = None
output[to] = None
num_total = len(input)
num_total = len(output)
# Create a seed for this run (if not provided)
seed = strategy.get("seed", Random().randint(0, 0xFFFFFFFF)) # noqa S311
# Go through each of the rows
graph_level_pairs_column: list[list[tuple[int, str]]] = []
for _, row in progress_iterable(input.iterrows(), callbacks.progress, num_total):
for _, row in progress_iterable(output.iterrows(), callbacks.progress, num_total):
levels = row[level_to]
graph_level_pairs: list[tuple[int, str]] = []
# For each of the levels, get the graph and add it to the list
for level in levels:
graph = "\n".join(
graphml = "\n".join(
nx.generate_graphml(
apply_clustering(
cast(str, row[column]),
input,
cast(Communities, row[community_map_to]),
level,
seed=seed,
)
)
)
graph_level_pairs.append((level, graph))
graph_level_pairs.append((level, graphml))
graph_level_pairs_column.append(graph_level_pairs)
input[to] = graph_level_pairs_column
output[to] = graph_level_pairs_column
# explode the list of (level, graph) pairs into separate rows
input = input.explode(to, ignore_index=True)
output = output.explode(to, ignore_index=True)
# split the (level, graph) pairs into separate columns
# TODO: There is probably a better way to do this
input[[level_to, to]] = pd.DataFrame(input[to].tolist(), index=input.index)
output[[level_to, to]] = pd.DataFrame(output[to].tolist(), index=output.index)
# clean up the community map
input.drop(columns=[community_map_to], inplace=True)
return input
output.drop(columns=[community_map_to], inplace=True)
return output
# TODO: This should support str | nx.Graph as a graphml param
def apply_clustering(
graphml: str, communities: Communities, level: int = 0, seed: int | None = None
graph: nx.Graph, communities: Communities, level: int = 0, seed: int | None = None
) -> nx.Graph:
"""Apply clustering to a graphml string."""
"""Apply clustering to a graph."""
random = Random(seed) # noqa S311
graph = nx.parse_graphml(graphml)
for community_level, community_id, nodes in communities:
if level == community_level:
for node in nodes:
@ -121,11 +122,8 @@ def apply_clustering(
return graph
def run_layout(
strategy: dict[str, Any], graphml_or_graph: str | nx.Graph
) -> Communities:
def run_layout(strategy: dict[str, Any], graph: nx.Graph) -> Communities:
"""Run layout method definition."""
graph = load_graph(graphml_or_graph)
if len(graph.nodes) == 0:
log.warning("Graph has no nodes")
return []

View File

@ -7,6 +7,7 @@ import logging
from enum import Enum
from typing import Any
import networkx as nx
import pandas as pd
from datashaper import (
AsyncType,
@ -41,15 +42,14 @@ async def extract_entities(
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
text_column: str,
id_column: str,
to: str,
strategy: dict[str, Any] | None,
graph_to: str | None = None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types=DEFAULT_ENTITY_TYPES,
num_threads: int = 4,
) -> pd.DataFrame:
) -> tuple[pd.DataFrame, list[nx.Graph]]:
"""
Extract entities from a piece of text.
@ -59,7 +59,6 @@ async def extract_entities(
column: the_document_text_column_to_extract_entities_from
id_column: the_column_with_the_unique_id_for_each_row
to: the_column_to_output_the_entities_to
graph_to: the_column_to_output_the_graphml_to
strategy: <strategy_config>, see strategies section below
summarize_descriptions: true | false /* Optional: This will summarize the descriptions of the entities and relationships, default: true */
entity_types:
@ -124,7 +123,7 @@ async def extract_entities(
async def run_strategy(row):
nonlocal num_started
text = row[column]
text = row[text_column]
id = row[id_column]
result = await strategy_exec(
[Document(text=text, id=id)],
@ -134,7 +133,7 @@ async def extract_entities(
strategy_config,
)
num_started += 1
return [result.entities, result.graphml_graph]
return [result.entities, result.graph]
results = await derive_from_rows(
input,
@ -145,20 +144,18 @@ async def extract_entities(
)
to_result = []
graph_to_result = []
graphs = []
for result in results:
if result:
to_result.append(result[0])
graph_to_result.append(result[1])
graphs.append(result[1])
else:
to_result.append(None)
graph_to_result.append(None)
graphs.append(None)
input[to] = to_result
if graph_to is not None:
input[graph_to] = graph_to_result
return input.reset_index(drop=True)
return (input.reset_index(drop=True), graphs)
def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:

View File

@ -3,7 +3,6 @@
"""A module containing run_graph_intelligence, run_extract_entities and _create_text_splitter methods to run graph intelligence."""
import networkx as nx
from datashaper import VerbCallbacks
import graphrag.config.defaults as defs
@ -113,8 +112,7 @@ async def run_extract_entities(
if item is not None
]
graph_data = "".join(nx.generate_graphml(graph))
return EntityExtractionResult(entities, graph_data)
return EntityExtractionResult(entities, graph)
def _create_text_splitter(

View File

@ -57,5 +57,5 @@ async def run( # noqa RUF029 async is required for interface
{"type": entity_type, "name": name}
for name, entity_type in entity_map.items()
],
graphml_graph="".join(nx.generate_graphml(graph)),
graph=graph,
)

View File

@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
import networkx as nx
from datashaper import VerbCallbacks
from graphrag.index.cache import PipelineCache
@ -29,7 +30,7 @@ class EntityExtractionResult:
"""Entity extraction result class definition."""
entities: list[ExtractedEntity]
graphml_graph: str | None
graph: nx.Graph | None
EntityExtractStrategy = Callable[

View File

@ -3,14 +3,11 @@
"""A module containing merge_graphs, merge_nodes, merge_edges, merge_attributes, apply_merge_operation and _get_detailed_attribute_merge_operation methods definitions."""
from typing import Any, cast
from typing import Any
import networkx as nx
import pandas as pd
from datashaper import VerbCallbacks, progress_iterable
from graphrag.index.utils import load_graph
from .typing import (
BasicMergeOperation,
DetailedAttributeMergeOperation,
@ -35,15 +32,13 @@ DEFAULT_CONCAT_SEPARATOR = ","
def merge_graphs(
input: pd.DataFrame,
graphs: list[nx.Graph],
callbacks: VerbCallbacks,
column: str,
to: str,
nodes: dict[str, Any] = DEFAULT_NODE_OPERATIONS,
edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS,
) -> pd.DataFrame:
node_operations: dict[str, Any] | None,
edge_operations: dict[str, Any] | None,
) -> nx.Graph:
"""
Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph.
Merge multiple graphs together. The graphs are expected to be in nx.Graph format. The verb outputs a new column containing the merged graph.
> Note: This will merge all rows into a single graph.
@ -51,8 +46,6 @@ def merge_graphs(
```yaml
verb: merge_graph
args:
column: clustered_graph # The name of the column containing the graph, should be a graphml graph
to: merged_graph # The name of the column to output the merged graph to
nodes: <node operations> # See node operations section below
edges: <edge operations> # See edge operations section below
```
@ -90,8 +83,8 @@ def merge_graphs(
- __average__: This operation takes the mean of the attribute with the last value seen.
- __multiply__: This operation multiplies the attribute with the last value seen.
"""
output = pd.DataFrame()
nodes = node_operations or DEFAULT_NODE_OPERATIONS
edges = edge_operations or DEFAULT_EDGE_OPERATIONS
node_ops = {
attrib: _get_detailed_attribute_merge_operation(value)
for attrib, value in nodes.items()
@ -102,15 +95,12 @@ def merge_graphs(
}
mega_graph = nx.Graph()
num_total = len(input)
for graphml in progress_iterable(input[column], callbacks.progress, num_total):
graph = load_graph(cast(str | nx.Graph, graphml))
num_total = len(graphs)
for graph in progress_iterable(graphs, callbacks.progress, num_total):
merge_nodes(mega_graph, graph, node_ops)
merge_edges(mega_graph, graph, edge_ops)
output[to] = ["\n".join(nx.generate_graphml(mega_graph))]
return output
return mega_graph
def merge_nodes(

View File

@ -0,0 +1,18 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing snapshot method definition."""
import networkx as nx
from graphrag.index.storage import PipelineStorage
async def snapshot_graphml(
input: str | nx.Graph,
name: str,
storage: PipelineStorage,
) -> None:
"""Take a entire snapshot of a graph to standard graphml format."""
graphml = input if isinstance(input, str) else "\n".join(nx.generate_graphml(input))
await storage.set(name + ".graphml", graphml)

View File

@ -5,10 +5,9 @@
import asyncio
import logging
from typing import Any, cast
from typing import Any
import networkx as nx
import pandas as pd
from datashaper import (
ProgressTicker,
VerbCallbacks,
@ -16,10 +15,8 @@ from datashaper import (
)
from graphrag.index.cache import PipelineCache
from graphrag.index.utils import load_graph
from .typing import (
DescriptionSummarizeRow,
SummarizationStrategy,
SummarizeStrategyType,
)
@ -28,14 +25,12 @@ log = logging.getLogger(__name__)
async def summarize_descriptions(
input: pd.DataFrame,
input: nx.Graph,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
to: str,
strategy: dict[str, Any] | None = None,
**kwargs,
) -> pd.DataFrame:
num_threads: int = 4,
) -> nx.Graph:
"""
Summarize entity and relationship descriptions from an entity graph.
@ -43,25 +38,10 @@ async def summarize_descriptions(
To turn this feature ON please set the environment variable `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_ENABLED=True`.
### json
```json
{
"verb": "",
"args": {
"column": "the_document_text_column_to_extract_descriptions_from", /* Required: This will be a graphml graph in string form which represents the entities and their relationships */
"to": "the_column_to_output_the_summarized_descriptions_to", /* Required: This will be a graphml graph in string form which represents the entities and their relationships after being summarized */
"strategy": {...} <strategy_config>, see strategies section below
}
}
```
### yaml
```yaml
args:
column: the_document_text_column_to_extract_descriptions_from
to: the_column_to_output_the_summarized_descriptions_to
strategy: <strategy_config>, see strategies section below
```
@ -99,9 +79,7 @@ async def summarize_descriptions(
)
strategy_config = {**strategy}
async def get_resolved_entities(row, semaphore: asyncio.Semaphore):
graph: nx.Graph = load_graph(cast(str | nx.Graph, getattr(row, column)))
async def get_resolved_entities(graph: nx.Graph, semaphore: asyncio.Semaphore):
ticker_length = len(graph.nodes) + len(graph.edges)
ticker = progress_ticker(callbacks.progress, ticker_length)
@ -134,9 +112,7 @@ async def summarize_descriptions(
elif isinstance(graph_item, tuple) and graph_item in graph.edges():
graph.edges[graph_item]["description"] = result.description
return DescriptionSummarizeRow(
graph="\n".join(nx.generate_graphml(graph)),
)
return graph
async def do_summarize_descriptions(
graph_item: str | tuple[str, str],
@ -155,25 +131,9 @@ async def summarize_descriptions(
ticker(1)
return results
# Graph is always on row 0, so here a derive from rows does not work
# This iteration will only happen once, but avoids hardcoding a iloc[0]
# Since parallelization is at graph level (nodes and edges), we can't use
# the parallelization of the derive_from_rows
semaphore = asyncio.Semaphore(kwargs.get("num_threads", 4))
semaphore = asyncio.Semaphore(num_threads)
results = [
await get_resolved_entities(row, semaphore) for row in input.itertuples()
]
to_result = []
for result in results:
if result:
to_result.append(result.graph)
else:
to_result.append(None)
input[to] = to_result
return input
return await get_resolved_entities(input, semaphore)
def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy:

View File

@ -7,24 +7,12 @@
from .v1.subflows import * # noqa
from .typing import WorkflowDefinitions
from .v1.create_base_documents import (
build_steps as build_create_base_documents_steps,
)
from .v1.create_base_documents import (
workflow_name as create_base_documents,
)
from .v1.create_base_entity_graph import (
build_steps as build_create_base_entity_graph_steps,
)
from .v1.create_base_entity_graph import (
workflow_name as create_base_entity_graph,
)
from .v1.create_base_extracted_entities import (
build_steps as build_create_base_extracted_entities_steps,
)
from .v1.create_base_extracted_entities import (
workflow_name as create_base_extracted_entities,
)
from .v1.create_base_text_units import (
build_steps as build_create_base_text_units_steps,
)
@ -79,16 +67,9 @@ from .v1.create_final_text_units import (
from .v1.create_final_text_units import (
workflow_name as create_final_text_units,
)
from .v1.create_summarized_entities import (
build_steps as build_create_summarized_entities_steps,
)
from .v1.create_summarized_entities import (
workflow_name as create_summarized_entities,
)
default_workflows: WorkflowDefinitions = {
create_base_extracted_entities: build_create_base_extracted_entities_steps,
create_base_entity_graph: build_create_base_entity_graph_steps,
create_base_text_units: build_create_base_text_units_steps,
create_final_text_units: build_create_final_text_units,
@ -97,8 +78,6 @@ default_workflows: WorkflowDefinitions = {
create_final_relationships: build_create_final_relationships_steps,
create_final_documents: build_create_final_documents_steps,
create_final_covariates: build_create_final_covariates_steps,
create_base_documents: build_create_base_documents_steps,
create_final_entities: build_create_final_entities_steps,
create_final_communities: build_create_final_communities_steps,
create_summarized_entities: build_create_summarized_entities_steps,
}

View File

@ -1,34 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing build_steps method definition."""
from datashaper import DEFAULT_INPUT_NAME
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
workflow_name = "create_base_documents"
def build_steps(
config: PipelineWorkflowConfig,
) -> list[PipelineWorkflowStep]:
"""
Create the documents table.
## Dependencies
* `workflow:create_final_text_units`
"""
document_attribute_columns = config.get("document_attribute_columns", [])
return [
{
"verb": "create_base_documents",
"args": {
"document_attribute_columns": document_attribute_columns,
},
"input": {
"source": DEFAULT_INPUT_NAME,
"text_units": "workflow:create_final_text_units",
},
},
]

View File

@ -3,6 +3,10 @@
"""A module containing build_steps method definition."""
from datashaper import (
AsyncType,
)
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
workflow_name = "create_base_entity_graph"
@ -15,8 +19,53 @@ def build_steps(
Create the base table for the entity graph.
## Dependencies
* `workflow:create_base_extracted_entities`
* `workflow:create_base_summarized_entities`
"""
entity_extraction_config = config.get("entity_extract", {})
text_column = entity_extraction_config.get("text_column", "chunk")
id_column = entity_extraction_config.get("id_column", "chunk_id")
async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO)
extraction_strategy = entity_extraction_config.get("strategy")
extraction_num_threads = entity_extraction_config.get("num_threads", 4)
entity_types = entity_extraction_config.get("entity_types")
graph_merge_operations_config = config.get(
"graph_merge_operations",
{
"nodes": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
},
"edges": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
"weight": "sum",
},
},
)
node_merge_config = graph_merge_operations_config.get("nodes")
edge_merge_config = graph_merge_operations_config.get("edges")
summarize_descriptions_config = config.get("summarize_descriptions", {})
summarization_strategy = summarize_descriptions_config.get("strategy")
summarization_num_threads = summarize_descriptions_config.get("num_threads", 4)
clustering_config = config.get(
"cluster_graph",
{"strategy": {"type": "leiden"}},
@ -40,17 +89,29 @@ def build_steps(
embed_graph_enabled = config.get("embed_graph_enabled", False) or False
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False
return [
{
"verb": "create_base_entity_graph",
"args": {
"text_column": text_column,
"id_column": id_column,
"extraction_strategy": extraction_strategy,
"extraction_num_threads": extraction_num_threads,
"extraction_async_mode": async_mode,
"entity_types": entity_types,
"node_merge_config": node_merge_config,
"edge_merge_config": edge_merge_config,
"summarization_strategy": summarization_strategy,
"summarization_num_threads": summarization_num_threads,
"clustering_strategy": clustering_strategy,
"graphml_snapshot_enabled": graphml_snapshot_enabled,
"embedding_strategy": embedding_strategy
if embed_graph_enabled
else None,
"raw_entity_snapshot_enabled": raw_entity_snapshot_enabled,
"graphml_snapshot_enabled": graphml_snapshot_enabled,
},
"input": ({"source": "workflow:create_summarized_entities"}),
"input": ({"source": "workflow:create_base_text_units"}),
},
]

View File

@ -1,84 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing build_steps method definition."""
from datashaper import AsyncType
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
workflow_name = "create_base_extracted_entities"
def build_steps(
config: PipelineWorkflowConfig,
) -> list[PipelineWorkflowStep]:
"""
Create the base table for extracted entities.
## Dependencies
* `workflow:create_base_text_units`
"""
entity_extraction_config = config.get("entity_extract", {})
column = entity_extraction_config.get("text_column", "chunk")
id_column = entity_extraction_config.get("id_column", "chunk_id")
async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO)
extraction_strategy = entity_extraction_config.get("strategy")
num_threads = entity_extraction_config.get("num_threads", 4)
entity_types = entity_extraction_config.get("entity_types")
graph_merge_operations_config = config.get(
"graph_merge_operations",
{
"nodes": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
},
"edges": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
"weight": "sum",
},
},
)
nodes = graph_merge_operations_config.get("nodes")
edges = graph_merge_operations_config.get("edges")
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False
return [
{
"verb": "create_base_extracted_entities",
"args": {
"column": column,
"id_column": id_column,
"async_mode": async_mode,
"extraction_strategy": extraction_strategy,
"num_threads": num_threads,
"entity_types": entity_types,
"nodes": nodes,
"edges": edges,
"raw_entity_snapshot_enabled": raw_entity_snapshot_enabled,
"graphml_snapshot_enabled": graphml_snapshot_enabled,
},
"input": {"source": "workflow:create_base_text_units"},
},
]

View File

@ -20,7 +20,6 @@ def build_steps(
## Dependencies
* `workflow:create_base_text_units`
* `workflow:create_base_extracted_entities`
"""
claim_extract_config = config.get("claim_extract", {})
extraction_strategy = claim_extract_config.get("strategy")

View File

@ -3,6 +3,8 @@
"""A module containing build_steps method definition."""
from datashaper import DEFAULT_INPUT_NAME
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
workflow_name = "create_final_documents"
@ -15,21 +17,26 @@ def build_steps(
Create the final documents table.
## Dependencies
* `workflow:create_base_documents`
* `workflow:create_final_text_units`
"""
base_text_embed = config.get("text_embed", {})
document_raw_content_embed_config = config.get(
"document_raw_content_embed", base_text_embed
)
skip_raw_content_embedding = config.get("skip_raw_content_embedding", False)
document_attribute_columns = config.get("document_attribute_columns", [])
return [
{
"verb": "create_final_documents",
"args": {
"document_attribute_columns": document_attribute_columns,
"raw_content_text_embed": document_raw_content_embed_config
if not skip_raw_content_embedding
else None,
},
"input": {"source": "workflow:create_base_documents"},
"input": {
"source": DEFAULT_INPUT_NAME,
"text_units": "workflow:create_final_text_units",
},
},
]

View File

@ -1,36 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing build_steps method definition."""
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
workflow_name = "create_summarized_entities"
def build_steps(
config: PipelineWorkflowConfig,
) -> list[PipelineWorkflowStep]:
"""
Create the base table for extracted entities.
## Dependencies
* `workflow:create_base_text_units`
"""
summarize_descriptions_config = config.get("summarize_descriptions", {})
summarization_strategy = summarize_descriptions_config.get("strategy")
num_threads = summarize_descriptions_config.get("num_threads", 4)
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
return [
{
"verb": "create_summarized_entities",
"args": {
"summarization_strategy": summarization_strategy,
"num_threads": num_threads,
"graphml_snapshot_enabled": graphml_snapshot_enabled,
},
"input": {"source": "workflow:create_base_extracted_entities"},
},
]

View File

@ -3,9 +3,7 @@
"""The Indexing Engine workflows -> subflows package root."""
from .create_base_documents import create_base_documents
from .create_base_entity_graph import create_base_entity_graph
from .create_base_extracted_entities import create_base_extracted_entities
from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_community_reports import create_final_community_reports
@ -17,12 +15,9 @@ from .create_final_relationships import (
create_final_relationships,
)
from .create_final_text_units import create_final_text_units
from .create_summarized_entities import create_summarized_entities
__all__ = [
"create_base_documents",
"create_base_entity_graph",
"create_base_extracted_entities",
"create_base_text_units",
"create_final_communities",
"create_final_community_reports",
@ -32,5 +27,4 @@ __all__ = [
"create_final_nodes",
"create_final_relationships",
"create_final_text_units",
"create_summarized_entities",
]

View File

@ -1,41 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform base documents."""
from typing import cast
import pandas as pd
from datashaper import (
Table,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.flows.create_base_documents import (
create_base_documents as create_base_documents_flow,
)
from graphrag.index.utils.ds_util import get_required_input_table
@verb(name="create_base_documents", treats_input_tables_as_immutable=True)
def create_base_documents(
input: VerbInput,
document_attribute_columns: list[str] | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform base documents."""
source = cast(pd.DataFrame, input.get_input())
text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table)
output = create_base_documents_flow(
source, text_units, document_attribute_columns=document_attribute_columns
)
return create_verb_result(
cast(
Table,
output,
)
)

View File

@ -7,6 +7,7 @@ from typing import Any, cast
import pandas as pd
from datashaper import (
AsyncType,
Table,
VerbCallbacks,
VerbInput,
@ -14,6 +15,7 @@ from datashaper import (
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.flows.create_base_entity_graph import (
create_base_entity_graph as create_base_entity_graph_flow,
)
@ -27,10 +29,22 @@ from graphrag.index.storage import PipelineStorage
async def create_base_entity_graph(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
text_column: str,
id_column: str,
clustering_strategy: dict[str, Any],
embedding_strategy: dict[str, Any] | None,
extraction_strategy: dict[str, Any] | None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
node_merge_config: dict[str, Any] | None = None,
edge_merge_config: dict[str, Any] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
embedding_strategy: dict[str, Any] | None = None,
graphml_snapshot_enabled: bool = False,
raw_entity_snapshot_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to create the base entity graph."""
@ -39,10 +53,22 @@ async def create_base_entity_graph(
output = await create_base_entity_graph_flow(
source,
callbacks,
cache,
storage,
clustering_strategy,
embedding_strategy,
text_column,
id_column,
clustering_strategy=clustering_strategy,
extraction_strategy=extraction_strategy,
extraction_num_threads=extraction_num_threads,
extraction_async_mode=extraction_async_mode,
entity_types=entity_types,
node_merge_config=node_merge_config,
edge_merge_config=edge_merge_config,
summarization_strategy=summarization_strategy,
summarization_num_threads=summarization_num_threads,
embedding_strategy=embedding_strategy,
graphml_snapshot_enabled=graphml_snapshot_enabled,
raw_entity_snapshot_enabled=raw_entity_snapshot_enabled,
)
return create_verb_result(cast(Table, output))

View File

@ -1,63 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to extract and format base entities."""
from typing import Any, cast
import pandas as pd
from datashaper import (
AsyncType,
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.flows.create_base_extracted_entities import (
create_base_extracted_entities as create_base_extracted_entities_flow,
)
from graphrag.index.storage import PipelineStorage
@verb(name="create_base_extracted_entities", treats_input_tables_as_immutable=True)
async def create_base_extracted_entities(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
column: str,
id_column: str,
nodes: dict[str, Any],
edges: dict[str, Any],
extraction_strategy: dict[str, Any] | None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
num_threads: int = 4,
graphml_snapshot_enabled: bool = False,
raw_entity_snapshot_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to extract and format base entities."""
source = cast(pd.DataFrame, input.get_input())
output = await create_base_extracted_entities_flow(
source,
callbacks,
cache,
storage,
column,
id_column,
nodes,
edges,
extraction_strategy,
async_mode=async_mode,
entity_types=entity_types,
graphml_snapshot_enabled=graphml_snapshot_enabled,
raw_entity_snapshot_enabled=raw_entity_snapshot_enabled,
num_threads=num_threads,
)
return create_verb_result(cast(Table, output))

View File

@ -18,6 +18,7 @@ from graphrag.index.cache import PipelineCache
from graphrag.index.flows.create_final_documents import (
create_final_documents as create_final_documents_flow,
)
from graphrag.index.utils.ds_util import get_required_input_table
@verb(
@ -28,16 +29,20 @@ async def create_final_documents(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
document_attribute_columns: list[str] | None = None,
raw_content_text_embed: dict | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final documents."""
source = cast(pd.DataFrame, input.get_input())
text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table)
output = await create_final_documents_flow(
source,
text_units,
callbacks,
cache,
document_attribute_columns=document_attribute_columns,
raw_content_text_embed=raw_content_text_embed,
)

View File

@ -1,51 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to summarize entities."""
from typing import Any, cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.flows.create_summarized_entities import (
create_summarized_entities as create_summarized_entities_flow,
)
from graphrag.index.storage import PipelineStorage
@verb(
name="create_summarized_entities",
treats_input_tables_as_immutable=True,
)
async def create_summarized_entities(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
summarization_strategy: dict[str, Any] | None = None,
num_threads: int = 4,
graphml_snapshot_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to summarize entities."""
source = cast(pd.DataFrame, input.get_input())
output = await create_summarized_entities_flow(
source,
callbacks,
cache,
storage,
summarization_strategy,
num_threads=num_threads,
graphml_snapshot_enabled=graphml_snapshot_enabled,
)
return create_verb_result(cast(Table, output))

View File

@ -10,29 +10,13 @@
"subworkflows": 1,
"max_runtime": 10
},
"create_base_extracted_entities": {
"row_range": [
1,
2000
],
"subworkflows": 1,
"max_runtime": 300
},
"create_summarized_entities": {
"row_range": [
1,
2000
],
"subworkflows": 1,
"max_runtime": 300
},
"create_base_entity_graph": {
"row_range": [
1,
2000
],
"subworkflows": 1,
"max_runtime": 10
"max_runtime": 300
},
"create_final_entities": {
"row_range": [
@ -108,14 +92,6 @@
"subworkflows": 1,
"max_runtime": 100
},
"create_base_documents": {
"row_range": [
1,
2000
],
"subworkflows": 1,
"max_runtime": 10
},
"create_final_documents": {
"row_range": [
1,

View File

@ -10,14 +10,6 @@
"subworkflows": 1,
"max_runtime": 10
},
"create_base_extracted_entities": {
"row_range": [
1,
2000
],
"subworkflows": 1,
"max_runtime": 300
},
"create_final_covariates": {
"row_range": [
1,
@ -35,21 +27,13 @@
"subworkflows": 1,
"max_runtime": 300
},
"create_summarized_entities": {
"row_range": [
1,
2000
],
"subworkflows": 1,
"max_runtime": 300
},
"create_base_entity_graph": {
"row_range": [
1,
2000
],
"subworkflows": 1,
"max_runtime": 10
"max_runtime": 300
},
"create_final_entities": {
"row_range": [
@ -125,14 +109,6 @@
"subworkflows": 1,
"max_runtime": 100
},
"create_base_documents": {
"row_range": [
1,
2000
],
"subworkflows": 1,
"max_runtime": 10
},
"create_final_documents": {
"row_range": [
1,

View File

@ -19,9 +19,10 @@ workflows:
# Just lump everything together
chunk_by: []
- name: create_base_extracted_entities
- name: create_base_entity_graph
config:
graphml_snapshot: True
embed_graph_enabled: True
entity_extract:
strategy:
type: graph_intelligence
@ -37,9 +38,6 @@ workflows:
("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2)
##
("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))'
- name: create_summarized_entities
config:
summarize_descriptions:
strategy:
type: graph_intelligence
@ -47,11 +45,6 @@ workflows:
type: static_response
responses:
- This is a MOCK response for the LLM. It is summarized!
- name: create_base_entity_graph
config:
graphml_snapshot: True
embed_graph_enabled: True
cluster_graph:
strategy:
type: leiden
@ -59,8 +52,6 @@ workflows:
- name: create_final_nodes
- name: create_base_documents
- name: create_final_communities
- name: create_final_text_units
config:

View File

@ -2,8 +2,6 @@
# Licensed under the MIT License
import unittest
import networkx as nx
from graphrag.index.operations.extract_entities.strategies.graph_intelligence import (
run_extract_entities,
)
@ -119,8 +117,8 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
# self.assertItemsEqual isn't available yet, or I am just silly
# so we sort the lists and compare them
assert results.graphml_graph is not None, "No graphml graph returned!"
graph = nx.parse_graphml(results.graphml_graph) # type: ignore
graph = results.graph
assert graph is not None, "No graph returned!"
# convert to strings for more visual comparison
edges_str = sorted([f"{edge[0]} -> {edge[1]}" for edge in graph.edges])
@ -162,8 +160,8 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
),
)
assert results.graphml_graph is not None, "No graphml graph returned!"
graph = nx.parse_graphml(results.graphml_graph) # type: ignore
graph = results.graph # type: ignore
assert graph is not None, "No graph returned!"
# TODO: The edges might come back in any order, but we're assuming they're coming
# back in the order that we passed in the docs, that might not be true
@ -173,9 +171,11 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
assert (
graph.nodes["TEST_ENTITY_2"].get("source_id") == "1"
) # TEST_ENTITY_2 should be in just 1
assert sorted(
graph.nodes["TEST_ENTITY_1"].get("source_id").split(",")
) == sorted(["1", "2"]) # TEST_ENTITY_1 should be 1 and 2
ids_str = graph.nodes["TEST_ENTITY_1"].get("source_id") or ""
assert sorted(ids_str.split(",")) == sorted([
"1",
"2",
]) # TEST_ENTITY_1 should be 1 and 2
async def test_run_extract_entities_multiple_documents_correct_edge_source_ids_mapped(
self,
@ -210,8 +210,8 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
),
)
assert results.graphml_graph is not None, "No graphml graph returned!"
graph = nx.parse_graphml(results.graphml_graph) # type: ignore
graph = results.graph # type: ignore
assert graph is not None, "No graph returned!"
edges = list(graph.edges(data=True))
# should only have 2 edges

View File

@ -1,58 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.workflows.v1.create_base_documents import (
build_steps,
workflow_name,
)
from .util import (
compare_outputs,
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
)
async def test_create_base_documents():
input_tables = load_input_tables(["workflow:create_final_text_units"])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
compare_outputs(actual, expected)
async def test_create_base_documents_with_attribute_columns():
input_tables = load_input_tables(["workflow:create_final_text_units"])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["document_attribute_columns"] = ["title"]
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
# we should have dropped "title" and added "attributes"
# our test dataframe does not have attributes, so we'll assert without it
# and separately confirm it is in the output
compare_outputs(actual, expected, columns=["id", "text_units", "raw_content"])
assert len(actual.columns) == 4
assert "attributes" in actual.columns

View File

@ -2,7 +2,9 @@
# Licensed under the MIT License
import networkx as nx
import pytest
from graphrag.config.enums import LLMType
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.index.workflows.v1.create_base_entity_graph import (
build_steps,
@ -16,16 +18,48 @@ from .util import (
load_input_tables,
)
MOCK_LLM_ENTITY_RESPONSES = [
"""
("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company)
##
("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A)
##
("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A)
##
("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2)
##
("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))
""".strip()
]
MOCK_LLM_ENTITY_CONFIG = {
"type": LLMType.StaticResponse,
"responses": MOCK_LLM_ENTITY_RESPONSES,
}
MOCK_LLM_SUMMARIZATION_RESPONSES = [
"""
This is a MOCK response for the LLM. It is summarized!
""".strip()
]
MOCK_LLM_SUMMARIZATION_CONFIG = {
"type": LLMType.StaticResponse,
"responses": MOCK_LLM_SUMMARIZATION_RESPONSES,
}
async def test_create_base_entity_graph():
input_tables = load_input_tables([
"workflow:create_summarized_entities",
"workflow:create_base_text_units",
])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG
steps = build_steps(config)
@ -37,34 +71,36 @@ async def test_create_base_entity_graph():
storage=storage,
)
# the serialization of the graph may differ so we can't assert the dataframes directly
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
assert len(actual.columns) == len(
expected.columns
), "Graph dataframe columns differ"
# let's parse a sample of the raw graphml
actual_graphml_0 = actual["clustered_graph"][:1][0]
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
expected_graphml_0 = expected["clustered_graph"][:1][0]
expected_graph_0 = nx.parse_graphml(expected_graphml_0)
assert actual_graph_0.number_of_nodes() == 3
assert actual_graph_0.number_of_edges() == 2
assert (
actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes()
), "Graphml node count differs"
assert (
actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges()
), "Graphml edge count differs"
# TODO: with the combined verb we can't force summarization
# this is because the mock responses always result in a single description, which is returned verbatim rather than summarized
# we need to update the mocking to provide somewhat unique graphs so a true merge happens
# the assertion should grab a node and ensure the description matches the mock description, not the original as we are doing below
nodes = list(actual_graph_0.nodes(data=True))
assert nodes[0][1]["description"] == "Company_A is a test company"
assert len(storage.keys()) == 0, "Storage should be empty"
async def test_create_base_entity_graph_with_embeddings():
input_tables = load_input_tables([
"workflow:create_summarized_entities",
"workflow:create_base_text_units",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG
config["embed_graph_enabled"] = True
steps = build_steps(config)
@ -84,19 +120,22 @@ async def test_create_base_entity_graph_with_embeddings():
async def test_create_base_entity_graph_with_snapshots():
input_tables = load_input_tables([
"workflow:create_summarized_entities",
"workflow:create_base_text_units",
])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG
config["raw_entity_snapshot"] = True
config["graphml_snapshot"] = True
config["embed_graph_enabled"] = True # need this on in order to see the snapshot
steps = build_steps(config)
actual = await get_workflow_output(
await get_workflow_output(
input_tables,
{
"steps": steps,
@ -104,15 +143,31 @@ async def test_create_base_entity_graph_with_snapshots():
storage=storage,
)
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
assert storage.keys() == [
"clustered_graph.0.graphml",
"clustered_graph.1.graphml",
"clustered_graph.2.graphml",
"clustered_graph.3.graphml",
"embedded_graph.0.graphml",
"embedded_graph.1.graphml",
"embedded_graph.2.graphml",
"embedded_graph.3.graphml",
"raw_extracted_entities.json",
"merged_graph.graphml",
"summarized_graph.graphml",
"clustered_graph.graphml",
"embedded_graph.graphml",
], "Graph snapshot keys differ"
async def test_create_base_entity_graph_missing_llm_throws():
input_tables = load_input_tables([
"workflow:create_base_text_units",
])
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
del config["summarize_descriptions"]["strategy"]["llm"]
steps = build_steps(config)
with pytest.raises(ValueError): # noqa PT011
await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

View File

@ -1,118 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import networkx as nx
import pytest
from datashaper.errors import VerbParallelizationError
from graphrag.config.enums import LLMType
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.index.workflows.v1.create_base_extracted_entities import (
build_steps,
workflow_name,
)
from .util import (
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
)
MOCK_LLM_RESPONSES = [
"""
("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company)
##
("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A)
##
("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A)
##
("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2)
##
("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))
""".strip()
]
MOCK_LLM_CONFIG = {
"type": LLMType.StaticResponse,
"responses": MOCK_LLM_RESPONSES,
}
async def test_create_base_extracted_entities():
input_tables = load_input_tables(["workflow:create_base_text_units"])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
storage=storage,
)
# let's parse a sample of the raw graphml
actual_graphml_0 = actual["entity_graph"][:1][0]
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
assert actual_graph_0.number_of_nodes() == 3
assert actual_graph_0.number_of_edges() == 2
assert actual.columns == expected.columns
assert len(storage.keys()) == 0, "Storage should be empty"
async def test_create_base_extracted_entities_with_snapshots():
input_tables = load_input_tables(["workflow:create_base_text_units"])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG
config["raw_entity_snapshot"] = True
config["graphml_snapshot"] = True
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
storage=storage,
)
print(storage.keys())
assert actual.columns == expected.columns
assert storage.keys() == ["raw_extracted_entities.json", "merged_graph.graphml"]
async def test_create_base_extracted_entities_missing_llm_throws():
input_tables = load_input_tables(["workflow:create_base_text_units"])
config = get_config_for_workflow(workflow_name)
del config["entity_extract"]["strategy"]["llm"]
steps = build_steps(config)
with pytest.raises(VerbParallelizationError):
await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

View File

@ -17,7 +17,7 @@ from .util import (
async def test_create_final_documents():
input_tables = load_input_tables([
"workflow:create_base_documents",
"workflow:create_final_text_units",
])
expected = load_expected(workflow_name)
@ -39,7 +39,7 @@ async def test_create_final_documents():
async def test_create_final_documents_with_embeddings():
input_tables = load_input_tables([
"workflow:create_base_documents",
"workflow:create_final_text_units",
])
expected = load_expected(workflow_name)
@ -63,3 +63,28 @@ async def test_create_final_documents_with_embeddings():
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["raw_content_embedding"][:1][0]) == 3
async def test_create_final_documents_with_attribute_columns():
input_tables = load_input_tables(["workflow:create_final_text_units"])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["document_attribute_columns"] = ["title"]
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
# we should have dropped "title" and added "attributes"
# our test dataframe does not have attributes, so we'll assert without it
# and separately confirm it is in the output
compare_outputs(actual, expected, columns=["id", "text_unit_ids", "raw_content"])
assert len(actual.columns) == 4
assert "attributes" in actual.columns

View File

@ -1,129 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import networkx as nx
import pytest
from graphrag.config.enums import LLMType
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.index.workflows.v1.create_summarized_entities import (
build_steps,
workflow_name,
)
from .util import (
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
)
MOCK_LLM_RESPONSES = [
"""
This is a MOCK response for the LLM. It is summarized!
""".strip()
]
MOCK_LLM_CONFIG = {
"type": LLMType.StaticResponse,
"responses": MOCK_LLM_RESPONSES,
}
async def test_create_summarized_entities():
input_tables = load_input_tables([
"workflow:create_base_extracted_entities",
])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
config = get_config_for_workflow(workflow_name)
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_CONFIG
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
storage=storage,
)
# the serialization of the graph may differ so we can't assert the dataframes directly
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
# let's parse a sample of the raw graphml
actual_graphml_0 = actual["entity_graph"][:1][0]
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
expected_graphml_0 = expected["entity_graph"][:1][0]
expected_graph_0 = nx.parse_graphml(expected_graphml_0)
assert (
actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes()
), "Graphml node count differs"
assert (
actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges()
), "Graphml edge count differs"
# ensure the mock summary was injected to the nodes
nodes = list(actual_graph_0.nodes(data=True))
assert (
nodes[0][1]["description"]
== "This is a MOCK response for the LLM. It is summarized!"
)
assert len(storage.keys()) == 0, "Storage should be empty"
async def test_create_summarized_entities_with_snapshots():
input_tables = load_input_tables([
"workflow:create_base_extracted_entities",
])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
config = get_config_for_workflow(workflow_name)
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_CONFIG
config["graphml_snapshot"] = True
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
storage=storage,
)
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
assert storage.keys() == [
"summarized_graph.graphml",
], "Graph snapshot keys differ"
async def test_create_summarized_entities_missing_llm_throws():
input_tables = load_input_tables([
"workflow:create_base_extracted_entities",
])
config = get_config_for_workflow(workflow_name)
del config["summarize_descriptions"]["strategy"]["llm"]
steps = build_steps(config)
with pytest.raises(ValueError): # noqa PT011
await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

View File

@ -26,7 +26,7 @@ def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
# all workflows implicitly receive the `input` source, which is formatted as a dataframe after loading from storage
# we'll simulate that by just loading one of our output parquets and converting back to equivalent dataframe
# so we aren't dealing with storage vagaries (which would become an integration test)
source = pd.read_parquet("tests/verbs/data/create_base_documents.parquet")
source = pd.read_parquet("tests/verbs/data/create_final_documents.parquet")
source.rename(columns={"raw_content": "text"}, inplace=True)
input_tables["source"] = cast(pd.DataFrame, source[["id", "text", "title"]])