mirror of
				https://github.com/microsoft/graphrag.git
				synced 2025-11-03 19:30:10 +00:00 
			
		
		
		
	Flow cleanup (#1510)
* Move snapshots out of flows into verbs * Move degree compute out of extract_graph * Move entity/relationship df merging into extract * Move "title" to extraction source * Move text_unit_ids agg closer to extraction * Move data definition * Update test data * Semver * Update smoke tests * Fix empty degree field and update smoke tests and verb data * Move extractors (#1516) * Consolidate graph embedding and umap * Consolidate claim extraction * Consolidate graph extractor * Move graph utils * Move summarizers * Semver --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com> * Fix syntax typo --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
		
							parent
							
								
									d0543d1fd6
								
							
						
					
					
						commit
						c1c09bab80
					
				@ -0,0 +1,4 @@
 | 
			
		||||
{
 | 
			
		||||
  "type": "patch",
 | 
			
		||||
  "description": "Streamline flows."
 | 
			
		||||
}
 | 
			
		||||
@ -9,15 +9,11 @@ import pandas as pd
 | 
			
		||||
 | 
			
		||||
from graphrag.index.operations.cluster_graph import cluster_graph
 | 
			
		||||
from graphrag.index.operations.create_graph import create_graph
 | 
			
		||||
from graphrag.index.operations.snapshot import snapshot
 | 
			
		||||
from graphrag.storage.pipeline_storage import PipelineStorage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def compute_communities(
 | 
			
		||||
def compute_communities(
 | 
			
		||||
    base_relationship_edges: pd.DataFrame,
 | 
			
		||||
    storage: PipelineStorage,
 | 
			
		||||
    clustering_strategy: dict[str, Any],
 | 
			
		||||
    snapshot_transient_enabled: bool = False,
 | 
			
		||||
) -> pd.DataFrame:
 | 
			
		||||
    """All the steps to create the base entity graph."""
 | 
			
		||||
    graph = create_graph(base_relationship_edges)
 | 
			
		||||
@ -32,12 +28,4 @@ async def compute_communities(
 | 
			
		||||
    ).explode("title")
 | 
			
		||||
    base_communities["community"] = base_communities["community"].astype(int)
 | 
			
		||||
 | 
			
		||||
    if snapshot_transient_enabled:
 | 
			
		||||
        await snapshot(
 | 
			
		||||
            base_communities,
 | 
			
		||||
            name="base_communities",
 | 
			
		||||
            storage=storage,
 | 
			
		||||
            formats=["parquet"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return base_communities
 | 
			
		||||
 | 
			
		||||
@ -15,18 +15,14 @@ from datashaper import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from graphrag.index.operations.chunk_text import chunk_text
 | 
			
		||||
from graphrag.index.operations.snapshot import snapshot
 | 
			
		||||
from graphrag.index.utils.hashing import gen_sha512_hash
 | 
			
		||||
from graphrag.storage.pipeline_storage import PipelineStorage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def create_base_text_units(
 | 
			
		||||
def create_base_text_units(
 | 
			
		||||
    documents: pd.DataFrame,
 | 
			
		||||
    callbacks: VerbCallbacks,
 | 
			
		||||
    storage: PipelineStorage,
 | 
			
		||||
    chunk_by_columns: list[str],
 | 
			
		||||
    chunk_strategy: dict[str, Any] | None = None,
 | 
			
		||||
    snapshot_transient_enabled: bool = False,
 | 
			
		||||
) -> pd.DataFrame:
 | 
			
		||||
    """All the steps to transform base text_units."""
 | 
			
		||||
    sort = documents.sort_values(by=["id"], ascending=[True])
 | 
			
		||||
@ -74,19 +70,7 @@ async def create_base_text_units(
 | 
			
		||||
    # rename for downstream consumption
 | 
			
		||||
    chunked.rename(columns={"chunk": "text"}, inplace=True)
 | 
			
		||||
 | 
			
		||||
    output = cast(
 | 
			
		||||
        "pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if snapshot_transient_enabled:
 | 
			
		||||
        await snapshot(
 | 
			
		||||
            output,
 | 
			
		||||
            name="create_base_text_units",
 | 
			
		||||
            storage=storage,
 | 
			
		||||
            formats=["parquet"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return output
 | 
			
		||||
    return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: would be nice to inline this completely in the main method with pandas
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,7 @@ from datashaper import (
 | 
			
		||||
    VerbCallbacks,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from graphrag.index.operations.compute_degree import compute_degree
 | 
			
		||||
from graphrag.index.operations.create_graph import create_graph
 | 
			
		||||
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
 | 
			
		||||
from graphrag.index.operations.layout_graph.layout_graph import layout_graph
 | 
			
		||||
@ -37,15 +38,19 @@ def create_final_nodes(
 | 
			
		||||
        layout_strategy,
 | 
			
		||||
        embeddings=graph_embeddings,
 | 
			
		||||
    )
 | 
			
		||||
    nodes = base_entity_nodes.merge(
 | 
			
		||||
        layout, left_on="title", right_on="label", how="left"
 | 
			
		||||
 | 
			
		||||
    degrees = compute_degree(graph)
 | 
			
		||||
 | 
			
		||||
    nodes = (
 | 
			
		||||
        base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left")
 | 
			
		||||
        .merge(degrees, on="title", how="left")
 | 
			
		||||
        .merge(base_communities, on="title", how="left")
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    joined = nodes.merge(base_communities, on="title", how="left")
 | 
			
		||||
    joined["level"] = joined["level"].fillna(0).astype(int)
 | 
			
		||||
    joined["community"] = joined["community"].fillna(-1).astype(int)
 | 
			
		||||
 | 
			
		||||
    return joined.loc[
 | 
			
		||||
    nodes["level"] = nodes["level"].fillna(0).astype(int)
 | 
			
		||||
    nodes["community"] = nodes["community"].fillna(-1).astype(int)
 | 
			
		||||
    # disconnected nodes and those with no community even at level 0 can be missing degree
 | 
			
		||||
    nodes["degree"] = nodes["degree"].fillna(0).astype(int)
 | 
			
		||||
    return nodes.loc[
 | 
			
		||||
        :,
 | 
			
		||||
        [
 | 
			
		||||
            "id",
 | 
			
		||||
 | 
			
		||||
@ -5,20 +5,25 @@
 | 
			
		||||
 | 
			
		||||
import pandas as pd
 | 
			
		||||
 | 
			
		||||
from graphrag.index.operations.compute_degree import compute_degree
 | 
			
		||||
from graphrag.index.operations.compute_edge_combined_degree import (
 | 
			
		||||
    compute_edge_combined_degree,
 | 
			
		||||
)
 | 
			
		||||
from graphrag.index.operations.create_graph import create_graph
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_final_relationships(
 | 
			
		||||
    base_relationship_edges: pd.DataFrame,
 | 
			
		||||
    base_entity_nodes: pd.DataFrame,
 | 
			
		||||
) -> pd.DataFrame:
 | 
			
		||||
    """All the steps to transform final relationships."""
 | 
			
		||||
    relationships = base_relationship_edges
 | 
			
		||||
 | 
			
		||||
    graph = create_graph(base_relationship_edges)
 | 
			
		||||
    degrees = compute_degree(graph)
 | 
			
		||||
 | 
			
		||||
    relationships["combined_degree"] = compute_edge_combined_degree(
 | 
			
		||||
        relationships,
 | 
			
		||||
        base_entity_nodes,
 | 
			
		||||
        degrees,
 | 
			
		||||
        node_name_column="title",
 | 
			
		||||
        node_degree_column="degree",
 | 
			
		||||
        edge_source_column="source",
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,6 @@
 | 
			
		||||
from typing import Any
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
 | 
			
		||||
import networkx as nx
 | 
			
		||||
import pandas as pd
 | 
			
		||||
from datashaper import (
 | 
			
		||||
    AsyncType,
 | 
			
		||||
@ -14,33 +13,26 @@ from datashaper import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from graphrag.cache.pipeline_cache import PipelineCache
 | 
			
		||||
from graphrag.index.operations.create_graph import create_graph
 | 
			
		||||
from graphrag.index.operations.extract_entities import extract_entities
 | 
			
		||||
from graphrag.index.operations.snapshot import snapshot
 | 
			
		||||
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
 | 
			
		||||
from graphrag.index.operations.summarize_descriptions import (
 | 
			
		||||
    summarize_descriptions,
 | 
			
		||||
)
 | 
			
		||||
from graphrag.storage.pipeline_storage import PipelineStorage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def extract_graph(
 | 
			
		||||
    text_units: pd.DataFrame,
 | 
			
		||||
    callbacks: VerbCallbacks,
 | 
			
		||||
    cache: PipelineCache,
 | 
			
		||||
    storage: PipelineStorage,
 | 
			
		||||
    extraction_strategy: dict[str, Any] | None = None,
 | 
			
		||||
    extraction_num_threads: int = 4,
 | 
			
		||||
    extraction_async_mode: AsyncType = AsyncType.AsyncIO,
 | 
			
		||||
    entity_types: list[str] | None = None,
 | 
			
		||||
    summarization_strategy: dict[str, Any] | None = None,
 | 
			
		||||
    summarization_num_threads: int = 4,
 | 
			
		||||
    snapshot_graphml_enabled: bool = False,
 | 
			
		||||
    snapshot_transient_enabled: bool = False,
 | 
			
		||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
 | 
			
		||||
    """All the steps to create the base entity graph."""
 | 
			
		||||
    # this returns a graph for each text unit, to be merged later
 | 
			
		||||
    entity_dfs, relationship_dfs = await extract_entities(
 | 
			
		||||
    entities, relationships = await extract_entities(
 | 
			
		||||
        text_units,
 | 
			
		||||
        callbacks,
 | 
			
		||||
        cache,
 | 
			
		||||
@ -52,87 +44,38 @@ async def extract_graph(
 | 
			
		||||
        num_threads=extraction_num_threads,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if not _validate_data(entity_dfs):
 | 
			
		||||
    if not _validate_data(entities):
 | 
			
		||||
        error_msg = "Entity Extraction failed. No entities detected during extraction."
 | 
			
		||||
        callbacks.error(error_msg)
 | 
			
		||||
        raise ValueError(error_msg)
 | 
			
		||||
 | 
			
		||||
    if not _validate_data(relationship_dfs):
 | 
			
		||||
    if not _validate_data(relationships):
 | 
			
		||||
        error_msg = (
 | 
			
		||||
            "Entity Extraction failed. No relationships detected during extraction."
 | 
			
		||||
        )
 | 
			
		||||
        callbacks.error(error_msg)
 | 
			
		||||
        raise ValueError(error_msg)
 | 
			
		||||
 | 
			
		||||
    merged_entities = _merge_entities(entity_dfs)
 | 
			
		||||
    merged_relationships = _merge_relationships(relationship_dfs)
 | 
			
		||||
 | 
			
		||||
    entity_summaries, relationship_summaries = await summarize_descriptions(
 | 
			
		||||
        merged_entities,
 | 
			
		||||
        merged_relationships,
 | 
			
		||||
        entities,
 | 
			
		||||
        relationships,
 | 
			
		||||
        callbacks,
 | 
			
		||||
        cache,
 | 
			
		||||
        strategy=summarization_strategy,
 | 
			
		||||
        num_threads=summarization_num_threads,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)
 | 
			
		||||
    base_relationship_edges = _prep_edges(relationships, relationship_summaries)
 | 
			
		||||
 | 
			
		||||
    graph = create_graph(base_relationship_edges)
 | 
			
		||||
 | 
			
		||||
    base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)
 | 
			
		||||
 | 
			
		||||
    if snapshot_graphml_enabled:
 | 
			
		||||
        # todo: extract graphs at each level, and add in meta like descriptions
 | 
			
		||||
        await snapshot_graphml(
 | 
			
		||||
            graph,
 | 
			
		||||
            name="graph",
 | 
			
		||||
            storage=storage,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if snapshot_transient_enabled:
 | 
			
		||||
        await snapshot(
 | 
			
		||||
            base_entity_nodes,
 | 
			
		||||
            name="base_entity_nodes",
 | 
			
		||||
            storage=storage,
 | 
			
		||||
            formats=["parquet"],
 | 
			
		||||
        )
 | 
			
		||||
        await snapshot(
 | 
			
		||||
            base_relationship_edges,
 | 
			
		||||
            name="base_relationship_edges",
 | 
			
		||||
            storage=storage,
 | 
			
		||||
            formats=["parquet"],
 | 
			
		||||
        )
 | 
			
		||||
    base_entity_nodes = _prep_nodes(entities, entity_summaries)
 | 
			
		||||
 | 
			
		||||
    return (base_entity_nodes, base_relationship_edges)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _merge_entities(entity_dfs) -> pd.DataFrame:
 | 
			
		||||
    all_entities = pd.concat(entity_dfs, ignore_index=True)
 | 
			
		||||
    return (
 | 
			
		||||
        all_entities.groupby(["name", "type"], sort=False)
 | 
			
		||||
        .agg({"description": list, "source_id": list})
 | 
			
		||||
        .reset_index()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
 | 
			
		||||
    all_relationships = pd.concat(relationship_dfs, ignore_index=False)
 | 
			
		||||
    return (
 | 
			
		||||
        all_relationships.groupby(["source", "target"], sort=False)
 | 
			
		||||
        .agg({"description": list, "source_id": list, "weight": "sum"})
 | 
			
		||||
        .reset_index()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
 | 
			
		||||
    degrees_df = _compute_degree(graph)
 | 
			
		||||
def _prep_nodes(entities, summaries) -> pd.DataFrame:
 | 
			
		||||
    entities.drop(columns=["description"], inplace=True)
 | 
			
		||||
    nodes = (
 | 
			
		||||
        entities.merge(summaries, on="name", how="left")
 | 
			
		||||
        .merge(degrees_df, on="name")
 | 
			
		||||
        .drop_duplicates(subset="name")
 | 
			
		||||
        .rename(columns={"name": "title", "source_id": "text_unit_ids"})
 | 
			
		||||
    nodes = entities.merge(summaries, on="title", how="left").drop_duplicates(
 | 
			
		||||
        subset="title"
 | 
			
		||||
    )
 | 
			
		||||
    nodes = nodes.loc[nodes["title"].notna()].reset_index()
 | 
			
		||||
    nodes["human_readable_id"] = nodes.index
 | 
			
		||||
@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
 | 
			
		||||
        relationships.drop(columns=["description"])
 | 
			
		||||
        .drop_duplicates(subset=["source", "target"])
 | 
			
		||||
        .merge(summaries, on=["source", "target"], how="left")
 | 
			
		||||
        .rename(columns={"source_id": "text_unit_ids"})
 | 
			
		||||
    )
 | 
			
		||||
    edges["human_readable_id"] = edges.index
 | 
			
		||||
    edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
 | 
			
		||||
    return edges
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
 | 
			
		||||
    return pd.DataFrame([
 | 
			
		||||
        {"name": node, "degree": int(degree)}
 | 
			
		||||
        for node, degree in graph.degree  # type: ignore
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _validate_data(df_list: list[pd.DataFrame]) -> bool:
 | 
			
		||||
    """Validate that the dataframe list is valid. At least one dataframe must contain data."""
 | 
			
		||||
    return any(
 | 
			
		||||
        len(df) > 0 for df in df_list
 | 
			
		||||
    )  # Check for len, not .empty, as the dfs have schemas in some cases
 | 
			
		||||
def _validate_data(df: pd.DataFrame) -> bool:
 | 
			
		||||
    """Validate that the dataframe has data."""
 | 
			
		||||
    return len(df) > 0
 | 
			
		||||
 | 
			
		||||
@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings(
 | 
			
		||||
            strategy=text_embed_config["strategy"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        data = data.loc[:, ["id", "embedding"]]
 | 
			
		||||
 | 
			
		||||
        if snapshot_embeddings_enabled is True:
 | 
			
		||||
            data = data.loc[:, ["id", "embedding"]]
 | 
			
		||||
            await snapshot(
 | 
			
		||||
                data,
 | 
			
		||||
                name=f"embeddings.{name}",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								graphrag/index/operations/compute_degree.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								graphrag/index/operations/compute_degree.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,15 @@
 | 
			
		||||
# Copyright (c) 2024 Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License
 | 
			
		||||
 | 
			
		||||
"""A module containing create_graph definition."""
 | 
			
		||||
 | 
			
		||||
import networkx as nx
 | 
			
		||||
import pandas as pd
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_degree(graph: nx.Graph) -> pd.DataFrame:
 | 
			
		||||
    """Create a new DataFrame with the degree of each node in the graph."""
 | 
			
		||||
    return pd.DataFrame([
 | 
			
		||||
        {"title": node, "degree": int(degree)}
 | 
			
		||||
        for node, degree in graph.degree  # type: ignore
 | 
			
		||||
    ])
 | 
			
		||||
@ -37,7 +37,7 @@ async def extract_entities(
 | 
			
		||||
    async_mode: AsyncType = AsyncType.AsyncIO,
 | 
			
		||||
    entity_types=DEFAULT_ENTITY_TYPES,
 | 
			
		||||
    num_threads: int = 4,
 | 
			
		||||
) -> tuple[list[pd.DataFrame], list[pd.DataFrame]]:
 | 
			
		||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
 | 
			
		||||
    """
 | 
			
		||||
    Extract entities from a piece of text.
 | 
			
		||||
 | 
			
		||||
@ -138,7 +138,10 @@ async def extract_entities(
 | 
			
		||||
            entity_dfs.append(pd.DataFrame(result[0]))
 | 
			
		||||
            relationship_dfs.append(pd.DataFrame(result[1]))
 | 
			
		||||
 | 
			
		||||
    return (entity_dfs, relationship_dfs)
 | 
			
		||||
    entities = _merge_entities(entity_dfs)
 | 
			
		||||
    relationships = _merge_relationships(relationship_dfs)
 | 
			
		||||
 | 
			
		||||
    return (entities, relationships)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
 | 
			
		||||
@ -162,3 +165,25 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr
 | 
			
		||||
        case _:
 | 
			
		||||
            msg = f"Unknown strategy: {strategy_type}"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _merge_entities(entity_dfs) -> pd.DataFrame:
 | 
			
		||||
    all_entities = pd.concat(entity_dfs, ignore_index=True)
 | 
			
		||||
    return (
 | 
			
		||||
        all_entities.groupby(["title", "type"], sort=False)
 | 
			
		||||
        .agg(description=("description", list), text_unit_ids=("source_id", list))
 | 
			
		||||
        .reset_index()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
 | 
			
		||||
    all_relationships = pd.concat(relationship_dfs, ignore_index=False)
 | 
			
		||||
    return (
 | 
			
		||||
        all_relationships.groupby(["source", "target"], sort=False)
 | 
			
		||||
        .agg(
 | 
			
		||||
            description=("description", list),
 | 
			
		||||
            text_unit_ids=("source_id", list),
 | 
			
		||||
            weight=("weight", "sum"),
 | 
			
		||||
        )
 | 
			
		||||
        .reset_index()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -106,7 +106,7 @@ async def run_extract_entities(
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    entities = [
 | 
			
		||||
        ({"name": item[0], **(item[1] or {})})
 | 
			
		||||
        ({"title": item[0], **(item[1] or {})})
 | 
			
		||||
        for item in graph.nodes(data=True)
 | 
			
		||||
        if item is not None
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
@ -58,7 +58,7 @@ async def run(  # noqa RUF029 async is required for interface
 | 
			
		||||
 | 
			
		||||
    return EntityExtractionResult(
 | 
			
		||||
        entities=[
 | 
			
		||||
            {"type": entity_type, "name": name}
 | 
			
		||||
            {"type": entity_type, "title": name}
 | 
			
		||||
            for name, entity_type in entity_map.items()
 | 
			
		||||
        ],
 | 
			
		||||
        relationships=[],
 | 
			
		||||
 | 
			
		||||
@ -88,7 +88,7 @@ async def summarize_descriptions(
 | 
			
		||||
 | 
			
		||||
        node_futures = [
 | 
			
		||||
            do_summarize_descriptions(
 | 
			
		||||
                str(row[1]["name"]),
 | 
			
		||||
                str(row[1]["title"]),
 | 
			
		||||
                sorted(set(row[1]["description"])),
 | 
			
		||||
                ticker,
 | 
			
		||||
                semaphore,
 | 
			
		||||
@ -100,7 +100,7 @@ async def summarize_descriptions(
 | 
			
		||||
 | 
			
		||||
        node_descriptions = [
 | 
			
		||||
            {
 | 
			
		||||
                "name": result.id,
 | 
			
		||||
                "title": result.id,
 | 
			
		||||
                "description": result.description,
 | 
			
		||||
            }
 | 
			
		||||
            for result in node_results
 | 
			
		||||
 | 
			
		||||
@ -14,6 +14,7 @@ from datashaper.table_store.types import VerbResult, create_verb_result
 | 
			
		||||
 | 
			
		||||
from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
 | 
			
		||||
from graphrag.index.flows.compute_communities import compute_communities
 | 
			
		||||
from graphrag.index.operations.snapshot import snapshot
 | 
			
		||||
from graphrag.storage.pipeline_storage import PipelineStorage
 | 
			
		||||
 | 
			
		||||
workflow_name = "compute_communities"
 | 
			
		||||
@ -62,13 +63,19 @@ async def workflow(
 | 
			
		||||
    """All the steps to create the base entity graph."""
 | 
			
		||||
    base_relationship_edges = await runtime_storage.get("base_relationship_edges")
 | 
			
		||||
 | 
			
		||||
    base_communities = await compute_communities(
 | 
			
		||||
    base_communities = compute_communities(
 | 
			
		||||
        base_relationship_edges,
 | 
			
		||||
        storage,
 | 
			
		||||
        clustering_strategy=clustering_strategy,
 | 
			
		||||
        snapshot_transient_enabled=snapshot_transient_enabled,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    await runtime_storage.set("base_communities", base_communities)
 | 
			
		||||
 | 
			
		||||
    if snapshot_transient_enabled:
 | 
			
		||||
        await snapshot(
 | 
			
		||||
            base_communities,
 | 
			
		||||
            name="base_communities",
 | 
			
		||||
            storage=storage,
 | 
			
		||||
            formats=["parquet"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return create_verb_result(cast("Table", pd.DataFrame()))
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkf
 | 
			
		||||
from graphrag.index.flows.create_base_text_units import (
 | 
			
		||||
    create_base_text_units,
 | 
			
		||||
)
 | 
			
		||||
from graphrag.index.operations.snapshot import snapshot
 | 
			
		||||
from graphrag.storage.pipeline_storage import PipelineStorage
 | 
			
		||||
 | 
			
		||||
workflow_name = "create_base_text_units"
 | 
			
		||||
@ -65,17 +66,23 @@ async def workflow(
 | 
			
		||||
    """All the steps to transform base text_units."""
 | 
			
		||||
    source = cast("pd.DataFrame", input.get_input())
 | 
			
		||||
 | 
			
		||||
    output = await create_base_text_units(
 | 
			
		||||
    output = create_base_text_units(
 | 
			
		||||
        source,
 | 
			
		||||
        callbacks,
 | 
			
		||||
        storage,
 | 
			
		||||
        chunk_by_columns,
 | 
			
		||||
        chunk_strategy=chunk_strategy,
 | 
			
		||||
        snapshot_transient_enabled=snapshot_transient_enabled,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    await runtime_storage.set("base_text_units", output)
 | 
			
		||||
 | 
			
		||||
    if snapshot_transient_enabled:
 | 
			
		||||
        await snapshot(
 | 
			
		||||
            output,
 | 
			
		||||
            name="create_base_text_units",
 | 
			
		||||
            storage=storage,
 | 
			
		||||
            formats=["parquet"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return create_verb_result(
 | 
			
		||||
        cast(
 | 
			
		||||
            "Table",
 | 
			
		||||
 | 
			
		||||
@ -53,8 +53,7 @@ async def workflow(
 | 
			
		||||
) -> VerbResult:
 | 
			
		||||
    """All the steps to transform final relationships."""
 | 
			
		||||
    base_relationship_edges = await runtime_storage.get("base_relationship_edges")
 | 
			
		||||
    base_entity_nodes = await runtime_storage.get("base_entity_nodes")
 | 
			
		||||
 | 
			
		||||
    output = create_final_relationships(base_relationship_edges, base_entity_nodes)
 | 
			
		||||
    output = create_final_relationships(base_relationship_edges)
 | 
			
		||||
 | 
			
		||||
    return create_verb_result(cast("Table", output))
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,9 @@ from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkf
 | 
			
		||||
from graphrag.index.flows.extract_graph import (
 | 
			
		||||
    extract_graph,
 | 
			
		||||
)
 | 
			
		||||
from graphrag.index.operations.create_graph import create_graph
 | 
			
		||||
from graphrag.index.operations.snapshot import snapshot
 | 
			
		||||
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
 | 
			
		||||
from graphrag.storage.pipeline_storage import PipelineStorage
 | 
			
		||||
 | 
			
		||||
workflow_name = "extract_graph"
 | 
			
		||||
@ -90,18 +93,38 @@ async def workflow(
 | 
			
		||||
        text_units,
 | 
			
		||||
        callbacks,
 | 
			
		||||
        cache,
 | 
			
		||||
        storage,
 | 
			
		||||
        extraction_strategy=extraction_strategy,
 | 
			
		||||
        extraction_num_threads=extraction_num_threads,
 | 
			
		||||
        extraction_async_mode=extraction_async_mode,
 | 
			
		||||
        entity_types=entity_types,
 | 
			
		||||
        summarization_strategy=summarization_strategy,
 | 
			
		||||
        summarization_num_threads=summarization_num_threads,
 | 
			
		||||
        snapshot_graphml_enabled=snapshot_graphml_enabled,
 | 
			
		||||
        snapshot_transient_enabled=snapshot_transient_enabled,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    await runtime_storage.set("base_entity_nodes", base_entity_nodes)
 | 
			
		||||
    await runtime_storage.set("base_relationship_edges", base_relationship_edges)
 | 
			
		||||
 | 
			
		||||
    if snapshot_graphml_enabled:
 | 
			
		||||
        # todo: extract graphs at each level, and add in meta like descriptions
 | 
			
		||||
        graph = create_graph(base_relationship_edges)
 | 
			
		||||
        await snapshot_graphml(
 | 
			
		||||
            graph,
 | 
			
		||||
            name="graph",
 | 
			
		||||
            storage=storage,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if snapshot_transient_enabled:
 | 
			
		||||
        await snapshot(
 | 
			
		||||
            base_entity_nodes,
 | 
			
		||||
            name="base_entity_nodes",
 | 
			
		||||
            storage=storage,
 | 
			
		||||
            formats=["parquet"],
 | 
			
		||||
        )
 | 
			
		||||
        await snapshot(
 | 
			
		||||
            base_relationship_edges,
 | 
			
		||||
            name="base_relationship_edges",
 | 
			
		||||
            storage=storage,
 | 
			
		||||
            formats=["parquet"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return create_verb_result(cast("Table", pd.DataFrame()))
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								tests/fixtures/min-csv/config.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								tests/fixtures/min-csv/config.json
									
									
									
									
										vendored
									
									
								
							@ -58,8 +58,8 @@
 | 
			
		||||
            ],
 | 
			
		||||
            "nan_allowed_columns": [
 | 
			
		||||
                "description",
 | 
			
		||||
                "community",
 | 
			
		||||
                "level"
 | 
			
		||||
                "x",
 | 
			
		||||
                "y"
 | 
			
		||||
            ],
 | 
			
		||||
            "subworkflows": 1,
 | 
			
		||||
            "max_runtime": 150,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								tests/fixtures/text/config.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								tests/fixtures/text/config.json
									
									
									
									
										vendored
									
									
								
							@ -76,8 +76,8 @@
 | 
			
		||||
            ],
 | 
			
		||||
            "nan_allowed_columns": [
 | 
			
		||||
                "description",
 | 
			
		||||
                "community",
 | 
			
		||||
                "level"
 | 
			
		||||
                "x",
 | 
			
		||||
                "y"
 | 
			
		||||
            ],
 | 
			
		||||
            "subworkflows": 1,
 | 
			
		||||
            "max_runtime": 150,
 | 
			
		||||
 | 
			
		||||
@ -42,7 +42,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
 | 
			
		||||
        # self.assertItemsEqual isn't available yet, or I am just silly
 | 
			
		||||
        # so we sort the lists and compare them
 | 
			
		||||
        assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([
 | 
			
		||||
            entity["name"] for entity in results.entities
 | 
			
		||||
            entity["title"] for entity in results.entities
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
    async def test_run_extract_entities_multiple_documents_correct_entities_returned(
 | 
			
		||||
@ -81,7 +81,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
 | 
			
		||||
        # self.assertItemsEqual isn't available yet, or I am just silly
 | 
			
		||||
        # so we sort the lists and compare them
 | 
			
		||||
        assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([
 | 
			
		||||
            entity["name"] for entity in results.entities
 | 
			
		||||
            entity["title"] for entity in results.entities
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
    async def test_run_extract_entities_multiple_documents_correct_edges_returned(self):
 | 
			
		||||
 | 
			
		||||
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							@ -4,7 +4,6 @@
 | 
			
		||||
from graphrag.index.flows.compute_communities import (
 | 
			
		||||
    compute_communities,
 | 
			
		||||
)
 | 
			
		||||
from graphrag.index.run.utils import create_run_context
 | 
			
		||||
from graphrag.index.workflows.v1.compute_communities import (
 | 
			
		||||
    workflow_name,
 | 
			
		||||
)
 | 
			
		||||
@ -16,37 +15,15 @@ from .util import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_compute_communities():
 | 
			
		||||
def test_compute_communities():
 | 
			
		||||
    edges = load_test_table("base_relationship_edges")
 | 
			
		||||
    expected = load_test_table("base_communities")
 | 
			
		||||
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
    clustering_strategy = config["cluster_graph"]["strategy"]
 | 
			
		||||
 | 
			
		||||
    actual = await compute_communities(
 | 
			
		||||
        edges, storage=context.storage, clustering_strategy=clustering_strategy
 | 
			
		||||
    )
 | 
			
		||||
    actual = compute_communities(edges, clustering_strategy=clustering_strategy)
 | 
			
		||||
 | 
			
		||||
    columns = list(expected.columns.values)
 | 
			
		||||
    compare_outputs(actual, expected, columns)
 | 
			
		||||
    assert len(actual.columns) == len(expected.columns)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_compute_communities_with_snapshots():
 | 
			
		||||
    edges = load_test_table("base_relationship_edges")
 | 
			
		||||
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
    clustering_strategy = config["cluster_graph"]["strategy"]
 | 
			
		||||
 | 
			
		||||
    await compute_communities(
 | 
			
		||||
        edges,
 | 
			
		||||
        storage=context.storage,
 | 
			
		||||
        clustering_strategy=clustering_strategy,
 | 
			
		||||
        snapshot_transient_enabled=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert context.storage.keys() == [
 | 
			
		||||
        "base_communities.parquet",
 | 
			
		||||
    ], "Community snapshot keys differ"
 | 
			
		||||
 | 
			
		||||
@ -16,10 +16,9 @@ from .util import (
 | 
			
		||||
 | 
			
		||||
def test_create_final_relationships():
 | 
			
		||||
    edges = load_test_table("base_relationship_edges")
 | 
			
		||||
    nodes = load_test_table("base_entity_nodes")
 | 
			
		||||
    expected = load_test_table(workflow_name)
 | 
			
		||||
 | 
			
		||||
    actual = create_final_relationships(edges, nodes)
 | 
			
		||||
    actual = create_final_relationships(edges)
 | 
			
		||||
 | 
			
		||||
    assert "id" in expected.columns
 | 
			
		||||
    columns = list(expected.columns.values)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user