mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-26 22:48:42 +00:00
Collapse graph documents workflows (#1284)
* Copy base documents logic into final documents * Delete create_base_documents * Combine graph creation under create_base_entity_graph * Delete collapsed workflows * Migrate most graph internals to nx.Graph * Fix None edge case * Semver * Remove comment typo * Fix smoke tests
This commit is contained in:
parent
137a5cd550
commit
ce5b1207e0
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Collapse intermediate workflow outputs."
|
||||
}
|
||||
@ -49,9 +49,7 @@ from graphrag.index.config.workflow import (
|
||||
PipelineWorkflowReference,
|
||||
)
|
||||
from graphrag.index.workflows.default_workflows import (
|
||||
create_base_documents,
|
||||
create_base_entity_graph,
|
||||
create_base_extracted_entities,
|
||||
create_base_text_units,
|
||||
create_final_communities,
|
||||
create_final_community_reports,
|
||||
@ -61,7 +59,6 @@ from graphrag.index.workflows.default_workflows import (
|
||||
create_final_nodes,
|
||||
create_final_relationships,
|
||||
create_final_text_units,
|
||||
create_summarized_entities,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -173,17 +170,12 @@ def _document_workflows(
|
||||
)
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=create_base_documents,
|
||||
name=create_final_documents,
|
||||
config={
|
||||
"document_attribute_columns": list(
|
||||
{*(settings.input.document_attribute_columns)}
|
||||
- builtin_document_attributes
|
||||
)
|
||||
},
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_documents,
|
||||
config={
|
||||
),
|
||||
"document_raw_content_embed": _get_embedding_settings(
|
||||
settings.embeddings,
|
||||
"document_raw_content",
|
||||
@ -267,10 +259,9 @@ def _graph_workflows(
|
||||
)
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=create_base_extracted_entities,
|
||||
name=create_base_entity_graph,
|
||||
config={
|
||||
"graphml_snapshot": settings.snapshots.graphml,
|
||||
"raw_entity_snapshot": settings.snapshots.raw_entities,
|
||||
"entity_extract": {
|
||||
**settings.entity_extraction.parallelization.model_dump(),
|
||||
"async_mode": settings.entity_extraction.async_mode,
|
||||
@ -279,12 +270,6 @@ def _graph_workflows(
|
||||
),
|
||||
"entity_types": settings.entity_extraction.entity_types,
|
||||
},
|
||||
},
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=create_summarized_entities,
|
||||
config={
|
||||
"graphml_snapshot": settings.snapshots.graphml,
|
||||
"summarize_descriptions": {
|
||||
**settings.summarize_descriptions.parallelization.model_dump(),
|
||||
"async_mode": settings.summarize_descriptions.async_mode,
|
||||
@ -292,12 +277,6 @@ def _graph_workflows(
|
||||
settings.root_dir,
|
||||
),
|
||||
},
|
||||
},
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=create_base_entity_graph,
|
||||
config={
|
||||
"graphml_snapshot": settings.snapshots.graphml,
|
||||
"embed_graph_enabled": settings.embed_graph.enabled,
|
||||
"cluster_graph": {
|
||||
"strategy": settings.cluster_graph.resolved_strategy()
|
||||
|
||||
@ -1,64 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Transform base documents by joining them with their text_units and adding optional attributes."""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def create_base_documents(
|
||||
documents: pd.DataFrame,
|
||||
text_units: pd.DataFrame,
|
||||
document_attribute_columns: list[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Transform base documents by joining them with their text_units and adding optional attributes."""
|
||||
exploded = (
|
||||
text_units.explode("document_ids")
|
||||
.loc[:, ["id", "document_ids", "text"]]
|
||||
.rename(
|
||||
columns={
|
||||
"document_ids": "chunk_doc_id",
|
||||
"id": "chunk_id",
|
||||
"text": "chunk_text",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
joined = exploded.merge(
|
||||
documents,
|
||||
left_on="chunk_doc_id",
|
||||
right_on="id",
|
||||
how="inner",
|
||||
copy=False,
|
||||
)
|
||||
|
||||
docs_with_text_units = joined.groupby("id", sort=False).agg(
|
||||
text_units=("chunk_id", list)
|
||||
)
|
||||
|
||||
rejoined = docs_with_text_units.merge(
|
||||
documents,
|
||||
on="id",
|
||||
how="right",
|
||||
copy=False,
|
||||
).reset_index(drop=True)
|
||||
|
||||
rejoined.rename(columns={"text": "raw_content"}, inplace=True)
|
||||
rejoined["id"] = rejoined["id"].astype(str)
|
||||
|
||||
# Convert attribute columns to strings and collapse them into a JSON object
|
||||
if document_attribute_columns:
|
||||
# Convert all specified columns to string at once
|
||||
rejoined[document_attribute_columns] = rejoined[
|
||||
document_attribute_columns
|
||||
].astype(str)
|
||||
|
||||
# Collapse the document_attribute_columns into a single JSON object column
|
||||
rejoined["attributes"] = rejoined[document_attribute_columns].to_dict(
|
||||
orient="records"
|
||||
)
|
||||
|
||||
# Drop the original attribute columns after collapsing them
|
||||
rejoined.drop(columns=document_attribute_columns, inplace=True)
|
||||
|
||||
return rejoined
|
||||
@ -7,26 +7,76 @@ from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
VerbCallbacks,
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.operations.cluster_graph import cluster_graph
|
||||
from graphrag.index.operations.embed_graph import embed_graph
|
||||
from graphrag.index.operations.extract_entities import extract_entities
|
||||
from graphrag.index.operations.merge_graphs import merge_graphs
|
||||
from graphrag.index.operations.snapshot import snapshot
|
||||
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
|
||||
from graphrag.index.operations.snapshot_rows import snapshot_rows
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
summarize_descriptions,
|
||||
)
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
|
||||
|
||||
async def create_base_entity_graph(
|
||||
entities: pd.DataFrame,
|
||||
text_units: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
text_column: str,
|
||||
id_column: str,
|
||||
clustering_strategy: dict[str, Any],
|
||||
embedding_strategy: dict[str, Any] | None,
|
||||
extraction_strategy: dict[str, Any] | None = None,
|
||||
extraction_num_threads: int = 4,
|
||||
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
entity_types: list[str] | None = None,
|
||||
node_merge_config: dict[str, Any] | None = None,
|
||||
edge_merge_config: dict[str, Any] | None = None,
|
||||
summarization_strategy: dict[str, Any] | None = None,
|
||||
summarization_num_threads: int = 4,
|
||||
embedding_strategy: dict[str, Any] | None = None,
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
raw_entity_snapshot_enabled: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to create the base entity graph."""
|
||||
# this returns a graph for each text unit, to be merged later
|
||||
entities, entity_graphs = await extract_entities(
|
||||
text_units,
|
||||
callbacks,
|
||||
cache,
|
||||
text_column=text_column,
|
||||
id_column=id_column,
|
||||
strategy=extraction_strategy,
|
||||
async_mode=extraction_async_mode,
|
||||
entity_types=entity_types,
|
||||
to="entities",
|
||||
num_threads=extraction_num_threads,
|
||||
)
|
||||
|
||||
merged_graph = merge_graphs(
|
||||
entity_graphs,
|
||||
callbacks,
|
||||
node_operations=node_merge_config,
|
||||
edge_operations=edge_merge_config,
|
||||
)
|
||||
|
||||
summarized = await summarize_descriptions(
|
||||
merged_graph,
|
||||
callbacks,
|
||||
cache,
|
||||
strategy=summarization_strategy,
|
||||
num_threads=summarization_num_threads,
|
||||
)
|
||||
|
||||
clustered = cluster_graph(
|
||||
entities,
|
||||
summarized,
|
||||
callbacks,
|
||||
column="entity_graph",
|
||||
strategy=clustering_strategy,
|
||||
@ -34,15 +84,6 @@ async def create_base_entity_graph(
|
||||
level_to="level",
|
||||
)
|
||||
|
||||
if graphml_snapshot_enabled:
|
||||
await snapshot_rows(
|
||||
clustered,
|
||||
column="clustered_graph",
|
||||
base_name="clustered_graph",
|
||||
storage=storage,
|
||||
formats=[{"format": "text", "extension": "graphml"}],
|
||||
)
|
||||
|
||||
if embedding_strategy:
|
||||
clustered["embeddings"] = await embed_graph(
|
||||
clustered,
|
||||
@ -51,16 +92,40 @@ async def create_base_entity_graph(
|
||||
strategy=embedding_strategy,
|
||||
)
|
||||
|
||||
# take second snapshot after embedding
|
||||
# todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph?
|
||||
if raw_entity_snapshot_enabled:
|
||||
await snapshot(
|
||||
entities,
|
||||
name="raw_extracted_entities",
|
||||
storage=storage,
|
||||
formats=["json"],
|
||||
)
|
||||
|
||||
if graphml_snapshot_enabled:
|
||||
await snapshot_graphml(
|
||||
merged_graph,
|
||||
name="merged_graph",
|
||||
storage=storage,
|
||||
)
|
||||
await snapshot_graphml(
|
||||
summarized,
|
||||
name="summarized_graph",
|
||||
storage=storage,
|
||||
)
|
||||
await snapshot_rows(
|
||||
clustered,
|
||||
column="entity_graph",
|
||||
base_name="embedded_graph",
|
||||
column="clustered_graph",
|
||||
base_name="clustered_graph",
|
||||
storage=storage,
|
||||
formats=[{"format": "text", "extension": "graphml"}],
|
||||
)
|
||||
if embedding_strategy:
|
||||
await snapshot_rows(
|
||||
clustered,
|
||||
column="entity_graph",
|
||||
base_name="embedded_graph",
|
||||
storage=storage,
|
||||
formats=[{"format": "text", "extension": "graphml"}],
|
||||
)
|
||||
|
||||
final_columns = ["level", "clustered_graph"]
|
||||
if embedding_strategy:
|
||||
|
||||
@ -1,79 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to extract and format covariates."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
VerbCallbacks,
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.operations.extract_entities import extract_entities
|
||||
from graphrag.index.operations.merge_graphs import merge_graphs
|
||||
from graphrag.index.operations.snapshot import snapshot
|
||||
from graphrag.index.operations.snapshot_rows import snapshot_rows
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
|
||||
|
||||
async def create_base_extracted_entities(
|
||||
text_units: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
column: str,
|
||||
id_column: str,
|
||||
nodes: dict[str, Any],
|
||||
edges: dict[str, Any],
|
||||
extraction_strategy: dict[str, Any] | None,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
entity_types: list[str] | None = None,
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
raw_entity_snapshot_enabled: bool = False,
|
||||
num_threads: int = 4,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to extract and format covariates."""
|
||||
entity_graph = await extract_entities(
|
||||
text_units,
|
||||
callbacks,
|
||||
cache,
|
||||
column=column,
|
||||
id_column=id_column,
|
||||
strategy=extraction_strategy,
|
||||
async_mode=async_mode,
|
||||
entity_types=entity_types,
|
||||
to="entities",
|
||||
graph_to="entity_graph",
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
if raw_entity_snapshot_enabled:
|
||||
await snapshot(
|
||||
entity_graph,
|
||||
name="raw_extracted_entities",
|
||||
storage=storage,
|
||||
formats=["json"],
|
||||
)
|
||||
|
||||
merged_graph = merge_graphs(
|
||||
entity_graph,
|
||||
callbacks,
|
||||
column="entity_graph",
|
||||
to="entity_graph",
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
)
|
||||
|
||||
if graphml_snapshot_enabled:
|
||||
await snapshot_rows(
|
||||
merged_graph,
|
||||
base_name="merged_graph",
|
||||
column="entity_graph",
|
||||
storage=storage,
|
||||
formats=[{"format": "text", "extension": "graphml"}],
|
||||
)
|
||||
|
||||
return merged_graph
|
||||
@ -14,20 +14,71 @@ from graphrag.index.operations.embed_text import embed_text
|
||||
|
||||
async def create_final_documents(
|
||||
documents: pd.DataFrame,
|
||||
text_units: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
document_attribute_columns: list[str] | None = None,
|
||||
raw_content_text_embed: dict | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform final documents."""
|
||||
documents.rename(columns={"text_units": "text_unit_ids"}, inplace=True)
|
||||
exploded = (
|
||||
text_units.explode("document_ids")
|
||||
.loc[:, ["id", "document_ids", "text"]]
|
||||
.rename(
|
||||
columns={
|
||||
"document_ids": "chunk_doc_id",
|
||||
"id": "chunk_id",
|
||||
"text": "chunk_text",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
joined = exploded.merge(
|
||||
documents,
|
||||
left_on="chunk_doc_id",
|
||||
right_on="id",
|
||||
how="inner",
|
||||
copy=False,
|
||||
)
|
||||
|
||||
docs_with_text_units = joined.groupby("id", sort=False).agg(
|
||||
text_units=("chunk_id", list)
|
||||
)
|
||||
|
||||
rejoined = docs_with_text_units.merge(
|
||||
documents,
|
||||
on="id",
|
||||
how="right",
|
||||
copy=False,
|
||||
).reset_index(drop=True)
|
||||
|
||||
rejoined.rename(
|
||||
columns={"text": "raw_content", "text_units": "text_unit_ids"}, inplace=True
|
||||
)
|
||||
rejoined["id"] = rejoined["id"].astype(str)
|
||||
|
||||
# Convert attribute columns to strings and collapse them into a JSON object
|
||||
if document_attribute_columns:
|
||||
# Convert all specified columns to string at once
|
||||
rejoined[document_attribute_columns] = rejoined[
|
||||
document_attribute_columns
|
||||
].astype(str)
|
||||
|
||||
# Collapse the document_attribute_columns into a single JSON object column
|
||||
rejoined["attributes"] = rejoined[document_attribute_columns].to_dict(
|
||||
orient="records"
|
||||
)
|
||||
|
||||
# Drop the original attribute columns after collapsing them
|
||||
rejoined.drop(columns=document_attribute_columns, inplace=True)
|
||||
|
||||
if raw_content_text_embed:
|
||||
documents["raw_content_embedding"] = await embed_text(
|
||||
documents,
|
||||
rejoined["raw_content_embedding"] = await embed_text(
|
||||
rejoined,
|
||||
callbacks,
|
||||
cache,
|
||||
column="raw_content",
|
||||
strategy=raw_content_text_embed["strategy"],
|
||||
)
|
||||
|
||||
return documents
|
||||
return rejoined
|
||||
|
||||
@ -1,50 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to summarize entities."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
VerbCallbacks,
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.operations.snapshot_rows import snapshot_rows
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
summarize_descriptions,
|
||||
)
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
|
||||
|
||||
async def create_summarized_entities(
|
||||
entities: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
summarization_strategy: dict[str, Any] | None = None,
|
||||
num_threads: int = 4,
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to summarize entities."""
|
||||
summarized = await summarize_descriptions(
|
||||
entities,
|
||||
callbacks,
|
||||
cache,
|
||||
column="entity_graph",
|
||||
to="entity_graph",
|
||||
strategy=summarization_strategy,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
if graphml_snapshot_enabled:
|
||||
await snapshot_rows(
|
||||
summarized,
|
||||
column="entity_graph",
|
||||
base_name="summarized_graph",
|
||||
storage=storage,
|
||||
formats=[{"format": "text", "extension": "graphml"}],
|
||||
)
|
||||
|
||||
return summarized
|
||||
@ -14,7 +14,7 @@ from datashaper import VerbCallbacks, progress_iterable
|
||||
from graspologic.partition import hierarchical_leiden
|
||||
|
||||
from graphrag.index.graph.utils import stable_largest_connected_component
|
||||
from graphrag.index.utils import gen_uuid, load_graph
|
||||
from graphrag.index.utils import gen_uuid
|
||||
|
||||
Communities = list[tuple[int, str, list[str]]]
|
||||
|
||||
@ -33,7 +33,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cluster_graph(
|
||||
input: pd.DataFrame,
|
||||
input: nx.Graph,
|
||||
callbacks: VerbCallbacks,
|
||||
strategy: dict[str, Any],
|
||||
column: str,
|
||||
@ -41,63 +41,64 @@ def cluster_graph(
|
||||
level_to: str | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Apply a hierarchical clustering algorithm to a graph."""
|
||||
results = input[column].apply(lambda graph: run_layout(strategy, graph))
|
||||
output = pd.DataFrame()
|
||||
# TODO: for back-compat, downstream expects a graphml string
|
||||
output[column] = ["\n".join(nx.generate_graphml(input))]
|
||||
communities = run_layout(strategy, input)
|
||||
|
||||
community_map_to = "communities"
|
||||
input[community_map_to] = results
|
||||
output[community_map_to] = [communities]
|
||||
|
||||
level_to = level_to or f"{to}_level"
|
||||
input[level_to] = input.apply(
|
||||
output[level_to] = output.apply(
|
||||
lambda x: list({level for level, _, _ in x[community_map_to]}), axis=1
|
||||
)
|
||||
input[to] = None
|
||||
output[to] = None
|
||||
|
||||
num_total = len(input)
|
||||
num_total = len(output)
|
||||
|
||||
# Create a seed for this run (if not provided)
|
||||
seed = strategy.get("seed", Random().randint(0, 0xFFFFFFFF)) # noqa S311
|
||||
|
||||
# Go through each of the rows
|
||||
graph_level_pairs_column: list[list[tuple[int, str]]] = []
|
||||
for _, row in progress_iterable(input.iterrows(), callbacks.progress, num_total):
|
||||
for _, row in progress_iterable(output.iterrows(), callbacks.progress, num_total):
|
||||
levels = row[level_to]
|
||||
graph_level_pairs: list[tuple[int, str]] = []
|
||||
|
||||
# For each of the levels, get the graph and add it to the list
|
||||
for level in levels:
|
||||
graph = "\n".join(
|
||||
graphml = "\n".join(
|
||||
nx.generate_graphml(
|
||||
apply_clustering(
|
||||
cast(str, row[column]),
|
||||
input,
|
||||
cast(Communities, row[community_map_to]),
|
||||
level,
|
||||
seed=seed,
|
||||
)
|
||||
)
|
||||
)
|
||||
graph_level_pairs.append((level, graph))
|
||||
graph_level_pairs.append((level, graphml))
|
||||
graph_level_pairs_column.append(graph_level_pairs)
|
||||
input[to] = graph_level_pairs_column
|
||||
output[to] = graph_level_pairs_column
|
||||
|
||||
# explode the list of (level, graph) pairs into separate rows
|
||||
input = input.explode(to, ignore_index=True)
|
||||
output = output.explode(to, ignore_index=True)
|
||||
|
||||
# split the (level, graph) pairs into separate columns
|
||||
# TODO: There is probably a better way to do this
|
||||
input[[level_to, to]] = pd.DataFrame(input[to].tolist(), index=input.index)
|
||||
output[[level_to, to]] = pd.DataFrame(output[to].tolist(), index=output.index)
|
||||
|
||||
# clean up the community map
|
||||
input.drop(columns=[community_map_to], inplace=True)
|
||||
return input
|
||||
output.drop(columns=[community_map_to], inplace=True)
|
||||
return output
|
||||
|
||||
|
||||
# TODO: This should support str | nx.Graph as a graphml param
|
||||
def apply_clustering(
|
||||
graphml: str, communities: Communities, level: int = 0, seed: int | None = None
|
||||
graph: nx.Graph, communities: Communities, level: int = 0, seed: int | None = None
|
||||
) -> nx.Graph:
|
||||
"""Apply clustering to a graphml string."""
|
||||
"""Apply clustering to a graph."""
|
||||
random = Random(seed) # noqa S311
|
||||
graph = nx.parse_graphml(graphml)
|
||||
for community_level, community_id, nodes in communities:
|
||||
if level == community_level:
|
||||
for node in nodes:
|
||||
@ -121,11 +122,8 @@ def apply_clustering(
|
||||
return graph
|
||||
|
||||
|
||||
def run_layout(
|
||||
strategy: dict[str, Any], graphml_or_graph: str | nx.Graph
|
||||
) -> Communities:
|
||||
def run_layout(strategy: dict[str, Any], graph: nx.Graph) -> Communities:
|
||||
"""Run layout method definition."""
|
||||
graph = load_graph(graphml_or_graph)
|
||||
if len(graph.nodes) == 0:
|
||||
log.warning("Graph has no nodes")
|
||||
return []
|
||||
|
||||
@ -7,6 +7,7 @@ import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
@ -41,15 +42,14 @@ async def extract_entities(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
text_column: str,
|
||||
id_column: str,
|
||||
to: str,
|
||||
strategy: dict[str, Any] | None,
|
||||
graph_to: str | None = None,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
entity_types=DEFAULT_ENTITY_TYPES,
|
||||
num_threads: int = 4,
|
||||
) -> pd.DataFrame:
|
||||
) -> tuple[pd.DataFrame, list[nx.Graph]]:
|
||||
"""
|
||||
Extract entities from a piece of text.
|
||||
|
||||
@ -59,7 +59,6 @@ async def extract_entities(
|
||||
column: the_document_text_column_to_extract_entities_from
|
||||
id_column: the_column_with_the_unique_id_for_each_row
|
||||
to: the_column_to_output_the_entities_to
|
||||
graph_to: the_column_to_output_the_graphml_to
|
||||
strategy: <strategy_config>, see strategies section below
|
||||
summarize_descriptions: true | false /* Optional: This will summarize the descriptions of the entities and relationships, default: true */
|
||||
entity_types:
|
||||
@ -124,7 +123,7 @@ async def extract_entities(
|
||||
|
||||
async def run_strategy(row):
|
||||
nonlocal num_started
|
||||
text = row[column]
|
||||
text = row[text_column]
|
||||
id = row[id_column]
|
||||
result = await strategy_exec(
|
||||
[Document(text=text, id=id)],
|
||||
@ -134,7 +133,7 @@ async def extract_entities(
|
||||
strategy_config,
|
||||
)
|
||||
num_started += 1
|
||||
return [result.entities, result.graphml_graph]
|
||||
return [result.entities, result.graph]
|
||||
|
||||
results = await derive_from_rows(
|
||||
input,
|
||||
@ -145,20 +144,18 @@ async def extract_entities(
|
||||
)
|
||||
|
||||
to_result = []
|
||||
graph_to_result = []
|
||||
graphs = []
|
||||
for result in results:
|
||||
if result:
|
||||
to_result.append(result[0])
|
||||
graph_to_result.append(result[1])
|
||||
graphs.append(result[1])
|
||||
else:
|
||||
to_result.append(None)
|
||||
graph_to_result.append(None)
|
||||
graphs.append(None)
|
||||
|
||||
input[to] = to_result
|
||||
if graph_to is not None:
|
||||
input[graph_to] = graph_to_result
|
||||
|
||||
return input.reset_index(drop=True)
|
||||
return (input.reset_index(drop=True), graphs)
|
||||
|
||||
|
||||
def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
"""A module containing run_graph_intelligence, run_extract_entities and _create_text_splitter methods to run graph intelligence."""
|
||||
|
||||
import networkx as nx
|
||||
from datashaper import VerbCallbacks
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
@ -113,8 +112,7 @@ async def run_extract_entities(
|
||||
if item is not None
|
||||
]
|
||||
|
||||
graph_data = "".join(nx.generate_graphml(graph))
|
||||
return EntityExtractionResult(entities, graph_data)
|
||||
return EntityExtractionResult(entities, graph)
|
||||
|
||||
|
||||
def _create_text_splitter(
|
||||
|
||||
@ -57,5 +57,5 @@ async def run( # noqa RUF029 async is required for interface
|
||||
{"type": entity_type, "name": name}
|
||||
for name, entity_type in entity_map.items()
|
||||
],
|
||||
graphml_graph="".join(nx.generate_graphml(graph)),
|
||||
graph=graph,
|
||||
)
|
||||
|
||||
@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
from datashaper import VerbCallbacks
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
@ -29,7 +30,7 @@ class EntityExtractionResult:
|
||||
"""Entity extraction result class definition."""
|
||||
|
||||
entities: list[ExtractedEntity]
|
||||
graphml_graph: str | None
|
||||
graph: nx.Graph | None
|
||||
|
||||
|
||||
EntityExtractStrategy = Callable[
|
||||
|
||||
@ -3,14 +3,11 @@
|
||||
|
||||
"""A module containing merge_graphs, merge_nodes, merge_edges, merge_attributes, apply_merge_operation and _get_detailed_attribute_merge_operation methods definitions."""
|
||||
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
from datashaper import VerbCallbacks, progress_iterable
|
||||
|
||||
from graphrag.index.utils import load_graph
|
||||
|
||||
from .typing import (
|
||||
BasicMergeOperation,
|
||||
DetailedAttributeMergeOperation,
|
||||
@ -35,15 +32,13 @@ DEFAULT_CONCAT_SEPARATOR = ","
|
||||
|
||||
|
||||
def merge_graphs(
|
||||
input: pd.DataFrame,
|
||||
graphs: list[nx.Graph],
|
||||
callbacks: VerbCallbacks,
|
||||
column: str,
|
||||
to: str,
|
||||
nodes: dict[str, Any] = DEFAULT_NODE_OPERATIONS,
|
||||
edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS,
|
||||
) -> pd.DataFrame:
|
||||
node_operations: dict[str, Any] | None,
|
||||
edge_operations: dict[str, Any] | None,
|
||||
) -> nx.Graph:
|
||||
"""
|
||||
Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph.
|
||||
Merge multiple graphs together. The graphs are expected to be in nx.Graph format. The verb outputs a new column containing the merged graph.
|
||||
|
||||
> Note: This will merge all rows into a single graph.
|
||||
|
||||
@ -51,8 +46,6 @@ def merge_graphs(
|
||||
```yaml
|
||||
verb: merge_graph
|
||||
args:
|
||||
column: clustered_graph # The name of the column containing the graph, should be a graphml graph
|
||||
to: merged_graph # The name of the column to output the merged graph to
|
||||
nodes: <node operations> # See node operations section below
|
||||
edges: <edge operations> # See edge operations section below
|
||||
```
|
||||
@ -90,8 +83,8 @@ def merge_graphs(
|
||||
- __average__: This operation takes the mean of the attribute with the last value seen.
|
||||
- __multiply__: This operation multiplies the attribute with the last value seen.
|
||||
"""
|
||||
output = pd.DataFrame()
|
||||
|
||||
nodes = node_operations or DEFAULT_NODE_OPERATIONS
|
||||
edges = edge_operations or DEFAULT_EDGE_OPERATIONS
|
||||
node_ops = {
|
||||
attrib: _get_detailed_attribute_merge_operation(value)
|
||||
for attrib, value in nodes.items()
|
||||
@ -102,15 +95,12 @@ def merge_graphs(
|
||||
}
|
||||
|
||||
mega_graph = nx.Graph()
|
||||
num_total = len(input)
|
||||
for graphml in progress_iterable(input[column], callbacks.progress, num_total):
|
||||
graph = load_graph(cast(str | nx.Graph, graphml))
|
||||
num_total = len(graphs)
|
||||
for graph in progress_iterable(graphs, callbacks.progress, num_total):
|
||||
merge_nodes(mega_graph, graph, node_ops)
|
||||
merge_edges(mega_graph, graph, edge_ops)
|
||||
|
||||
output[to] = ["\n".join(nx.generate_graphml(mega_graph))]
|
||||
|
||||
return output
|
||||
return mega_graph
|
||||
|
||||
|
||||
def merge_nodes(
|
||||
|
||||
18
graphrag/index/operations/snapshot_graphml.py
Normal file
18
graphrag/index/operations/snapshot_graphml.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing snapshot method definition."""
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
|
||||
|
||||
async def snapshot_graphml(
|
||||
input: str | nx.Graph,
|
||||
name: str,
|
||||
storage: PipelineStorage,
|
||||
) -> None:
|
||||
"""Take a entire snapshot of a graph to standard graphml format."""
|
||||
graphml = input if isinstance(input, str) else "\n".join(nx.generate_graphml(input))
|
||||
await storage.set(name + ".graphml", graphml)
|
||||
@ -5,10 +5,9 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
ProgressTicker,
|
||||
VerbCallbacks,
|
||||
@ -16,10 +15,8 @@ from datashaper import (
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.utils import load_graph
|
||||
|
||||
from .typing import (
|
||||
DescriptionSummarizeRow,
|
||||
SummarizationStrategy,
|
||||
SummarizeStrategyType,
|
||||
)
|
||||
@ -28,14 +25,12 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def summarize_descriptions(
|
||||
input: pd.DataFrame,
|
||||
input: nx.Graph,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
to: str,
|
||||
strategy: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
num_threads: int = 4,
|
||||
) -> nx.Graph:
|
||||
"""
|
||||
Summarize entity and relationship descriptions from an entity graph.
|
||||
|
||||
@ -43,25 +38,10 @@ async def summarize_descriptions(
|
||||
|
||||
To turn this feature ON please set the environment variable `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_ENABLED=True`.
|
||||
|
||||
### json
|
||||
|
||||
```json
|
||||
{
|
||||
"verb": "",
|
||||
"args": {
|
||||
"column": "the_document_text_column_to_extract_descriptions_from", /* Required: This will be a graphml graph in string form which represents the entities and their relationships */
|
||||
"to": "the_column_to_output_the_summarized_descriptions_to", /* Required: This will be a graphml graph in string form which represents the entities and their relationships after being summarized */
|
||||
"strategy": {...} <strategy_config>, see strategies section below
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### yaml
|
||||
|
||||
```yaml
|
||||
args:
|
||||
column: the_document_text_column_to_extract_descriptions_from
|
||||
to: the_column_to_output_the_summarized_descriptions_to
|
||||
strategy: <strategy_config>, see strategies section below
|
||||
```
|
||||
|
||||
@ -99,9 +79,7 @@ async def summarize_descriptions(
|
||||
)
|
||||
strategy_config = {**strategy}
|
||||
|
||||
async def get_resolved_entities(row, semaphore: asyncio.Semaphore):
|
||||
graph: nx.Graph = load_graph(cast(str | nx.Graph, getattr(row, column)))
|
||||
|
||||
async def get_resolved_entities(graph: nx.Graph, semaphore: asyncio.Semaphore):
|
||||
ticker_length = len(graph.nodes) + len(graph.edges)
|
||||
|
||||
ticker = progress_ticker(callbacks.progress, ticker_length)
|
||||
@ -134,9 +112,7 @@ async def summarize_descriptions(
|
||||
elif isinstance(graph_item, tuple) and graph_item in graph.edges():
|
||||
graph.edges[graph_item]["description"] = result.description
|
||||
|
||||
return DescriptionSummarizeRow(
|
||||
graph="\n".join(nx.generate_graphml(graph)),
|
||||
)
|
||||
return graph
|
||||
|
||||
async def do_summarize_descriptions(
|
||||
graph_item: str | tuple[str, str],
|
||||
@ -155,25 +131,9 @@ async def summarize_descriptions(
|
||||
ticker(1)
|
||||
return results
|
||||
|
||||
# Graph is always on row 0, so here a derive from rows does not work
|
||||
# This iteration will only happen once, but avoids hardcoding a iloc[0]
|
||||
# Since parallelization is at graph level (nodes and edges), we can't use
|
||||
# the parallelization of the derive_from_rows
|
||||
semaphore = asyncio.Semaphore(kwargs.get("num_threads", 4))
|
||||
semaphore = asyncio.Semaphore(num_threads)
|
||||
|
||||
results = [
|
||||
await get_resolved_entities(row, semaphore) for row in input.itertuples()
|
||||
]
|
||||
|
||||
to_result = []
|
||||
|
||||
for result in results:
|
||||
if result:
|
||||
to_result.append(result.graph)
|
||||
else:
|
||||
to_result.append(None)
|
||||
input[to] = to_result
|
||||
return input
|
||||
return await get_resolved_entities(input, semaphore)
|
||||
|
||||
|
||||
def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy:
|
||||
|
||||
@ -7,24 +7,12 @@
|
||||
from .v1.subflows import * # noqa
|
||||
|
||||
from .typing import WorkflowDefinitions
|
||||
from .v1.create_base_documents import (
|
||||
build_steps as build_create_base_documents_steps,
|
||||
)
|
||||
from .v1.create_base_documents import (
|
||||
workflow_name as create_base_documents,
|
||||
)
|
||||
from .v1.create_base_entity_graph import (
|
||||
build_steps as build_create_base_entity_graph_steps,
|
||||
)
|
||||
from .v1.create_base_entity_graph import (
|
||||
workflow_name as create_base_entity_graph,
|
||||
)
|
||||
from .v1.create_base_extracted_entities import (
|
||||
build_steps as build_create_base_extracted_entities_steps,
|
||||
)
|
||||
from .v1.create_base_extracted_entities import (
|
||||
workflow_name as create_base_extracted_entities,
|
||||
)
|
||||
from .v1.create_base_text_units import (
|
||||
build_steps as build_create_base_text_units_steps,
|
||||
)
|
||||
@ -79,16 +67,9 @@ from .v1.create_final_text_units import (
|
||||
from .v1.create_final_text_units import (
|
||||
workflow_name as create_final_text_units,
|
||||
)
|
||||
from .v1.create_summarized_entities import (
|
||||
build_steps as build_create_summarized_entities_steps,
|
||||
)
|
||||
from .v1.create_summarized_entities import (
|
||||
workflow_name as create_summarized_entities,
|
||||
)
|
||||
|
||||
|
||||
default_workflows: WorkflowDefinitions = {
|
||||
create_base_extracted_entities: build_create_base_extracted_entities_steps,
|
||||
create_base_entity_graph: build_create_base_entity_graph_steps,
|
||||
create_base_text_units: build_create_base_text_units_steps,
|
||||
create_final_text_units: build_create_final_text_units,
|
||||
@ -97,8 +78,6 @@ default_workflows: WorkflowDefinitions = {
|
||||
create_final_relationships: build_create_final_relationships_steps,
|
||||
create_final_documents: build_create_final_documents_steps,
|
||||
create_final_covariates: build_create_final_covariates_steps,
|
||||
create_base_documents: build_create_base_documents_steps,
|
||||
create_final_entities: build_create_final_entities_steps,
|
||||
create_final_communities: build_create_final_communities_steps,
|
||||
create_summarized_entities: build_create_summarized_entities_steps,
|
||||
}
|
||||
|
||||
@ -1,34 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
from datashaper import DEFAULT_INPUT_NAME
|
||||
|
||||
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
|
||||
workflow_name = "create_base_documents"
|
||||
|
||||
|
||||
def build_steps(
|
||||
config: PipelineWorkflowConfig,
|
||||
) -> list[PipelineWorkflowStep]:
|
||||
"""
|
||||
Create the documents table.
|
||||
|
||||
## Dependencies
|
||||
* `workflow:create_final_text_units`
|
||||
"""
|
||||
document_attribute_columns = config.get("document_attribute_columns", [])
|
||||
return [
|
||||
{
|
||||
"verb": "create_base_documents",
|
||||
"args": {
|
||||
"document_attribute_columns": document_attribute_columns,
|
||||
},
|
||||
"input": {
|
||||
"source": DEFAULT_INPUT_NAME,
|
||||
"text_units": "workflow:create_final_text_units",
|
||||
},
|
||||
},
|
||||
]
|
||||
@ -3,6 +3,10 @@
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
)
|
||||
|
||||
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
|
||||
workflow_name = "create_base_entity_graph"
|
||||
@ -15,8 +19,53 @@ def build_steps(
|
||||
Create the base table for the entity graph.
|
||||
|
||||
## Dependencies
|
||||
* `workflow:create_base_extracted_entities`
|
||||
* `workflow:create_base_summarized_entities`
|
||||
"""
|
||||
entity_extraction_config = config.get("entity_extract", {})
|
||||
text_column = entity_extraction_config.get("text_column", "chunk")
|
||||
id_column = entity_extraction_config.get("id_column", "chunk_id")
|
||||
async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO)
|
||||
extraction_strategy = entity_extraction_config.get("strategy")
|
||||
extraction_num_threads = entity_extraction_config.get("num_threads", 4)
|
||||
entity_types = entity_extraction_config.get("entity_types")
|
||||
|
||||
graph_merge_operations_config = config.get(
|
||||
"graph_merge_operations",
|
||||
{
|
||||
"nodes": {
|
||||
"source_id": {
|
||||
"operation": "concat",
|
||||
"delimiter": ", ",
|
||||
"distinct": True,
|
||||
},
|
||||
"description": ({
|
||||
"operation": "concat",
|
||||
"separator": "\n",
|
||||
"distinct": False,
|
||||
}),
|
||||
},
|
||||
"edges": {
|
||||
"source_id": {
|
||||
"operation": "concat",
|
||||
"delimiter": ", ",
|
||||
"distinct": True,
|
||||
},
|
||||
"description": ({
|
||||
"operation": "concat",
|
||||
"separator": "\n",
|
||||
"distinct": False,
|
||||
}),
|
||||
"weight": "sum",
|
||||
},
|
||||
},
|
||||
)
|
||||
node_merge_config = graph_merge_operations_config.get("nodes")
|
||||
edge_merge_config = graph_merge_operations_config.get("edges")
|
||||
|
||||
summarize_descriptions_config = config.get("summarize_descriptions", {})
|
||||
summarization_strategy = summarize_descriptions_config.get("strategy")
|
||||
summarization_num_threads = summarize_descriptions_config.get("num_threads", 4)
|
||||
|
||||
clustering_config = config.get(
|
||||
"cluster_graph",
|
||||
{"strategy": {"type": "leiden"}},
|
||||
@ -40,17 +89,29 @@ def build_steps(
|
||||
embed_graph_enabled = config.get("embed_graph_enabled", False) or False
|
||||
|
||||
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
|
||||
raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": "create_base_entity_graph",
|
||||
"args": {
|
||||
"text_column": text_column,
|
||||
"id_column": id_column,
|
||||
"extraction_strategy": extraction_strategy,
|
||||
"extraction_num_threads": extraction_num_threads,
|
||||
"extraction_async_mode": async_mode,
|
||||
"entity_types": entity_types,
|
||||
"node_merge_config": node_merge_config,
|
||||
"edge_merge_config": edge_merge_config,
|
||||
"summarization_strategy": summarization_strategy,
|
||||
"summarization_num_threads": summarization_num_threads,
|
||||
"clustering_strategy": clustering_strategy,
|
||||
"graphml_snapshot_enabled": graphml_snapshot_enabled,
|
||||
"embedding_strategy": embedding_strategy
|
||||
if embed_graph_enabled
|
||||
else None,
|
||||
"raw_entity_snapshot_enabled": raw_entity_snapshot_enabled,
|
||||
"graphml_snapshot_enabled": graphml_snapshot_enabled,
|
||||
},
|
||||
"input": ({"source": "workflow:create_summarized_entities"}),
|
||||
"input": ({"source": "workflow:create_base_text_units"}),
|
||||
},
|
||||
]
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
from datashaper import AsyncType
|
||||
|
||||
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
|
||||
workflow_name = "create_base_extracted_entities"
|
||||
|
||||
|
||||
def build_steps(
|
||||
config: PipelineWorkflowConfig,
|
||||
) -> list[PipelineWorkflowStep]:
|
||||
"""
|
||||
Create the base table for extracted entities.
|
||||
|
||||
## Dependencies
|
||||
* `workflow:create_base_text_units`
|
||||
"""
|
||||
entity_extraction_config = config.get("entity_extract", {})
|
||||
|
||||
column = entity_extraction_config.get("text_column", "chunk")
|
||||
id_column = entity_extraction_config.get("id_column", "chunk_id")
|
||||
async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO)
|
||||
extraction_strategy = entity_extraction_config.get("strategy")
|
||||
num_threads = entity_extraction_config.get("num_threads", 4)
|
||||
entity_types = entity_extraction_config.get("entity_types")
|
||||
|
||||
graph_merge_operations_config = config.get(
|
||||
"graph_merge_operations",
|
||||
{
|
||||
"nodes": {
|
||||
"source_id": {
|
||||
"operation": "concat",
|
||||
"delimiter": ", ",
|
||||
"distinct": True,
|
||||
},
|
||||
"description": ({
|
||||
"operation": "concat",
|
||||
"separator": "\n",
|
||||
"distinct": False,
|
||||
}),
|
||||
},
|
||||
"edges": {
|
||||
"source_id": {
|
||||
"operation": "concat",
|
||||
"delimiter": ", ",
|
||||
"distinct": True,
|
||||
},
|
||||
"description": ({
|
||||
"operation": "concat",
|
||||
"separator": "\n",
|
||||
"distinct": False,
|
||||
}),
|
||||
"weight": "sum",
|
||||
},
|
||||
},
|
||||
)
|
||||
nodes = graph_merge_operations_config.get("nodes")
|
||||
edges = graph_merge_operations_config.get("edges")
|
||||
|
||||
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
|
||||
raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": "create_base_extracted_entities",
|
||||
"args": {
|
||||
"column": column,
|
||||
"id_column": id_column,
|
||||
"async_mode": async_mode,
|
||||
"extraction_strategy": extraction_strategy,
|
||||
"num_threads": num_threads,
|
||||
"entity_types": entity_types,
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"raw_entity_snapshot_enabled": raw_entity_snapshot_enabled,
|
||||
"graphml_snapshot_enabled": graphml_snapshot_enabled,
|
||||
},
|
||||
"input": {"source": "workflow:create_base_text_units"},
|
||||
},
|
||||
]
|
||||
@ -20,7 +20,6 @@ def build_steps(
|
||||
|
||||
## Dependencies
|
||||
* `workflow:create_base_text_units`
|
||||
* `workflow:create_base_extracted_entities`
|
||||
"""
|
||||
claim_extract_config = config.get("claim_extract", {})
|
||||
extraction_strategy = claim_extract_config.get("strategy")
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
from datashaper import DEFAULT_INPUT_NAME
|
||||
|
||||
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
|
||||
workflow_name = "create_final_documents"
|
||||
@ -15,21 +17,26 @@ def build_steps(
|
||||
Create the final documents table.
|
||||
|
||||
## Dependencies
|
||||
* `workflow:create_base_documents`
|
||||
* `workflow:create_final_text_units`
|
||||
"""
|
||||
base_text_embed = config.get("text_embed", {})
|
||||
document_raw_content_embed_config = config.get(
|
||||
"document_raw_content_embed", base_text_embed
|
||||
)
|
||||
skip_raw_content_embedding = config.get("skip_raw_content_embedding", False)
|
||||
document_attribute_columns = config.get("document_attribute_columns", [])
|
||||
return [
|
||||
{
|
||||
"verb": "create_final_documents",
|
||||
"args": {
|
||||
"document_attribute_columns": document_attribute_columns,
|
||||
"raw_content_text_embed": document_raw_content_embed_config
|
||||
if not skip_raw_content_embedding
|
||||
else None,
|
||||
},
|
||||
"input": {"source": "workflow:create_base_documents"},
|
||||
"input": {
|
||||
"source": DEFAULT_INPUT_NAME,
|
||||
"text_units": "workflow:create_final_text_units",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@ -1,36 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
|
||||
workflow_name = "create_summarized_entities"
|
||||
|
||||
|
||||
def build_steps(
|
||||
config: PipelineWorkflowConfig,
|
||||
) -> list[PipelineWorkflowStep]:
|
||||
"""
|
||||
Create the base table for extracted entities.
|
||||
|
||||
## Dependencies
|
||||
* `workflow:create_base_text_units`
|
||||
"""
|
||||
summarize_descriptions_config = config.get("summarize_descriptions", {})
|
||||
summarization_strategy = summarize_descriptions_config.get("strategy")
|
||||
num_threads = summarize_descriptions_config.get("num_threads", 4)
|
||||
|
||||
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": "create_summarized_entities",
|
||||
"args": {
|
||||
"summarization_strategy": summarization_strategy,
|
||||
"num_threads": num_threads,
|
||||
"graphml_snapshot_enabled": graphml_snapshot_enabled,
|
||||
},
|
||||
"input": {"source": "workflow:create_base_extracted_entities"},
|
||||
},
|
||||
]
|
||||
@ -3,9 +3,7 @@
|
||||
|
||||
"""The Indexing Engine workflows -> subflows package root."""
|
||||
|
||||
from .create_base_documents import create_base_documents
|
||||
from .create_base_entity_graph import create_base_entity_graph
|
||||
from .create_base_extracted_entities import create_base_extracted_entities
|
||||
from .create_base_text_units import create_base_text_units
|
||||
from .create_final_communities import create_final_communities
|
||||
from .create_final_community_reports import create_final_community_reports
|
||||
@ -17,12 +15,9 @@ from .create_final_relationships import (
|
||||
create_final_relationships,
|
||||
)
|
||||
from .create_final_text_units import create_final_text_units
|
||||
from .create_summarized_entities import create_summarized_entities
|
||||
|
||||
__all__ = [
|
||||
"create_base_documents",
|
||||
"create_base_entity_graph",
|
||||
"create_base_extracted_entities",
|
||||
"create_base_text_units",
|
||||
"create_final_communities",
|
||||
"create_final_community_reports",
|
||||
@ -32,5 +27,4 @@ __all__ = [
|
||||
"create_final_nodes",
|
||||
"create_final_relationships",
|
||||
"create_final_text_units",
|
||||
"create_summarized_entities",
|
||||
]
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform base documents."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.flows.create_base_documents import (
|
||||
create_base_documents as create_base_documents_flow,
|
||||
)
|
||||
from graphrag.index.utils.ds_util import get_required_input_table
|
||||
|
||||
|
||||
@verb(name="create_base_documents", treats_input_tables_as_immutable=True)
|
||||
def create_base_documents(
|
||||
input: VerbInput,
|
||||
document_attribute_columns: list[str] | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform base documents."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table)
|
||||
|
||||
output = create_base_documents_flow(
|
||||
source, text_units, document_attribute_columns=document_attribute_columns
|
||||
)
|
||||
|
||||
return create_verb_result(
|
||||
cast(
|
||||
Table,
|
||||
output,
|
||||
)
|
||||
)
|
||||
@ -7,6 +7,7 @@ from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
@ -14,6 +15,7 @@ from datashaper import (
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.create_base_entity_graph import (
|
||||
create_base_entity_graph as create_base_entity_graph_flow,
|
||||
)
|
||||
@ -27,10 +29,22 @@ from graphrag.index.storage import PipelineStorage
|
||||
async def create_base_entity_graph(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
text_column: str,
|
||||
id_column: str,
|
||||
clustering_strategy: dict[str, Any],
|
||||
embedding_strategy: dict[str, Any] | None,
|
||||
extraction_strategy: dict[str, Any] | None,
|
||||
extraction_num_threads: int = 4,
|
||||
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
entity_types: list[str] | None = None,
|
||||
node_merge_config: dict[str, Any] | None = None,
|
||||
edge_merge_config: dict[str, Any] | None = None,
|
||||
summarization_strategy: dict[str, Any] | None = None,
|
||||
summarization_num_threads: int = 4,
|
||||
embedding_strategy: dict[str, Any] | None = None,
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
raw_entity_snapshot_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to create the base entity graph."""
|
||||
@ -39,10 +53,22 @@ async def create_base_entity_graph(
|
||||
output = await create_base_entity_graph_flow(
|
||||
source,
|
||||
callbacks,
|
||||
cache,
|
||||
storage,
|
||||
clustering_strategy,
|
||||
embedding_strategy,
|
||||
text_column,
|
||||
id_column,
|
||||
clustering_strategy=clustering_strategy,
|
||||
extraction_strategy=extraction_strategy,
|
||||
extraction_num_threads=extraction_num_threads,
|
||||
extraction_async_mode=extraction_async_mode,
|
||||
entity_types=entity_types,
|
||||
node_merge_config=node_merge_config,
|
||||
edge_merge_config=edge_merge_config,
|
||||
summarization_strategy=summarization_strategy,
|
||||
summarization_num_threads=summarization_num_threads,
|
||||
embedding_strategy=embedding_strategy,
|
||||
graphml_snapshot_enabled=graphml_snapshot_enabled,
|
||||
raw_entity_snapshot_enabled=raw_entity_snapshot_enabled,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
|
||||
@ -1,63 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to extract and format base entities."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.create_base_extracted_entities import (
|
||||
create_base_extracted_entities as create_base_extracted_entities_flow,
|
||||
)
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
|
||||
|
||||
@verb(name="create_base_extracted_entities", treats_input_tables_as_immutable=True)
|
||||
async def create_base_extracted_entities(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
column: str,
|
||||
id_column: str,
|
||||
nodes: dict[str, Any],
|
||||
edges: dict[str, Any],
|
||||
extraction_strategy: dict[str, Any] | None,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
entity_types: list[str] | None = None,
|
||||
num_threads: int = 4,
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
raw_entity_snapshot_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to extract and format base entities."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
output = await create_base_extracted_entities_flow(
|
||||
source,
|
||||
callbacks,
|
||||
cache,
|
||||
storage,
|
||||
column,
|
||||
id_column,
|
||||
nodes,
|
||||
edges,
|
||||
extraction_strategy,
|
||||
async_mode=async_mode,
|
||||
entity_types=entity_types,
|
||||
graphml_snapshot_enabled=graphml_snapshot_enabled,
|
||||
raw_entity_snapshot_enabled=raw_entity_snapshot_enabled,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
@ -18,6 +18,7 @@ from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.create_final_documents import (
|
||||
create_final_documents as create_final_documents_flow,
|
||||
)
|
||||
from graphrag.index.utils.ds_util import get_required_input_table
|
||||
|
||||
|
||||
@verb(
|
||||
@ -28,16 +29,20 @@ async def create_final_documents(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
document_attribute_columns: list[str] | None = None,
|
||||
raw_content_text_embed: dict | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final documents."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table)
|
||||
|
||||
output = await create_final_documents_flow(
|
||||
source,
|
||||
text_units,
|
||||
callbacks,
|
||||
cache,
|
||||
document_attribute_columns=document_attribute_columns,
|
||||
raw_content_text_embed=raw_content_text_embed,
|
||||
)
|
||||
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to summarize entities."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.create_summarized_entities import (
|
||||
create_summarized_entities as create_summarized_entities_flow,
|
||||
)
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
|
||||
|
||||
@verb(
|
||||
name="create_summarized_entities",
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def create_summarized_entities(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
summarization_strategy: dict[str, Any] | None = None,
|
||||
num_threads: int = 4,
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to summarize entities."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
output = await create_summarized_entities_flow(
|
||||
source,
|
||||
callbacks,
|
||||
cache,
|
||||
storage,
|
||||
summarization_strategy,
|
||||
num_threads=num_threads,
|
||||
graphml_snapshot_enabled=graphml_snapshot_enabled,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
26
tests/fixtures/min-csv/config.json
vendored
26
tests/fixtures/min-csv/config.json
vendored
@ -10,29 +10,13 @@
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
},
|
||||
"create_base_extracted_entities": {
|
||||
"row_range": [
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_summarized_entities": {
|
||||
"row_range": [
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_base_entity_graph": {
|
||||
"row_range": [
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_final_entities": {
|
||||
"row_range": [
|
||||
@ -108,14 +92,6 @@
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 100
|
||||
},
|
||||
"create_base_documents": {
|
||||
"row_range": [
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
},
|
||||
"create_final_documents": {
|
||||
"row_range": [
|
||||
1,
|
||||
|
||||
26
tests/fixtures/text/config.json
vendored
26
tests/fixtures/text/config.json
vendored
@ -10,14 +10,6 @@
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
},
|
||||
"create_base_extracted_entities": {
|
||||
"row_range": [
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_final_covariates": {
|
||||
"row_range": [
|
||||
1,
|
||||
@ -35,21 +27,13 @@
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_summarized_entities": {
|
||||
"row_range": [
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_base_entity_graph": {
|
||||
"row_range": [
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_final_entities": {
|
||||
"row_range": [
|
||||
@ -125,14 +109,6 @@
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 100
|
||||
},
|
||||
"create_base_documents": {
|
||||
"row_range": [
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
},
|
||||
"create_final_documents": {
|
||||
"row_range": [
|
||||
1,
|
||||
|
||||
@ -19,9 +19,10 @@ workflows:
|
||||
# Just lump everything together
|
||||
chunk_by: []
|
||||
|
||||
- name: create_base_extracted_entities
|
||||
- name: create_base_entity_graph
|
||||
config:
|
||||
graphml_snapshot: True
|
||||
embed_graph_enabled: True
|
||||
entity_extract:
|
||||
strategy:
|
||||
type: graph_intelligence
|
||||
@ -37,9 +38,6 @@ workflows:
|
||||
("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2)
|
||||
##
|
||||
("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))'
|
||||
|
||||
- name: create_summarized_entities
|
||||
config:
|
||||
summarize_descriptions:
|
||||
strategy:
|
||||
type: graph_intelligence
|
||||
@ -47,11 +45,6 @@ workflows:
|
||||
type: static_response
|
||||
responses:
|
||||
- This is a MOCK response for the LLM. It is summarized!
|
||||
|
||||
- name: create_base_entity_graph
|
||||
config:
|
||||
graphml_snapshot: True
|
||||
embed_graph_enabled: True
|
||||
cluster_graph:
|
||||
strategy:
|
||||
type: leiden
|
||||
@ -59,8 +52,6 @@ workflows:
|
||||
|
||||
- name: create_final_nodes
|
||||
|
||||
- name: create_base_documents
|
||||
|
||||
- name: create_final_communities
|
||||
- name: create_final_text_units
|
||||
config:
|
||||
|
||||
@ -2,8 +2,6 @@
|
||||
# Licensed under the MIT License
|
||||
import unittest
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.index.operations.extract_entities.strategies.graph_intelligence import (
|
||||
run_extract_entities,
|
||||
)
|
||||
@ -119,8 +117,8 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
# self.assertItemsEqual isn't available yet, or I am just silly
|
||||
# so we sort the lists and compare them
|
||||
assert results.graphml_graph is not None, "No graphml graph returned!"
|
||||
graph = nx.parse_graphml(results.graphml_graph) # type: ignore
|
||||
graph = results.graph
|
||||
assert graph is not None, "No graph returned!"
|
||||
|
||||
# convert to strings for more visual comparison
|
||||
edges_str = sorted([f"{edge[0]} -> {edge[1]}" for edge in graph.edges])
|
||||
@ -162,8 +160,8 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
||||
),
|
||||
)
|
||||
|
||||
assert results.graphml_graph is not None, "No graphml graph returned!"
|
||||
graph = nx.parse_graphml(results.graphml_graph) # type: ignore
|
||||
graph = results.graph # type: ignore
|
||||
assert graph is not None, "No graph returned!"
|
||||
|
||||
# TODO: The edges might come back in any order, but we're assuming they're coming
|
||||
# back in the order that we passed in the docs, that might not be true
|
||||
@ -173,9 +171,11 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
||||
assert (
|
||||
graph.nodes["TEST_ENTITY_2"].get("source_id") == "1"
|
||||
) # TEST_ENTITY_2 should be in just 1
|
||||
assert sorted(
|
||||
graph.nodes["TEST_ENTITY_1"].get("source_id").split(",")
|
||||
) == sorted(["1", "2"]) # TEST_ENTITY_1 should be 1 and 2
|
||||
ids_str = graph.nodes["TEST_ENTITY_1"].get("source_id") or ""
|
||||
assert sorted(ids_str.split(",")) == sorted([
|
||||
"1",
|
||||
"2",
|
||||
]) # TEST_ENTITY_1 should be 1 and 2
|
||||
|
||||
async def test_run_extract_entities_multiple_documents_correct_edge_source_ids_mapped(
|
||||
self,
|
||||
@ -210,8 +210,8 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
||||
),
|
||||
)
|
||||
|
||||
assert results.graphml_graph is not None, "No graphml graph returned!"
|
||||
graph = nx.parse_graphml(results.graphml_graph) # type: ignore
|
||||
graph = results.graph # type: ignore
|
||||
assert graph is not None, "No graph returned!"
|
||||
edges = list(graph.edges(data=True))
|
||||
|
||||
# should only have 2 edges
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,58 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.index.workflows.v1.create_base_documents import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
compare_outputs,
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_expected,
|
||||
load_input_tables,
|
||||
)
|
||||
|
||||
|
||||
async def test_create_base_documents():
|
||||
input_tables = load_input_tables(["workflow:create_final_text_units"])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
|
||||
async def test_create_base_documents_with_attribute_columns():
|
||||
input_tables = load_input_tables(["workflow:create_final_text_units"])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["document_attribute_columns"] = ["title"]
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
# we should have dropped "title" and added "attributes"
|
||||
# our test dataframe does not have attributes, so we'll assert without it
|
||||
# and separately confirm it is in the output
|
||||
compare_outputs(actual, expected, columns=["id", "text_units", "raw_content"])
|
||||
assert len(actual.columns) == 4
|
||||
assert "attributes" in actual.columns
|
||||
@ -2,7 +2,9 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
import networkx as nx
|
||||
import pytest
|
||||
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.index.workflows.v1.create_base_entity_graph import (
|
||||
build_steps,
|
||||
@ -16,16 +18,48 @@ from .util import (
|
||||
load_input_tables,
|
||||
)
|
||||
|
||||
MOCK_LLM_ENTITY_RESPONSES = [
|
||||
"""
|
||||
("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company)
|
||||
##
|
||||
("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A)
|
||||
##
|
||||
("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A)
|
||||
##
|
||||
("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2)
|
||||
##
|
||||
("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))
|
||||
""".strip()
|
||||
]
|
||||
|
||||
MOCK_LLM_ENTITY_CONFIG = {
|
||||
"type": LLMType.StaticResponse,
|
||||
"responses": MOCK_LLM_ENTITY_RESPONSES,
|
||||
}
|
||||
|
||||
MOCK_LLM_SUMMARIZATION_RESPONSES = [
|
||||
"""
|
||||
This is a MOCK response for the LLM. It is summarized!
|
||||
""".strip()
|
||||
]
|
||||
|
||||
MOCK_LLM_SUMMARIZATION_CONFIG = {
|
||||
"type": LLMType.StaticResponse,
|
||||
"responses": MOCK_LLM_SUMMARIZATION_RESPONSES,
|
||||
}
|
||||
|
||||
|
||||
async def test_create_base_entity_graph():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_summarized_entities",
|
||||
"workflow:create_base_text_units",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
@ -37,34 +71,36 @@ async def test_create_base_entity_graph():
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
# the serialization of the graph may differ so we can't assert the dataframes directly
|
||||
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
||||
|
||||
assert len(actual.columns) == len(
|
||||
expected.columns
|
||||
), "Graph dataframe columns differ"
|
||||
# let's parse a sample of the raw graphml
|
||||
actual_graphml_0 = actual["clustered_graph"][:1][0]
|
||||
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
|
||||
|
||||
expected_graphml_0 = expected["clustered_graph"][:1][0]
|
||||
expected_graph_0 = nx.parse_graphml(expected_graphml_0)
|
||||
assert actual_graph_0.number_of_nodes() == 3
|
||||
assert actual_graph_0.number_of_edges() == 2
|
||||
|
||||
assert (
|
||||
actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes()
|
||||
), "Graphml node count differs"
|
||||
assert (
|
||||
actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges()
|
||||
), "Graphml edge count differs"
|
||||
# TODO: with the combined verb we can't force summarization
|
||||
# this is because the mock responses always result in a single description, which is returned verbatim rather than summarized
|
||||
# we need to update the mocking to provide somewhat unique graphs so a true merge happens
|
||||
# the assertion should grab a node and ensure the description matches the mock description, not the original as we are doing below
|
||||
nodes = list(actual_graph_0.nodes(data=True))
|
||||
assert nodes[0][1]["description"] == "Company_A is a test company"
|
||||
|
||||
assert len(storage.keys()) == 0, "Storage should be empty"
|
||||
|
||||
|
||||
async def test_create_base_entity_graph_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_summarized_entities",
|
||||
"workflow:create_base_text_units",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG
|
||||
config["embed_graph_enabled"] = True
|
||||
|
||||
steps = build_steps(config)
|
||||
@ -84,19 +120,22 @@ async def test_create_base_entity_graph_with_embeddings():
|
||||
|
||||
async def test_create_base_entity_graph_with_snapshots():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_summarized_entities",
|
||||
"workflow:create_base_text_units",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG
|
||||
config["raw_entity_snapshot"] = True
|
||||
config["graphml_snapshot"] = True
|
||||
config["embed_graph_enabled"] = True # need this on in order to see the snapshot
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
@ -104,15 +143,31 @@ async def test_create_base_entity_graph_with_snapshots():
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
||||
|
||||
assert storage.keys() == [
|
||||
"clustered_graph.0.graphml",
|
||||
"clustered_graph.1.graphml",
|
||||
"clustered_graph.2.graphml",
|
||||
"clustered_graph.3.graphml",
|
||||
"embedded_graph.0.graphml",
|
||||
"embedded_graph.1.graphml",
|
||||
"embedded_graph.2.graphml",
|
||||
"embedded_graph.3.graphml",
|
||||
"raw_extracted_entities.json",
|
||||
"merged_graph.graphml",
|
||||
"summarized_graph.graphml",
|
||||
"clustered_graph.graphml",
|
||||
"embedded_graph.graphml",
|
||||
], "Graph snapshot keys differ"
|
||||
|
||||
|
||||
async def test_create_base_entity_graph_missing_llm_throws():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_text_units",
|
||||
])
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||
del config["summarize_descriptions"]["strategy"]["llm"]
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
with pytest.raises(ValueError): # noqa PT011
|
||||
await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
@ -1,118 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
import networkx as nx
|
||||
import pytest
|
||||
from datashaper.errors import VerbParallelizationError
|
||||
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.index.workflows.v1.create_base_extracted_entities import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_expected,
|
||||
load_input_tables,
|
||||
)
|
||||
|
||||
MOCK_LLM_RESPONSES = [
|
||||
"""
|
||||
("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company)
|
||||
##
|
||||
("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A)
|
||||
##
|
||||
("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A)
|
||||
##
|
||||
("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2)
|
||||
##
|
||||
("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))
|
||||
""".strip()
|
||||
]
|
||||
|
||||
MOCK_LLM_CONFIG = {
|
||||
"type": LLMType.StaticResponse,
|
||||
"responses": MOCK_LLM_RESPONSES,
|
||||
}
|
||||
|
||||
|
||||
async def test_create_base_extracted_entities():
|
||||
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
# let's parse a sample of the raw graphml
|
||||
actual_graphml_0 = actual["entity_graph"][:1][0]
|
||||
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
|
||||
|
||||
assert actual_graph_0.number_of_nodes() == 3
|
||||
assert actual_graph_0.number_of_edges() == 2
|
||||
|
||||
assert actual.columns == expected.columns
|
||||
|
||||
assert len(storage.keys()) == 0, "Storage should be empty"
|
||||
|
||||
|
||||
async def test_create_base_extracted_entities_with_snapshots():
|
||||
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG
|
||||
config["raw_entity_snapshot"] = True
|
||||
config["graphml_snapshot"] = True
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
print(storage.keys())
|
||||
|
||||
assert actual.columns == expected.columns
|
||||
|
||||
assert storage.keys() == ["raw_extracted_entities.json", "merged_graph.graphml"]
|
||||
|
||||
|
||||
async def test_create_base_extracted_entities_missing_llm_throws():
|
||||
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
del config["entity_extract"]["strategy"]["llm"]
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
with pytest.raises(VerbParallelizationError):
|
||||
await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
@ -17,7 +17,7 @@ from .util import (
|
||||
|
||||
async def test_create_final_documents():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_documents",
|
||||
"workflow:create_final_text_units",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
@ -39,7 +39,7 @@ async def test_create_final_documents():
|
||||
|
||||
async def test_create_final_documents_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_documents",
|
||||
"workflow:create_final_text_units",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
@ -63,3 +63,28 @@ async def test_create_final_documents_with_embeddings():
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["raw_content_embedding"][:1][0]) == 3
|
||||
|
||||
|
||||
async def test_create_final_documents_with_attribute_columns():
|
||||
input_tables = load_input_tables(["workflow:create_final_text_units"])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["document_attribute_columns"] = ["title"]
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
# we should have dropped "title" and added "attributes"
|
||||
# our test dataframe does not have attributes, so we'll assert without it
|
||||
# and separately confirm it is in the output
|
||||
compare_outputs(actual, expected, columns=["id", "text_unit_ids", "raw_content"])
|
||||
assert len(actual.columns) == 4
|
||||
assert "attributes" in actual.columns
|
||||
|
||||
@ -1,129 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
import networkx as nx
|
||||
import pytest
|
||||
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.index.workflows.v1.create_summarized_entities import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_expected,
|
||||
load_input_tables,
|
||||
)
|
||||
|
||||
MOCK_LLM_RESPONSES = [
|
||||
"""
|
||||
This is a MOCK response for the LLM. It is summarized!
|
||||
""".strip()
|
||||
]
|
||||
|
||||
MOCK_LLM_CONFIG = {
|
||||
"type": LLMType.StaticResponse,
|
||||
"responses": MOCK_LLM_RESPONSES,
|
||||
}
|
||||
|
||||
|
||||
async def test_create_summarized_entities():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_extracted_entities",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_CONFIG
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
# the serialization of the graph may differ so we can't assert the dataframes directly
|
||||
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
||||
|
||||
# let's parse a sample of the raw graphml
|
||||
actual_graphml_0 = actual["entity_graph"][:1][0]
|
||||
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
|
||||
|
||||
expected_graphml_0 = expected["entity_graph"][:1][0]
|
||||
expected_graph_0 = nx.parse_graphml(expected_graphml_0)
|
||||
|
||||
assert (
|
||||
actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes()
|
||||
), "Graphml node count differs"
|
||||
assert (
|
||||
actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges()
|
||||
), "Graphml edge count differs"
|
||||
|
||||
# ensure the mock summary was injected to the nodes
|
||||
nodes = list(actual_graph_0.nodes(data=True))
|
||||
assert (
|
||||
nodes[0][1]["description"]
|
||||
== "This is a MOCK response for the LLM. It is summarized!"
|
||||
)
|
||||
|
||||
assert len(storage.keys()) == 0, "Storage should be empty"
|
||||
|
||||
|
||||
async def test_create_summarized_entities_with_snapshots():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_extracted_entities",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_CONFIG
|
||||
config["graphml_snapshot"] = True
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
||||
|
||||
assert storage.keys() == [
|
||||
"summarized_graph.graphml",
|
||||
], "Graph snapshot keys differ"
|
||||
|
||||
|
||||
async def test_create_summarized_entities_missing_llm_throws():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_extracted_entities",
|
||||
])
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
del config["summarize_descriptions"]["strategy"]["llm"]
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
with pytest.raises(ValueError): # noqa PT011
|
||||
await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
@ -26,7 +26,7 @@ def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
|
||||
# all workflows implicitly receive the `input` source, which is formatted as a dataframe after loading from storage
|
||||
# we'll simulate that by just loading one of our output parquets and converting back to equivalent dataframe
|
||||
# so we aren't dealing with storage vagaries (which would become an integration test)
|
||||
source = pd.read_parquet("tests/verbs/data/create_base_documents.parquet")
|
||||
source = pd.read_parquet("tests/verbs/data/create_final_documents.parquet")
|
||||
source.rename(columns={"raw_content": "text"}, inplace=True)
|
||||
input_tables["source"] = cast(pd.DataFrame, source[["id", "text", "title"]])
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user