Incremental model alignment (#1766)

* Used shared schema lists for all final columns

* Semver
This commit is contained in:
Nathan Evans 2025-02-25 11:14:42 -08:00 committed by GitHub
parent 0144b3fd88
commit 61a309b182
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 153 additions and 161 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Use shared schema for final outputs."
}

View File

@ -65,7 +65,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -96,7 +96,7 @@
" final_nodes.loc[:, [\"id\", \"degree\", \"x\", \"y\"]].groupby(\"id\").first().reset_index()\n",
")\n",
"final_entities = final_entities.merge(graph_props, on=\"id\", how=\"left\")\n",
"# we're also persistint the frequency column\n",
"# we're also persisting the frequency column\n",
"final_entities[\"frequency\"] = final_entities[\"text_unit_ids\"].count()\n",
"\n",
"\n",

View File

@ -1,18 +1,20 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Common field name definitions for community reports."""
"""Common field name definitions for data frames."""
ID = "id"
SHORT_ID = "human_readable_id"
TITLE = "title"
DESCRIPTION = "description"
TYPE = "type"
# POST-PREP NODE TABLE SCHEMA
NODE_DEGREE = "degree"
NODE_FREQUENCY = "frequency"
NODE_DETAILS = "node_details"
NODE_PARENT_COMMUNITY = "parent_community"
NODE_X = "x"
NODE_Y = "y"
# POST-PREP EDGE TABLE SCHEMA
EDGE_SOURCE = "source"
@ -23,13 +25,11 @@ EDGE_WEIGHT = "weight"
# POST-PREP CLAIM TABLE SCHEMA
CLAIM_SUBJECT = "subject_id"
CLAIM_TYPE = "type"
CLAIM_STATUS = "status"
CLAIM_DETAILS = "claim_details"
# COMMUNITY HIERARCHY TABLE SCHEMA
SUB_COMMUNITY = "sub_community"
COMMUNITY_LEVEL = "level"
# COMMUNITY CONTEXT TABLE SCHEMA
ALL_CONTEXT = "all_context"
@ -40,6 +40,8 @@ CONTEXT_EXCEED_FLAG = "context_exceed_limit"
# COMMUNITY REPORT TABLE SCHEMA
COMMUNITY_ID = "community"
COMMUNITY_LEVEL = "level"
COMMUNITY_PARENT = "parent"
COMMUNITY_CHILDREN = "children"
TITLE = "title"
SUMMARY = "summary"
FINDINGS = "findings"
@ -48,9 +50,114 @@ EXPLANATION = "rating_explanation"
FULL_CONTENT = "full_content"
FULL_CONTENT_JSON = "full_content_json"
ENTITY_IDS = "entity_ids"
RELATIONSHIP_IDS = "relationship_ids"
TEXT_UNIT_IDS = "text_unit_ids"
COVARIATE_IDS = "covariate_ids"
DOCUMENT_IDS = "document_ids"
PERIOD = "period"
SIZE = "size"
# text units
ENTITY_DEGREE = "entity_degree"
ALL_DETAILS = "all_details"
TEXT = "text"
N_TOKENS = "n_tokens"
CREATION_DATE = "creation_date"
METADATA = "metadata"
# the following lists define the final content and ordering of columns in the data model parquet outputs
ENTITIES_FINAL_COLUMNS = [
ID,
SHORT_ID,
TITLE,
TYPE,
DESCRIPTION,
TEXT_UNIT_IDS,
NODE_FREQUENCY,
NODE_DEGREE,
NODE_X,
NODE_Y,
]
RELATIONSHIPS_FINAL_COLUMNS = [
ID,
SHORT_ID,
EDGE_SOURCE,
EDGE_TARGET,
DESCRIPTION,
EDGE_WEIGHT,
EDGE_DEGREE,
TEXT_UNIT_IDS,
]
COMMUNITIES_FINAL_COLUMNS = [
ID,
SHORT_ID,
COMMUNITY_ID,
COMMUNITY_LEVEL,
COMMUNITY_PARENT,
COMMUNITY_CHILDREN,
TITLE,
ENTITY_IDS,
RELATIONSHIP_IDS,
TEXT_UNIT_IDS,
PERIOD,
SIZE,
]
COMMUNITY_REPORTS_FINAL_COLUMNS = [
ID,
SHORT_ID,
COMMUNITY_ID,
COMMUNITY_LEVEL,
COMMUNITY_PARENT,
COMMUNITY_CHILDREN,
TITLE,
SUMMARY,
FULL_CONTENT,
RATING,
EXPLANATION,
FINDINGS,
FULL_CONTENT_JSON,
PERIOD,
SIZE,
]
COVARIATES_FINAL_COLUMNS = [
ID,
SHORT_ID,
"covariate_type",
TYPE,
DESCRIPTION,
"subject_id",
"object_id",
"status",
"start_date",
"end_date",
"source_text",
"text_unit_id",
]
TEXT_UNITS_FINAL_COLUMNS = [
ID,
SHORT_ID,
TEXT,
N_TOKENS,
DOCUMENT_IDS,
ENTITY_IDS,
RELATIONSHIP_IDS,
COVARIATE_IDS,
]
DOCUMENTS_FINAL_COLUMNS = [
ID,
SHORT_ID,
TITLE,
TEXT,
TEXT_UNIT_IDS,
CREATION_DATE,
METADATA,
]

View File

@ -7,6 +7,8 @@ from uuid import uuid4
import pandas as pd
from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS
def finalize_community_reports(
reports: pd.DataFrame,
@ -27,21 +29,5 @@ def finalize_community_reports(
return community_reports.loc[
:,
[
"id",
"human_readable_id",
"community",
"level",
"parent",
"children",
"title",
"summary",
"full_content",
"rank",
"rank_explanation",
"findings",
"full_content_json",
"period",
"size",
],
COMMUNITY_REPORTS_FINAL_COLUMNS,
]

View File

@ -9,6 +9,7 @@ import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.data_model.schemas import ENTITIES_FINAL_COLUMNS
from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
@ -52,16 +53,5 @@ def finalize_entities(
)
return final_entities.loc[
:,
[
"id",
"human_readable_id",
"title",
"type",
"description",
"text_unit_ids",
"frequency",
"degree",
"x",
"y",
],
ENTITIES_FINAL_COLUMNS,
]

View File

@ -7,6 +7,7 @@ from uuid import uuid4
import pandas as pd
from graphrag.data_model.schemas import RELATIONSHIPS_FINAL_COLUMNS
from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.compute_edge_combined_degree import (
compute_edge_combined_degree,
@ -39,14 +40,5 @@ def finalize_relationships(
return final_relationships.loc[
:,
[
"id",
"human_readable_id",
"source",
"target",
"description",
"weight",
"combined_degree",
"text_unit_ids",
],
RELATIONSHIPS_FINAL_COLUMNS,
]

View File

@ -78,7 +78,7 @@ async def _run_extractor(
level=level,
rank=report.rating,
title=report.title,
rank_explanation=report.rating_explanation,
rating_explanation=report.rating_explanation,
summary=report.summary,
findings=[
Finding(explanation=f.explanation, summary=f.summary)

View File

@ -36,7 +36,7 @@ class CommunityReport(TypedDict):
full_content_json: str
rank: float
level: int
rank_explanation: str
rating_explanation: str
findings: list[Finding]

View File

@ -5,6 +5,11 @@
import pandas as pd
from graphrag.data_model.schemas import (
COMMUNITIES_FINAL_COLUMNS,
COMMUNITY_REPORTS_FINAL_COLUMNS,
)
def _update_and_merge_communities(
old_communities: pd.DataFrame,
@ -76,19 +81,7 @@ def _update_and_merge_communities(
merged_communities = merged_communities.loc[
:,
[
"id",
"human_readable_id",
"community",
"parent",
"level",
"title",
"entity_ids",
"relationship_ids",
"text_unit_ids",
"period",
"size",
],
COMMUNITIES_FINAL_COLUMNS,
]
return merged_communities, community_id_mapping
@ -155,22 +148,4 @@ def _update_and_merge_community_reports(
"community"
]
return merged_community_reports.loc[
:,
[
"id",
"human_readable_id",
"community",
"parent",
"level",
"title",
"summary",
"full_content",
"rank",
"rank_explanation",
"findings",
"full_content_json",
"period",
"size",
],
]
return merged_community_reports.loc[:, COMMUNITY_REPORTS_FINAL_COLUMNS]

View File

@ -12,6 +12,7 @@ import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import ENTITIES_FINAL_COLUMNS
from graphrag.index.operations.summarize_descriptions.graph_intelligence_strategy import (
run_graph_intelligence as run_entity_summarization,
)
@ -79,21 +80,7 @@ def _group_and_resolve_entities(
resolved: pd.DataFrame = pd.DataFrame(aggregated)
# Modify column order to keep consistency
resolved = resolved.loc[
:,
[
"id",
"human_readable_id",
"title",
"type",
"description",
"text_unit_ids",
"frequency",
"degree",
"x",
"y",
],
]
resolved = resolved.loc[:, ENTITIES_FINAL_COLUMNS]
return resolved, id_mapping

View File

@ -6,6 +6,8 @@
import numpy as np
import pandas as pd
from graphrag.data_model.schemas import RELATIONSHIPS_FINAL_COLUMNS
def _update_and_merge_relationships(
old_relationships: pd.DataFrame, delta_relationships: pd.DataFrame
@ -59,14 +61,5 @@ def _update_and_merge_relationships(
return final_relationships.loc[
:,
[
"id",
"human_readable_id",
"source",
"target",
"description",
"weight",
"combined_degree",
"text_unit_ids",
],
RELATIONSHIPS_FINAL_COLUMNS,
]

View File

@ -12,6 +12,7 @@ import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
@ -146,18 +147,5 @@ def create_communities(
return final_communities.loc[
:,
[
"id",
"human_readable_id",
"community",
"level",
"parent",
"children",
"title",
"entity_ids",
"relationship_ids",
"text_unit_ids",
"period",
"size",
],
COMMUNITIES_FINAL_COLUMNS,
]

View File

@ -175,7 +175,7 @@ def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
[
schemas.SHORT_ID,
schemas.CLAIM_SUBJECT,
schemas.CLAIM_TYPE,
schemas.TYPE,
schemas.CLAIM_STATUS,
schemas.DESCRIPTION,
],

View File

@ -7,6 +7,7 @@ import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS
from graphrag.index.context import PipelineRunContext
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
@ -66,17 +67,7 @@ def create_final_documents(
rejoined["id"] = rejoined["id"].astype(str)
rejoined["human_readable_id"] = rejoined.index + 1
# set the final column order, but adjust for metadata
core_columns = [
"id",
"human_readable_id",
"title",
"text",
"text_unit_ids",
"creation_date",
]
final_columns = [column for column in core_columns if column in rejoined.columns]
if "metadata" in rejoined.columns:
final_columns.append("metadata")
if "metadata" not in rejoined.columns:
rejoined["metadata"] = pd.Series(dtype="object")
return rejoined.loc[:, final_columns]
return rejoined.loc[:, DOCUMENTS_FINAL_COLUMNS]

View File

@ -7,6 +7,7 @@ import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS
from graphrag.index.context import PipelineRunContext
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.utils.storage import (
@ -65,21 +66,14 @@ def create_final_text_units(
if final_covariates is not None:
covariate_join = _covariates(final_covariates)
final_joined = _join(relationship_joined, covariate_join)
else:
final_joined["covariate_ids"] = [[] for i in range(len(final_joined))]
aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()
return aggregated.loc[
:,
[
"id",
"human_readable_id",
"text",
"n_tokens",
"document_ids",
"entity_ids",
"relationship_ids",
*([] if final_covariates is None else ["covariate_ids"]),
],
TEXT_UNITS_FINAL_COLUMNS,
]

View File

@ -12,6 +12,7 @@ from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import COVARIATES_FINAL_COLUMNS
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.extract_covariates.extract_covariates import (
extract_covariates as extractor,
@ -83,20 +84,4 @@ async def extract_covariates(
covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4()))
covariates["human_readable_id"] = covariates.index + 1
return covariates.loc[
:,
[
"id",
"human_readable_id",
"covariate_type",
"type",
"description",
"subject_id",
"object_id",
"status",
"start_date",
"end_date",
"source_text",
"text_unit_id",
],
]
return covariates.loc[:, COVARIATES_FINAL_COLUMNS]

View File

@ -79,4 +79,4 @@ async def test_create_community_reports():
# assert a handful of mock data items to confirm they get put in the right spot
assert actual["rank"][:1][0] == 2
assert actual["rank_explanation"][:1][0] == "<rating_explanation>"
assert actual["rating_explanation"][:1][0] == "<rating_explanation>"