diff --git a/.semversioner/next-release/patch-20241003185355991586.json b/.semversioner/next-release/patch-20241003185355991586.json new file mode 100644 index 00000000..f45902b1 --- /dev/null +++ b/.semversioner/next-release/patch-20241003185355991586.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Update crete final nodes" +} diff --git a/graphrag/index/update/dataframes.py b/graphrag/index/update/dataframes.py index 15833fd4..b5031ee5 100644 --- a/graphrag/index/update/dataframes.py +++ b/graphrag/index/update/dataframes.py @@ -5,6 +5,11 @@ import os from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable + import numpy as np import pandas as pd @@ -120,15 +125,27 @@ async def update_dataframe_outputs( ) delta_text_units = dataframe_dict["create_final_text_units"] - merged_text_units = _update_and_merge_text_units( + merged_text_units_df = _update_and_merge_text_units( old_text_units, delta_text_units, entity_id_mapping ) # TODO: Using _new in the meantime, to compare outputs without overwriting the original await storage.set( - "create_final_text_units_new.parquet", merged_text_units.to_parquet() + "create_final_text_units_new.parquet", merged_text_units_df.to_parquet() ) + # Update final nodes + old_nodes = await _load_table_from_storage("create_final_nodes.parquet", storage) + delta_nodes = dataframe_dict["create_final_nodes"] + + merged_nodes = _merge_and_update_nodes( + old_nodes, + delta_nodes, + merged_relationships_df, + ) + + await storage.set("create_final_nodes_new.parquet", merged_nodes.to_parquet()) + async def _concat_dataframes(name, dataframe_dict, storage): """Concatenate the dataframes. @@ -229,7 +246,8 @@ def _group_and_resolve_entities( def _update_and_merge_relationships( - old_relationships: pd.DataFrame, delta_relationships: pd.DataFrame + old_relationships: pd.DataFrame, + delta_relationships: pd.DataFrame, ) -> pd.DataFrame: """Update and merge relationships. @@ -309,3 +327,169 @@ def _update_and_merge_text_units( # Merge the final text units return pd.concat([old_text_units, delta_text_units], ignore_index=True, copy=False) + + +def _merge_and_update_nodes( + old_nodes: pd.DataFrame, + delta_nodes: pd.DataFrame, + merged_relationships_df: pd.DataFrame, + community_count_threshold: int = 2, +) -> pd.DataFrame: + """Merge and update nodes. + + Parameters + ---------- + old_nodes : pd.DataFrame + The old nodes. + delta_nodes : pd.DataFrame + The delta nodes. + merged_entities_df : pd.DataFrame + The merged entities. + merged_relationships_df : pd.DataFrame + The merged relationships. + + Returns + ------- + pd.DataFrame + The updated nodes. + """ + # Increment all community ids by the max of the old nodes + old_max_community_id = old_nodes["community"].fillna(0).astype(int).max() + + # Increment only the non-NaN values in delta_nodes["community"] + delta_nodes["community"] = delta_nodes["community"].where( + delta_nodes["community"].isna(), + delta_nodes["community"].fillna(0).astype(int) + old_max_community_id + 1, + ) + + # Set index for comparison + old_nodes_index = old_nodes.set_index(["level", "title"]).index + delta_nodes_index = delta_nodes.set_index(["level", "title"]).index + + # Get all delta nodes that are not in the old nodes + new_delta_nodes_df = delta_nodes[ + ~delta_nodes_index.isin(old_nodes_index) + ].reset_index(drop=True) + + # Get all delta nodes that are in the old nodes + existing_delta_nodes_df = delta_nodes[ + delta_nodes_index.isin(old_nodes_index) + ].reset_index(drop=True) + + # Cocnat the DataFrames + concat_nodes = pd.concat([old_nodes, existing_delta_nodes_df], ignore_index=True) + columns_to_agg: dict[str, str | Callable] = { + col: "first" + for col in concat_nodes.columns + if col not in ["description", "source_id", "level", "title"] + } + + # Specify custom aggregation for description and source_id + columns_to_agg.update({ + "description": lambda x: os.linesep.join(x.astype(str)), + "source_id": lambda x: ",".join(str(i) for i in x.tolist()), + }) + + old_nodes = ( + concat_nodes.groupby(["level", "title"]).agg(columns_to_agg).reset_index() + ) + + new_delta_nodes_df = _assign_communities( + new_delta_nodes_df, + merged_relationships_df, + old_nodes, + community_count_threshold, + ) + + # Concatenate the old nodes with the new delta nodes + merged_final_nodes = pd.concat( + [old_nodes, new_delta_nodes_df], ignore_index=True, copy=False + ) + + merged_final_nodes["community"] = ( + merged_final_nodes["community"].fillna("").astype(str) + ) + + # Merge both source and target degrees + merged_final_nodes = merged_final_nodes.merge( + merged_relationships_df[["source", "source_degree"]], + how="left", + left_on="title", + right_on="source", + ).merge( + merged_relationships_df[["target", "target_degree"]], + how="left", + left_on="title", + right_on="target", + ) + + # Assign 'source_degree' to 'size' and 'degree' + merged_final_nodes["size"] = merged_final_nodes["source_degree"] + + # Fill NaN values in 'size' and 'degree' with target_degree + merged_final_nodes["size"] = merged_final_nodes["size"].fillna( + merged_final_nodes["target_degree"] + ) + merged_final_nodes["degree"] = merged_final_nodes["size"] + + # Drop duplicates and the auxiliary 'source', 'target, 'source_degree' and 'target_degree' columns + return merged_final_nodes.drop( + columns=["source", "source_degree", "target", "target_degree"] + ).drop_duplicates() + + +def _assign_communities( + new_delta_nodes_df: pd.DataFrame, + merged_relationships_df: pd.DataFrame, + old_nodes: pd.DataFrame, + community_count_threshold: int = 2, +) -> pd.DataFrame: + # Find all relationships for the new delta nodes + node_relationships = merged_relationships_df[ + merged_relationships_df["source"].isin(new_delta_nodes_df["title"]) + | merged_relationships_df["target"].isin(new_delta_nodes_df["title"]) + ] + + # Find old nodes that are related to these relationships + related_communities = old_nodes[ + old_nodes["title"].isin(node_relationships["source"]) + | old_nodes["title"].isin(node_relationships["target"]) + ] + + # Merge with new_delta_nodes_df to get the level and community info + related_communities = related_communities.merge( + new_delta_nodes_df[["level", "title"]], on=["level", "title"], how="inner" + ) + + # Count the communities for each (level, title) pair + community_counts = ( + related_communities.groupby(["level", "title"])["community"] + .value_counts() + .reset_index(name="count") + ) + + # Filter by community threshold and select the most common community for each node + most_common_communities = community_counts[ + community_counts["count"] >= community_count_threshold + ] + most_common_communities = ( + most_common_communities.groupby(["level", "title"]).first().reset_index() + ) + + # Merge the most common community information back into new_delta_nodes_df + new_delta_nodes_df = new_delta_nodes_df.merge( + most_common_communities[["level", "title", "community"]], + on=["level", "title"], + how="left", + suffixes=("", "_new"), + ) + + # Update the community in new_delta_nodes_df if a common community was found + new_delta_nodes_df["community"] = new_delta_nodes_df["community_new"].combine_first( + new_delta_nodes_df["community"] + ) + + # Drop the auxiliary column used for merging + new_delta_nodes_df.drop(columns=["community_new"], inplace=True) + + return new_delta_nodes_df