Update final nodes output

This commit is contained in:
Alonso Guevara 2024-10-03 12:54:22 -06:00
parent 43ec92e173
commit 4f93aa675c
2 changed files with 191 additions and 3 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Update crete final nodes"
}

View File

@ -5,6 +5,11 @@
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
import numpy as np
import pandas as pd
@ -120,15 +125,27 @@ async def update_dataframe_outputs(
)
delta_text_units = dataframe_dict["create_final_text_units"]
merged_text_units = _update_and_merge_text_units(
merged_text_units_df = _update_and_merge_text_units(
old_text_units, delta_text_units, entity_id_mapping
)
# TODO: Using _new in the meantime, to compare outputs without overwriting the original
await storage.set(
"create_final_text_units_new.parquet", merged_text_units.to_parquet()
"create_final_text_units_new.parquet", merged_text_units_df.to_parquet()
)
# Update final nodes
old_nodes = await _load_table_from_storage("create_final_nodes.parquet", storage)
delta_nodes = dataframe_dict["create_final_nodes"]
merged_nodes = _merge_and_update_nodes(
old_nodes,
delta_nodes,
merged_relationships_df,
)
await storage.set("create_final_nodes_new.parquet", merged_nodes.to_parquet())
async def _concat_dataframes(name, dataframe_dict, storage):
"""Concatenate the dataframes.
@ -229,7 +246,8 @@ def _group_and_resolve_entities(
def _update_and_merge_relationships(
old_relationships: pd.DataFrame, delta_relationships: pd.DataFrame
old_relationships: pd.DataFrame,
delta_relationships: pd.DataFrame,
) -> pd.DataFrame:
"""Update and merge relationships.
@ -309,3 +327,169 @@ def _update_and_merge_text_units(
# Merge the final text units
return pd.concat([old_text_units, delta_text_units], ignore_index=True, copy=False)
def _merge_and_update_nodes(
old_nodes: pd.DataFrame,
delta_nodes: pd.DataFrame,
merged_relationships_df: pd.DataFrame,
community_count_threshold: int = 2,
) -> pd.DataFrame:
"""Merge and update nodes.
Parameters
----------
old_nodes : pd.DataFrame
The old nodes.
delta_nodes : pd.DataFrame
The delta nodes.
merged_entities_df : pd.DataFrame
The merged entities.
merged_relationships_df : pd.DataFrame
The merged relationships.
Returns
-------
pd.DataFrame
The updated nodes.
"""
# Increment all community ids by the max of the old nodes
old_max_community_id = old_nodes["community"].fillna(0).astype(int).max()
# Increment only the non-NaN values in delta_nodes["community"]
delta_nodes["community"] = delta_nodes["community"].where(
delta_nodes["community"].isna(),
delta_nodes["community"].fillna(0).astype(int) + old_max_community_id + 1,
)
# Set index for comparison
old_nodes_index = old_nodes.set_index(["level", "title"]).index
delta_nodes_index = delta_nodes.set_index(["level", "title"]).index
# Get all delta nodes that are not in the old nodes
new_delta_nodes_df = delta_nodes[
~delta_nodes_index.isin(old_nodes_index)
].reset_index(drop=True)
# Get all delta nodes that are in the old nodes
existing_delta_nodes_df = delta_nodes[
delta_nodes_index.isin(old_nodes_index)
].reset_index(drop=True)
# Cocnat the DataFrames
concat_nodes = pd.concat([old_nodes, existing_delta_nodes_df], ignore_index=True)
columns_to_agg: dict[str, str | Callable] = {
col: "first"
for col in concat_nodes.columns
if col not in ["description", "source_id", "level", "title"]
}
# Specify custom aggregation for description and source_id
columns_to_agg.update({
"description": lambda x: os.linesep.join(x.astype(str)),
"source_id": lambda x: ",".join(str(i) for i in x.tolist()),
})
old_nodes = (
concat_nodes.groupby(["level", "title"]).agg(columns_to_agg).reset_index()
)
new_delta_nodes_df = _assign_communities(
new_delta_nodes_df,
merged_relationships_df,
old_nodes,
community_count_threshold,
)
# Concatenate the old nodes with the new delta nodes
merged_final_nodes = pd.concat(
[old_nodes, new_delta_nodes_df], ignore_index=True, copy=False
)
merged_final_nodes["community"] = (
merged_final_nodes["community"].fillna("").astype(str)
)
# Merge both source and target degrees
merged_final_nodes = merged_final_nodes.merge(
merged_relationships_df[["source", "source_degree"]],
how="left",
left_on="title",
right_on="source",
).merge(
merged_relationships_df[["target", "target_degree"]],
how="left",
left_on="title",
right_on="target",
)
# Assign 'source_degree' to 'size' and 'degree'
merged_final_nodes["size"] = merged_final_nodes["source_degree"]
# Fill NaN values in 'size' and 'degree' with target_degree
merged_final_nodes["size"] = merged_final_nodes["size"].fillna(
merged_final_nodes["target_degree"]
)
merged_final_nodes["degree"] = merged_final_nodes["size"]
# Drop duplicates and the auxiliary 'source', 'target, 'source_degree' and 'target_degree' columns
return merged_final_nodes.drop(
columns=["source", "source_degree", "target", "target_degree"]
).drop_duplicates()
def _assign_communities(
new_delta_nodes_df: pd.DataFrame,
merged_relationships_df: pd.DataFrame,
old_nodes: pd.DataFrame,
community_count_threshold: int = 2,
) -> pd.DataFrame:
# Find all relationships for the new delta nodes
node_relationships = merged_relationships_df[
merged_relationships_df["source"].isin(new_delta_nodes_df["title"])
| merged_relationships_df["target"].isin(new_delta_nodes_df["title"])
]
# Find old nodes that are related to these relationships
related_communities = old_nodes[
old_nodes["title"].isin(node_relationships["source"])
| old_nodes["title"].isin(node_relationships["target"])
]
# Merge with new_delta_nodes_df to get the level and community info
related_communities = related_communities.merge(
new_delta_nodes_df[["level", "title"]], on=["level", "title"], how="inner"
)
# Count the communities for each (level, title) pair
community_counts = (
related_communities.groupby(["level", "title"])["community"]
.value_counts()
.reset_index(name="count")
)
# Filter by community threshold and select the most common community for each node
most_common_communities = community_counts[
community_counts["count"] >= community_count_threshold
]
most_common_communities = (
most_common_communities.groupby(["level", "title"]).first().reset_index()
)
# Merge the most common community information back into new_delta_nodes_df
new_delta_nodes_df = new_delta_nodes_df.merge(
most_common_communities[["level", "title", "community"]],
on=["level", "title"],
how="left",
suffixes=("", "_new"),
)
# Update the community in new_delta_nodes_df if a common community was found
new_delta_nodes_df["community"] = new_delta_nodes_df["community_new"].combine_first(
new_delta_nodes_df["community"]
)
# Drop the auxiliary column used for merging
new_delta_nodes_df.drop(columns=["community_new"], inplace=True)
return new_delta_nodes_df