mirror of
https://github.com/microsoft/graphrag.git
synced 2025-11-09 14:24:07 +00:00
Flow cleanup (#1510)
* Move snapshots out of flows into verbs * Move degree compute out of extract_graph * Move entity/relationship df merging into extract * Move "title" to extraction source * Move text_unit_ids agg closer to extraction * Move data definition * Update test data * Semver * Update smoke tests * Fix empty degree field and update smoke tests and verb data * Move extractors (#1516) * Consolidate graph embedding and umap * Consolidate claim extraction * Consolidate graph extractor * Move graph utils * Move summarizers * Semver --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com> * Fix syntax typo --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
d0543d1fd6
commit
c1c09bab80
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Streamline flows."
|
||||||
|
}
|
||||||
@ -9,15 +9,11 @@ import pandas as pd
|
|||||||
|
|
||||||
from graphrag.index.operations.cluster_graph import cluster_graph
|
from graphrag.index.operations.cluster_graph import cluster_graph
|
||||||
from graphrag.index.operations.create_graph import create_graph
|
from graphrag.index.operations.create_graph import create_graph
|
||||||
from graphrag.index.operations.snapshot import snapshot
|
|
||||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
|
||||||
|
|
||||||
|
|
||||||
async def compute_communities(
|
def compute_communities(
|
||||||
base_relationship_edges: pd.DataFrame,
|
base_relationship_edges: pd.DataFrame,
|
||||||
storage: PipelineStorage,
|
|
||||||
clustering_strategy: dict[str, Any],
|
clustering_strategy: dict[str, Any],
|
||||||
snapshot_transient_enabled: bool = False,
|
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""All the steps to create the base entity graph."""
|
"""All the steps to create the base entity graph."""
|
||||||
graph = create_graph(base_relationship_edges)
|
graph = create_graph(base_relationship_edges)
|
||||||
@ -32,12 +28,4 @@ async def compute_communities(
|
|||||||
).explode("title")
|
).explode("title")
|
||||||
base_communities["community"] = base_communities["community"].astype(int)
|
base_communities["community"] = base_communities["community"].astype(int)
|
||||||
|
|
||||||
if snapshot_transient_enabled:
|
|
||||||
await snapshot(
|
|
||||||
base_communities,
|
|
||||||
name="base_communities",
|
|
||||||
storage=storage,
|
|
||||||
formats=["parquet"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return base_communities
|
return base_communities
|
||||||
|
|||||||
@ -15,18 +15,14 @@ from datashaper import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from graphrag.index.operations.chunk_text import chunk_text
|
from graphrag.index.operations.chunk_text import chunk_text
|
||||||
from graphrag.index.operations.snapshot import snapshot
|
|
||||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
|
||||||
|
|
||||||
|
|
||||||
async def create_base_text_units(
|
def create_base_text_units(
|
||||||
documents: pd.DataFrame,
|
documents: pd.DataFrame,
|
||||||
callbacks: VerbCallbacks,
|
callbacks: VerbCallbacks,
|
||||||
storage: PipelineStorage,
|
|
||||||
chunk_by_columns: list[str],
|
chunk_by_columns: list[str],
|
||||||
chunk_strategy: dict[str, Any] | None = None,
|
chunk_strategy: dict[str, Any] | None = None,
|
||||||
snapshot_transient_enabled: bool = False,
|
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""All the steps to transform base text_units."""
|
"""All the steps to transform base text_units."""
|
||||||
sort = documents.sort_values(by=["id"], ascending=[True])
|
sort = documents.sort_values(by=["id"], ascending=[True])
|
||||||
@ -74,19 +70,7 @@ async def create_base_text_units(
|
|||||||
# rename for downstream consumption
|
# rename for downstream consumption
|
||||||
chunked.rename(columns={"chunk": "text"}, inplace=True)
|
chunked.rename(columns={"chunk": "text"}, inplace=True)
|
||||||
|
|
||||||
output = cast(
|
return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))
|
||||||
"pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
if snapshot_transient_enabled:
|
|
||||||
await snapshot(
|
|
||||||
output,
|
|
||||||
name="create_base_text_units",
|
|
||||||
storage=storage,
|
|
||||||
formats=["parquet"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: would be nice to inline this completely in the main method with pandas
|
# TODO: would be nice to inline this completely in the main method with pandas
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from datashaper import (
|
|||||||
VerbCallbacks,
|
VerbCallbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from graphrag.index.operations.compute_degree import compute_degree
|
||||||
from graphrag.index.operations.create_graph import create_graph
|
from graphrag.index.operations.create_graph import create_graph
|
||||||
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
|
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
|
||||||
from graphrag.index.operations.layout_graph.layout_graph import layout_graph
|
from graphrag.index.operations.layout_graph.layout_graph import layout_graph
|
||||||
@ -37,15 +38,19 @@ def create_final_nodes(
|
|||||||
layout_strategy,
|
layout_strategy,
|
||||||
embeddings=graph_embeddings,
|
embeddings=graph_embeddings,
|
||||||
)
|
)
|
||||||
nodes = base_entity_nodes.merge(
|
|
||||||
layout, left_on="title", right_on="label", how="left"
|
degrees = compute_degree(graph)
|
||||||
|
|
||||||
|
nodes = (
|
||||||
|
base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left")
|
||||||
|
.merge(degrees, on="title", how="left")
|
||||||
|
.merge(base_communities, on="title", how="left")
|
||||||
)
|
)
|
||||||
|
nodes["level"] = nodes["level"].fillna(0).astype(int)
|
||||||
joined = nodes.merge(base_communities, on="title", how="left")
|
nodes["community"] = nodes["community"].fillna(-1).astype(int)
|
||||||
joined["level"] = joined["level"].fillna(0).astype(int)
|
# disconnected nodes and those with no community even at level 0 can be missing degree
|
||||||
joined["community"] = joined["community"].fillna(-1).astype(int)
|
nodes["degree"] = nodes["degree"].fillna(0).astype(int)
|
||||||
|
return nodes.loc[
|
||||||
return joined.loc[
|
|
||||||
:,
|
:,
|
||||||
[
|
[
|
||||||
"id",
|
"id",
|
||||||
|
|||||||
@ -5,20 +5,25 @@
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
from graphrag.index.operations.compute_degree import compute_degree
|
||||||
from graphrag.index.operations.compute_edge_combined_degree import (
|
from graphrag.index.operations.compute_edge_combined_degree import (
|
||||||
compute_edge_combined_degree,
|
compute_edge_combined_degree,
|
||||||
)
|
)
|
||||||
|
from graphrag.index.operations.create_graph import create_graph
|
||||||
|
|
||||||
|
|
||||||
def create_final_relationships(
|
def create_final_relationships(
|
||||||
base_relationship_edges: pd.DataFrame,
|
base_relationship_edges: pd.DataFrame,
|
||||||
base_entity_nodes: pd.DataFrame,
|
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""All the steps to transform final relationships."""
|
"""All the steps to transform final relationships."""
|
||||||
relationships = base_relationship_edges
|
relationships = base_relationship_edges
|
||||||
|
|
||||||
|
graph = create_graph(base_relationship_edges)
|
||||||
|
degrees = compute_degree(graph)
|
||||||
|
|
||||||
relationships["combined_degree"] = compute_edge_combined_degree(
|
relationships["combined_degree"] = compute_edge_combined_degree(
|
||||||
relationships,
|
relationships,
|
||||||
base_entity_nodes,
|
degrees,
|
||||||
node_name_column="title",
|
node_name_column="title",
|
||||||
node_degree_column="degree",
|
node_degree_column="degree",
|
||||||
edge_source_column="source",
|
edge_source_column="source",
|
||||||
|
|||||||
@ -6,7 +6,6 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from datashaper import (
|
from datashaper import (
|
||||||
AsyncType,
|
AsyncType,
|
||||||
@ -14,33 +13,26 @@ from datashaper import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from graphrag.cache.pipeline_cache import PipelineCache
|
from graphrag.cache.pipeline_cache import PipelineCache
|
||||||
from graphrag.index.operations.create_graph import create_graph
|
|
||||||
from graphrag.index.operations.extract_entities import extract_entities
|
from graphrag.index.operations.extract_entities import extract_entities
|
||||||
from graphrag.index.operations.snapshot import snapshot
|
|
||||||
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
|
|
||||||
from graphrag.index.operations.summarize_descriptions import (
|
from graphrag.index.operations.summarize_descriptions import (
|
||||||
summarize_descriptions,
|
summarize_descriptions,
|
||||||
)
|
)
|
||||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_graph(
|
async def extract_graph(
|
||||||
text_units: pd.DataFrame,
|
text_units: pd.DataFrame,
|
||||||
callbacks: VerbCallbacks,
|
callbacks: VerbCallbacks,
|
||||||
cache: PipelineCache,
|
cache: PipelineCache,
|
||||||
storage: PipelineStorage,
|
|
||||||
extraction_strategy: dict[str, Any] | None = None,
|
extraction_strategy: dict[str, Any] | None = None,
|
||||||
extraction_num_threads: int = 4,
|
extraction_num_threads: int = 4,
|
||||||
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
|
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
|
||||||
entity_types: list[str] | None = None,
|
entity_types: list[str] | None = None,
|
||||||
summarization_strategy: dict[str, Any] | None = None,
|
summarization_strategy: dict[str, Any] | None = None,
|
||||||
summarization_num_threads: int = 4,
|
summarization_num_threads: int = 4,
|
||||||
snapshot_graphml_enabled: bool = False,
|
|
||||||
snapshot_transient_enabled: bool = False,
|
|
||||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||||
"""All the steps to create the base entity graph."""
|
"""All the steps to create the base entity graph."""
|
||||||
# this returns a graph for each text unit, to be merged later
|
# this returns a graph for each text unit, to be merged later
|
||||||
entity_dfs, relationship_dfs = await extract_entities(
|
entities, relationships = await extract_entities(
|
||||||
text_units,
|
text_units,
|
||||||
callbacks,
|
callbacks,
|
||||||
cache,
|
cache,
|
||||||
@ -52,87 +44,38 @@ async def extract_graph(
|
|||||||
num_threads=extraction_num_threads,
|
num_threads=extraction_num_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not _validate_data(entity_dfs):
|
if not _validate_data(entities):
|
||||||
error_msg = "Entity Extraction failed. No entities detected during extraction."
|
error_msg = "Entity Extraction failed. No entities detected during extraction."
|
||||||
callbacks.error(error_msg)
|
callbacks.error(error_msg)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
if not _validate_data(relationship_dfs):
|
if not _validate_data(relationships):
|
||||||
error_msg = (
|
error_msg = (
|
||||||
"Entity Extraction failed. No relationships detected during extraction."
|
"Entity Extraction failed. No relationships detected during extraction."
|
||||||
)
|
)
|
||||||
callbacks.error(error_msg)
|
callbacks.error(error_msg)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
merged_entities = _merge_entities(entity_dfs)
|
|
||||||
merged_relationships = _merge_relationships(relationship_dfs)
|
|
||||||
|
|
||||||
entity_summaries, relationship_summaries = await summarize_descriptions(
|
entity_summaries, relationship_summaries = await summarize_descriptions(
|
||||||
merged_entities,
|
entities,
|
||||||
merged_relationships,
|
relationships,
|
||||||
callbacks,
|
callbacks,
|
||||||
cache,
|
cache,
|
||||||
strategy=summarization_strategy,
|
strategy=summarization_strategy,
|
||||||
num_threads=summarization_num_threads,
|
num_threads=summarization_num_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)
|
base_relationship_edges = _prep_edges(relationships, relationship_summaries)
|
||||||
|
|
||||||
graph = create_graph(base_relationship_edges)
|
base_entity_nodes = _prep_nodes(entities, entity_summaries)
|
||||||
|
|
||||||
base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)
|
|
||||||
|
|
||||||
if snapshot_graphml_enabled:
|
|
||||||
# todo: extract graphs at each level, and add in meta like descriptions
|
|
||||||
await snapshot_graphml(
|
|
||||||
graph,
|
|
||||||
name="graph",
|
|
||||||
storage=storage,
|
|
||||||
)
|
|
||||||
|
|
||||||
if snapshot_transient_enabled:
|
|
||||||
await snapshot(
|
|
||||||
base_entity_nodes,
|
|
||||||
name="base_entity_nodes",
|
|
||||||
storage=storage,
|
|
||||||
formats=["parquet"],
|
|
||||||
)
|
|
||||||
await snapshot(
|
|
||||||
base_relationship_edges,
|
|
||||||
name="base_relationship_edges",
|
|
||||||
storage=storage,
|
|
||||||
formats=["parquet"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return (base_entity_nodes, base_relationship_edges)
|
return (base_entity_nodes, base_relationship_edges)
|
||||||
|
|
||||||
|
|
||||||
def _merge_entities(entity_dfs) -> pd.DataFrame:
|
def _prep_nodes(entities, summaries) -> pd.DataFrame:
|
||||||
all_entities = pd.concat(entity_dfs, ignore_index=True)
|
|
||||||
return (
|
|
||||||
all_entities.groupby(["name", "type"], sort=False)
|
|
||||||
.agg({"description": list, "source_id": list})
|
|
||||||
.reset_index()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
|
|
||||||
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
|
|
||||||
return (
|
|
||||||
all_relationships.groupby(["source", "target"], sort=False)
|
|
||||||
.agg({"description": list, "source_id": list, "weight": "sum"})
|
|
||||||
.reset_index()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
|
|
||||||
degrees_df = _compute_degree(graph)
|
|
||||||
entities.drop(columns=["description"], inplace=True)
|
entities.drop(columns=["description"], inplace=True)
|
||||||
nodes = (
|
nodes = entities.merge(summaries, on="title", how="left").drop_duplicates(
|
||||||
entities.merge(summaries, on="name", how="left")
|
subset="title"
|
||||||
.merge(degrees_df, on="name")
|
|
||||||
.drop_duplicates(subset="name")
|
|
||||||
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
|
|
||||||
)
|
)
|
||||||
nodes = nodes.loc[nodes["title"].notna()].reset_index()
|
nodes = nodes.loc[nodes["title"].notna()].reset_index()
|
||||||
nodes["human_readable_id"] = nodes.index
|
nodes["human_readable_id"] = nodes.index
|
||||||
@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
|
|||||||
relationships.drop(columns=["description"])
|
relationships.drop(columns=["description"])
|
||||||
.drop_duplicates(subset=["source", "target"])
|
.drop_duplicates(subset=["source", "target"])
|
||||||
.merge(summaries, on=["source", "target"], how="left")
|
.merge(summaries, on=["source", "target"], how="left")
|
||||||
.rename(columns={"source_id": "text_unit_ids"})
|
|
||||||
)
|
)
|
||||||
edges["human_readable_id"] = edges.index
|
edges["human_readable_id"] = edges.index
|
||||||
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
|
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
|
def _validate_data(df: pd.DataFrame) -> bool:
|
||||||
return pd.DataFrame([
|
"""Validate that the dataframe has data."""
|
||||||
{"name": node, "degree": int(degree)}
|
return len(df) > 0
|
||||||
for node, degree in graph.degree # type: ignore
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_data(df_list: list[pd.DataFrame]) -> bool:
|
|
||||||
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
|
|
||||||
return any(
|
|
||||||
len(df) > 0 for df in df_list
|
|
||||||
) # Check for len, not .empty, as the dfs have schemas in some cases
|
|
||||||
|
|||||||
@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings(
|
|||||||
strategy=text_embed_config["strategy"],
|
strategy=text_embed_config["strategy"],
|
||||||
)
|
)
|
||||||
|
|
||||||
data = data.loc[:, ["id", "embedding"]]
|
|
||||||
|
|
||||||
if snapshot_embeddings_enabled is True:
|
if snapshot_embeddings_enabled is True:
|
||||||
|
data = data.loc[:, ["id", "embedding"]]
|
||||||
await snapshot(
|
await snapshot(
|
||||||
data,
|
data,
|
||||||
name=f"embeddings.{name}",
|
name=f"embeddings.{name}",
|
||||||
|
|||||||
15
graphrag/index/operations/compute_degree.py
Normal file
15
graphrag/index/operations/compute_degree.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
"""A module containing create_graph definition."""
|
||||||
|
|
||||||
|
import networkx as nx
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def compute_degree(graph: nx.Graph) -> pd.DataFrame:
|
||||||
|
"""Create a new DataFrame with the degree of each node in the graph."""
|
||||||
|
return pd.DataFrame([
|
||||||
|
{"title": node, "degree": int(degree)}
|
||||||
|
for node, degree in graph.degree # type: ignore
|
||||||
|
])
|
||||||
@ -37,7 +37,7 @@ async def extract_entities(
|
|||||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||||
entity_types=DEFAULT_ENTITY_TYPES,
|
entity_types=DEFAULT_ENTITY_TYPES,
|
||||||
num_threads: int = 4,
|
num_threads: int = 4,
|
||||||
) -> tuple[list[pd.DataFrame], list[pd.DataFrame]]:
|
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||||
"""
|
"""
|
||||||
Extract entities from a piece of text.
|
Extract entities from a piece of text.
|
||||||
|
|
||||||
@ -138,7 +138,10 @@ async def extract_entities(
|
|||||||
entity_dfs.append(pd.DataFrame(result[0]))
|
entity_dfs.append(pd.DataFrame(result[0]))
|
||||||
relationship_dfs.append(pd.DataFrame(result[1]))
|
relationship_dfs.append(pd.DataFrame(result[1]))
|
||||||
|
|
||||||
return (entity_dfs, relationship_dfs)
|
entities = _merge_entities(entity_dfs)
|
||||||
|
relationships = _merge_relationships(relationship_dfs)
|
||||||
|
|
||||||
|
return (entities, relationships)
|
||||||
|
|
||||||
|
|
||||||
def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
|
def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
|
||||||
@ -162,3 +165,25 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr
|
|||||||
case _:
|
case _:
|
||||||
msg = f"Unknown strategy: {strategy_type}"
|
msg = f"Unknown strategy: {strategy_type}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_entities(entity_dfs) -> pd.DataFrame:
|
||||||
|
all_entities = pd.concat(entity_dfs, ignore_index=True)
|
||||||
|
return (
|
||||||
|
all_entities.groupby(["title", "type"], sort=False)
|
||||||
|
.agg(description=("description", list), text_unit_ids=("source_id", list))
|
||||||
|
.reset_index()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
|
||||||
|
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
|
||||||
|
return (
|
||||||
|
all_relationships.groupby(["source", "target"], sort=False)
|
||||||
|
.agg(
|
||||||
|
description=("description", list),
|
||||||
|
text_unit_ids=("source_id", list),
|
||||||
|
weight=("weight", "sum"),
|
||||||
|
)
|
||||||
|
.reset_index()
|
||||||
|
)
|
||||||
|
|||||||
@ -106,7 +106,7 @@ async def run_extract_entities(
|
|||||||
)
|
)
|
||||||
|
|
||||||
entities = [
|
entities = [
|
||||||
({"name": item[0], **(item[1] or {})})
|
({"title": item[0], **(item[1] or {})})
|
||||||
for item in graph.nodes(data=True)
|
for item in graph.nodes(data=True)
|
||||||
if item is not None
|
if item is not None
|
||||||
]
|
]
|
||||||
|
|||||||
@ -58,7 +58,7 @@ async def run( # noqa RUF029 async is required for interface
|
|||||||
|
|
||||||
return EntityExtractionResult(
|
return EntityExtractionResult(
|
||||||
entities=[
|
entities=[
|
||||||
{"type": entity_type, "name": name}
|
{"type": entity_type, "title": name}
|
||||||
for name, entity_type in entity_map.items()
|
for name, entity_type in entity_map.items()
|
||||||
],
|
],
|
||||||
relationships=[],
|
relationships=[],
|
||||||
|
|||||||
@ -88,7 +88,7 @@ async def summarize_descriptions(
|
|||||||
|
|
||||||
node_futures = [
|
node_futures = [
|
||||||
do_summarize_descriptions(
|
do_summarize_descriptions(
|
||||||
str(row[1]["name"]),
|
str(row[1]["title"]),
|
||||||
sorted(set(row[1]["description"])),
|
sorted(set(row[1]["description"])),
|
||||||
ticker,
|
ticker,
|
||||||
semaphore,
|
semaphore,
|
||||||
@ -100,7 +100,7 @@ async def summarize_descriptions(
|
|||||||
|
|
||||||
node_descriptions = [
|
node_descriptions = [
|
||||||
{
|
{
|
||||||
"name": result.id,
|
"title": result.id,
|
||||||
"description": result.description,
|
"description": result.description,
|
||||||
}
|
}
|
||||||
for result in node_results
|
for result in node_results
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from datashaper.table_store.types import VerbResult, create_verb_result
|
|||||||
|
|
||||||
from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
|
from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||||
from graphrag.index.flows.compute_communities import compute_communities
|
from graphrag.index.flows.compute_communities import compute_communities
|
||||||
|
from graphrag.index.operations.snapshot import snapshot
|
||||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||||
|
|
||||||
workflow_name = "compute_communities"
|
workflow_name = "compute_communities"
|
||||||
@ -62,13 +63,19 @@ async def workflow(
|
|||||||
"""All the steps to create the base entity graph."""
|
"""All the steps to create the base entity graph."""
|
||||||
base_relationship_edges = await runtime_storage.get("base_relationship_edges")
|
base_relationship_edges = await runtime_storage.get("base_relationship_edges")
|
||||||
|
|
||||||
base_communities = await compute_communities(
|
base_communities = compute_communities(
|
||||||
base_relationship_edges,
|
base_relationship_edges,
|
||||||
storage,
|
|
||||||
clustering_strategy=clustering_strategy,
|
clustering_strategy=clustering_strategy,
|
||||||
snapshot_transient_enabled=snapshot_transient_enabled,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await runtime_storage.set("base_communities", base_communities)
|
await runtime_storage.set("base_communities", base_communities)
|
||||||
|
|
||||||
|
if snapshot_transient_enabled:
|
||||||
|
await snapshot(
|
||||||
|
base_communities,
|
||||||
|
name="base_communities",
|
||||||
|
storage=storage,
|
||||||
|
formats=["parquet"],
|
||||||
|
)
|
||||||
|
|
||||||
return create_verb_result(cast("Table", pd.DataFrame()))
|
return create_verb_result(cast("Table", pd.DataFrame()))
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkf
|
|||||||
from graphrag.index.flows.create_base_text_units import (
|
from graphrag.index.flows.create_base_text_units import (
|
||||||
create_base_text_units,
|
create_base_text_units,
|
||||||
)
|
)
|
||||||
|
from graphrag.index.operations.snapshot import snapshot
|
||||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||||
|
|
||||||
workflow_name = "create_base_text_units"
|
workflow_name = "create_base_text_units"
|
||||||
@ -65,17 +66,23 @@ async def workflow(
|
|||||||
"""All the steps to transform base text_units."""
|
"""All the steps to transform base text_units."""
|
||||||
source = cast("pd.DataFrame", input.get_input())
|
source = cast("pd.DataFrame", input.get_input())
|
||||||
|
|
||||||
output = await create_base_text_units(
|
output = create_base_text_units(
|
||||||
source,
|
source,
|
||||||
callbacks,
|
callbacks,
|
||||||
storage,
|
|
||||||
chunk_by_columns,
|
chunk_by_columns,
|
||||||
chunk_strategy=chunk_strategy,
|
chunk_strategy=chunk_strategy,
|
||||||
snapshot_transient_enabled=snapshot_transient_enabled,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await runtime_storage.set("base_text_units", output)
|
await runtime_storage.set("base_text_units", output)
|
||||||
|
|
||||||
|
if snapshot_transient_enabled:
|
||||||
|
await snapshot(
|
||||||
|
output,
|
||||||
|
name="create_base_text_units",
|
||||||
|
storage=storage,
|
||||||
|
formats=["parquet"],
|
||||||
|
)
|
||||||
|
|
||||||
return create_verb_result(
|
return create_verb_result(
|
||||||
cast(
|
cast(
|
||||||
"Table",
|
"Table",
|
||||||
|
|||||||
@ -53,8 +53,7 @@ async def workflow(
|
|||||||
) -> VerbResult:
|
) -> VerbResult:
|
||||||
"""All the steps to transform final relationships."""
|
"""All the steps to transform final relationships."""
|
||||||
base_relationship_edges = await runtime_storage.get("base_relationship_edges")
|
base_relationship_edges = await runtime_storage.get("base_relationship_edges")
|
||||||
base_entity_nodes = await runtime_storage.get("base_entity_nodes")
|
|
||||||
|
|
||||||
output = create_final_relationships(base_relationship_edges, base_entity_nodes)
|
output = create_final_relationships(base_relationship_edges)
|
||||||
|
|
||||||
return create_verb_result(cast("Table", output))
|
return create_verb_result(cast("Table", output))
|
||||||
|
|||||||
@ -19,6 +19,9 @@ from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkf
|
|||||||
from graphrag.index.flows.extract_graph import (
|
from graphrag.index.flows.extract_graph import (
|
||||||
extract_graph,
|
extract_graph,
|
||||||
)
|
)
|
||||||
|
from graphrag.index.operations.create_graph import create_graph
|
||||||
|
from graphrag.index.operations.snapshot import snapshot
|
||||||
|
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
|
||||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||||
|
|
||||||
workflow_name = "extract_graph"
|
workflow_name = "extract_graph"
|
||||||
@ -90,18 +93,38 @@ async def workflow(
|
|||||||
text_units,
|
text_units,
|
||||||
callbacks,
|
callbacks,
|
||||||
cache,
|
cache,
|
||||||
storage,
|
|
||||||
extraction_strategy=extraction_strategy,
|
extraction_strategy=extraction_strategy,
|
||||||
extraction_num_threads=extraction_num_threads,
|
extraction_num_threads=extraction_num_threads,
|
||||||
extraction_async_mode=extraction_async_mode,
|
extraction_async_mode=extraction_async_mode,
|
||||||
entity_types=entity_types,
|
entity_types=entity_types,
|
||||||
summarization_strategy=summarization_strategy,
|
summarization_strategy=summarization_strategy,
|
||||||
summarization_num_threads=summarization_num_threads,
|
summarization_num_threads=summarization_num_threads,
|
||||||
snapshot_graphml_enabled=snapshot_graphml_enabled,
|
|
||||||
snapshot_transient_enabled=snapshot_transient_enabled,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await runtime_storage.set("base_entity_nodes", base_entity_nodes)
|
await runtime_storage.set("base_entity_nodes", base_entity_nodes)
|
||||||
await runtime_storage.set("base_relationship_edges", base_relationship_edges)
|
await runtime_storage.set("base_relationship_edges", base_relationship_edges)
|
||||||
|
|
||||||
|
if snapshot_graphml_enabled:
|
||||||
|
# todo: extract graphs at each level, and add in meta like descriptions
|
||||||
|
graph = create_graph(base_relationship_edges)
|
||||||
|
await snapshot_graphml(
|
||||||
|
graph,
|
||||||
|
name="graph",
|
||||||
|
storage=storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
if snapshot_transient_enabled:
|
||||||
|
await snapshot(
|
||||||
|
base_entity_nodes,
|
||||||
|
name="base_entity_nodes",
|
||||||
|
storage=storage,
|
||||||
|
formats=["parquet"],
|
||||||
|
)
|
||||||
|
await snapshot(
|
||||||
|
base_relationship_edges,
|
||||||
|
name="base_relationship_edges",
|
||||||
|
storage=storage,
|
||||||
|
formats=["parquet"],
|
||||||
|
)
|
||||||
|
|
||||||
return create_verb_result(cast("Table", pd.DataFrame()))
|
return create_verb_result(cast("Table", pd.DataFrame()))
|
||||||
|
|||||||
4
tests/fixtures/min-csv/config.json
vendored
4
tests/fixtures/min-csv/config.json
vendored
@ -58,8 +58,8 @@
|
|||||||
],
|
],
|
||||||
"nan_allowed_columns": [
|
"nan_allowed_columns": [
|
||||||
"description",
|
"description",
|
||||||
"community",
|
"x",
|
||||||
"level"
|
"y"
|
||||||
],
|
],
|
||||||
"subworkflows": 1,
|
"subworkflows": 1,
|
||||||
"max_runtime": 150,
|
"max_runtime": 150,
|
||||||
|
|||||||
4
tests/fixtures/text/config.json
vendored
4
tests/fixtures/text/config.json
vendored
@ -76,8 +76,8 @@
|
|||||||
],
|
],
|
||||||
"nan_allowed_columns": [
|
"nan_allowed_columns": [
|
||||||
"description",
|
"description",
|
||||||
"community",
|
"x",
|
||||||
"level"
|
"y"
|
||||||
],
|
],
|
||||||
"subworkflows": 1,
|
"subworkflows": 1,
|
||||||
"max_runtime": 150,
|
"max_runtime": 150,
|
||||||
|
|||||||
@ -42,7 +42,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
|||||||
# self.assertItemsEqual isn't available yet, or I am just silly
|
# self.assertItemsEqual isn't available yet, or I am just silly
|
||||||
# so we sort the lists and compare them
|
# so we sort the lists and compare them
|
||||||
assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([
|
assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([
|
||||||
entity["name"] for entity in results.entities
|
entity["title"] for entity in results.entities
|
||||||
])
|
])
|
||||||
|
|
||||||
async def test_run_extract_entities_multiple_documents_correct_entities_returned(
|
async def test_run_extract_entities_multiple_documents_correct_entities_returned(
|
||||||
@ -81,7 +81,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
|||||||
# self.assertItemsEqual isn't available yet, or I am just silly
|
# self.assertItemsEqual isn't available yet, or I am just silly
|
||||||
# so we sort the lists and compare them
|
# so we sort the lists and compare them
|
||||||
assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([
|
assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([
|
||||||
entity["name"] for entity in results.entities
|
entity["title"] for entity in results.entities
|
||||||
])
|
])
|
||||||
|
|
||||||
async def test_run_extract_entities_multiple_documents_correct_edges_returned(self):
|
async def test_run_extract_entities_multiple_documents_correct_edges_returned(self):
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -4,7 +4,6 @@
|
|||||||
from graphrag.index.flows.compute_communities import (
|
from graphrag.index.flows.compute_communities import (
|
||||||
compute_communities,
|
compute_communities,
|
||||||
)
|
)
|
||||||
from graphrag.index.run.utils import create_run_context
|
|
||||||
from graphrag.index.workflows.v1.compute_communities import (
|
from graphrag.index.workflows.v1.compute_communities import (
|
||||||
workflow_name,
|
workflow_name,
|
||||||
)
|
)
|
||||||
@ -16,37 +15,15 @@ from .util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_compute_communities():
|
def test_compute_communities():
|
||||||
edges = load_test_table("base_relationship_edges")
|
edges = load_test_table("base_relationship_edges")
|
||||||
expected = load_test_table("base_communities")
|
expected = load_test_table("base_communities")
|
||||||
|
|
||||||
context = create_run_context(None, None, None)
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
clustering_strategy = config["cluster_graph"]["strategy"]
|
clustering_strategy = config["cluster_graph"]["strategy"]
|
||||||
|
|
||||||
actual = await compute_communities(
|
actual = compute_communities(edges, clustering_strategy=clustering_strategy)
|
||||||
edges, storage=context.storage, clustering_strategy=clustering_strategy
|
|
||||||
)
|
|
||||||
|
|
||||||
columns = list(expected.columns.values)
|
columns = list(expected.columns.values)
|
||||||
compare_outputs(actual, expected, columns)
|
compare_outputs(actual, expected, columns)
|
||||||
assert len(actual.columns) == len(expected.columns)
|
assert len(actual.columns) == len(expected.columns)
|
||||||
|
|
||||||
|
|
||||||
async def test_compute_communities_with_snapshots():
|
|
||||||
edges = load_test_table("base_relationship_edges")
|
|
||||||
|
|
||||||
context = create_run_context(None, None, None)
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
|
||||||
clustering_strategy = config["cluster_graph"]["strategy"]
|
|
||||||
|
|
||||||
await compute_communities(
|
|
||||||
edges,
|
|
||||||
storage=context.storage,
|
|
||||||
clustering_strategy=clustering_strategy,
|
|
||||||
snapshot_transient_enabled=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert context.storage.keys() == [
|
|
||||||
"base_communities.parquet",
|
|
||||||
], "Community snapshot keys differ"
|
|
||||||
|
|||||||
@ -16,10 +16,9 @@ from .util import (
|
|||||||
|
|
||||||
def test_create_final_relationships():
|
def test_create_final_relationships():
|
||||||
edges = load_test_table("base_relationship_edges")
|
edges = load_test_table("base_relationship_edges")
|
||||||
nodes = load_test_table("base_entity_nodes")
|
|
||||||
expected = load_test_table(workflow_name)
|
expected = load_test_table(workflow_name)
|
||||||
|
|
||||||
actual = create_final_relationships(edges, nodes)
|
actual = create_final_relationships(edges)
|
||||||
|
|
||||||
assert "id" in expected.columns
|
assert "id" in expected.columns
|
||||||
columns = list(expected.columns.values)
|
columns = list(expected.columns.values)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user