Pyright fixes

This commit is contained in:
Alonso Guevara 2024-10-03 15:40:02 -06:00
parent 082db97614
commit d70457087e

View File

@ -211,21 +211,17 @@ def _group_and_resolve_entities(
# Group by name and resolve conflicts
aggregated = (
combined.groupby("name")
.agg(
{
"id": "first",
"type": "first",
"human_readable_id": "first",
"graph_embedding": "first",
"description": lambda x: os.linesep.join(x.astype(str)), # Ensure str
# Concatenate nd.array into a single list
"text_unit_ids": lambda x: ",".join(
str(i) for j in x.tolist() for i in j
),
# Keep only descriptions where the original value wasn't modified
"description_embedding": lambda x: x.iloc[0] if len(x) == 1 else np.nan,
}
)
.agg({
"id": "first",
"type": "first",
"human_readable_id": "first",
"graph_embedding": "first",
"description": lambda x: os.linesep.join(x.astype(str)), # Ensure str
# Concatenate nd.array into a single list
"text_unit_ids": lambda x: ",".join(str(i) for j in x.tolist() for i in j),
# Keep only descriptions where the original value wasn't modified
"description_embedding": lambda x: x.iloc[0] if len(x) == 1 else np.nan,
})
.reset_index()
)
@ -375,9 +371,9 @@ def _merge_and_update_nodes(
)
# Replace existing human_readable_id with the new one from merged_entities_df
delta_nodes["human_readable_id"] = delta_nodes[
"human_readable_id_new"
].combine_first(delta_nodes["human_readable_id"])
delta_nodes["human_readable_id"] = delta_nodes.loc[
:, "human_readable_id_new"
].combine_first(delta_nodes.loc[:, "human_readable_id"])
# Drop the auxiliary column from the merge
delta_nodes.drop(columns=["name", "human_readable_id_new"], inplace=True)
@ -411,12 +407,10 @@ def _merge_and_update_nodes(
}
# Specify custom aggregation for description and source_id
columns_to_agg.update(
{
"description": lambda x: os.linesep.join(x.astype(str)),
"source_id": lambda x: ",".join(str(i) for i in x.tolist()),
}
)
columns_to_agg.update({
"description": lambda x: os.linesep.join(x.astype(str)),
"source_id": lambda x: ",".join(str(i) for i in x.tolist()),
})
old_nodes = (
concat_nodes.groupby(["level", "title"]).agg(columns_to_agg).reset_index()