mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Incremental model alignment (#1766)
* Used shared schema lists for all final columns * Semver
This commit is contained in:
parent
0144b3fd88
commit
61a309b182
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Use shared schema for final outputs."
|
||||
}
|
@ -65,7 +65,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -96,7 +96,7 @@
|
||||
" final_nodes.loc[:, [\"id\", \"degree\", \"x\", \"y\"]].groupby(\"id\").first().reset_index()\n",
|
||||
")\n",
|
||||
"final_entities = final_entities.merge(graph_props, on=\"id\", how=\"left\")\n",
|
||||
"# we're also persistint the frequency column\n",
|
||||
"# we're also persisting the frequency column\n",
|
||||
"final_entities[\"frequency\"] = final_entities[\"text_unit_ids\"].count()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
@ -1,18 +1,20 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""Common field name definitions for community reports."""
|
||||
"""Common field name definitions for data frames."""
|
||||
|
||||
ID = "id"
|
||||
SHORT_ID = "human_readable_id"
|
||||
TITLE = "title"
|
||||
DESCRIPTION = "description"
|
||||
|
||||
TYPE = "type"
|
||||
|
||||
# POST-PREP NODE TABLE SCHEMA
|
||||
NODE_DEGREE = "degree"
|
||||
NODE_FREQUENCY = "frequency"
|
||||
NODE_DETAILS = "node_details"
|
||||
|
||||
NODE_PARENT_COMMUNITY = "parent_community"
|
||||
NODE_X = "x"
|
||||
NODE_Y = "y"
|
||||
|
||||
# POST-PREP EDGE TABLE SCHEMA
|
||||
EDGE_SOURCE = "source"
|
||||
@ -23,13 +25,11 @@ EDGE_WEIGHT = "weight"
|
||||
|
||||
# POST-PREP CLAIM TABLE SCHEMA
|
||||
CLAIM_SUBJECT = "subject_id"
|
||||
CLAIM_TYPE = "type"
|
||||
CLAIM_STATUS = "status"
|
||||
CLAIM_DETAILS = "claim_details"
|
||||
|
||||
# COMMUNITY HIERARCHY TABLE SCHEMA
|
||||
SUB_COMMUNITY = "sub_community"
|
||||
COMMUNITY_LEVEL = "level"
|
||||
|
||||
# COMMUNITY CONTEXT TABLE SCHEMA
|
||||
ALL_CONTEXT = "all_context"
|
||||
@ -40,6 +40,8 @@ CONTEXT_EXCEED_FLAG = "context_exceed_limit"
|
||||
# COMMUNITY REPORT TABLE SCHEMA
|
||||
COMMUNITY_ID = "community"
|
||||
COMMUNITY_LEVEL = "level"
|
||||
COMMUNITY_PARENT = "parent"
|
||||
COMMUNITY_CHILDREN = "children"
|
||||
TITLE = "title"
|
||||
SUMMARY = "summary"
|
||||
FINDINGS = "findings"
|
||||
@ -48,9 +50,114 @@ EXPLANATION = "rating_explanation"
|
||||
FULL_CONTENT = "full_content"
|
||||
FULL_CONTENT_JSON = "full_content_json"
|
||||
|
||||
ENTITY_IDS = "entity_ids"
|
||||
RELATIONSHIP_IDS = "relationship_ids"
|
||||
TEXT_UNIT_IDS = "text_unit_ids"
|
||||
COVARIATE_IDS = "covariate_ids"
|
||||
DOCUMENT_IDS = "document_ids"
|
||||
|
||||
PERIOD = "period"
|
||||
SIZE = "size"
|
||||
|
||||
# text units
|
||||
ENTITY_DEGREE = "entity_degree"
|
||||
ALL_DETAILS = "all_details"
|
||||
TEXT = "text"
|
||||
N_TOKENS = "n_tokens"
|
||||
|
||||
CREATION_DATE = "creation_date"
|
||||
METADATA = "metadata"
|
||||
|
||||
# the following lists define the final content and ordering of columns in the data model parquet outputs
|
||||
ENTITIES_FINAL_COLUMNS = [
|
||||
ID,
|
||||
SHORT_ID,
|
||||
TITLE,
|
||||
TYPE,
|
||||
DESCRIPTION,
|
||||
TEXT_UNIT_IDS,
|
||||
NODE_FREQUENCY,
|
||||
NODE_DEGREE,
|
||||
NODE_X,
|
||||
NODE_Y,
|
||||
]
|
||||
|
||||
RELATIONSHIPS_FINAL_COLUMNS = [
|
||||
ID,
|
||||
SHORT_ID,
|
||||
EDGE_SOURCE,
|
||||
EDGE_TARGET,
|
||||
DESCRIPTION,
|
||||
EDGE_WEIGHT,
|
||||
EDGE_DEGREE,
|
||||
TEXT_UNIT_IDS,
|
||||
]
|
||||
|
||||
COMMUNITIES_FINAL_COLUMNS = [
|
||||
ID,
|
||||
SHORT_ID,
|
||||
COMMUNITY_ID,
|
||||
COMMUNITY_LEVEL,
|
||||
COMMUNITY_PARENT,
|
||||
COMMUNITY_CHILDREN,
|
||||
TITLE,
|
||||
ENTITY_IDS,
|
||||
RELATIONSHIP_IDS,
|
||||
TEXT_UNIT_IDS,
|
||||
PERIOD,
|
||||
SIZE,
|
||||
]
|
||||
|
||||
COMMUNITY_REPORTS_FINAL_COLUMNS = [
|
||||
ID,
|
||||
SHORT_ID,
|
||||
COMMUNITY_ID,
|
||||
COMMUNITY_LEVEL,
|
||||
COMMUNITY_PARENT,
|
||||
COMMUNITY_CHILDREN,
|
||||
TITLE,
|
||||
SUMMARY,
|
||||
FULL_CONTENT,
|
||||
RATING,
|
||||
EXPLANATION,
|
||||
FINDINGS,
|
||||
FULL_CONTENT_JSON,
|
||||
PERIOD,
|
||||
SIZE,
|
||||
]
|
||||
|
||||
COVARIATES_FINAL_COLUMNS = [
|
||||
ID,
|
||||
SHORT_ID,
|
||||
"covariate_type",
|
||||
TYPE,
|
||||
DESCRIPTION,
|
||||
"subject_id",
|
||||
"object_id",
|
||||
"status",
|
||||
"start_date",
|
||||
"end_date",
|
||||
"source_text",
|
||||
"text_unit_id",
|
||||
]
|
||||
|
||||
TEXT_UNITS_FINAL_COLUMNS = [
|
||||
ID,
|
||||
SHORT_ID,
|
||||
TEXT,
|
||||
N_TOKENS,
|
||||
DOCUMENT_IDS,
|
||||
ENTITY_IDS,
|
||||
RELATIONSHIP_IDS,
|
||||
COVARIATE_IDS,
|
||||
]
|
||||
|
||||
DOCUMENTS_FINAL_COLUMNS = [
|
||||
ID,
|
||||
SHORT_ID,
|
||||
TITLE,
|
||||
TEXT,
|
||||
TEXT_UNIT_IDS,
|
||||
CREATION_DATE,
|
||||
METADATA,
|
||||
]
|
||||
|
@ -7,6 +7,8 @@ from uuid import uuid4
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS
|
||||
|
||||
|
||||
def finalize_community_reports(
|
||||
reports: pd.DataFrame,
|
||||
@ -27,21 +29,5 @@ def finalize_community_reports(
|
||||
|
||||
return community_reports.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"community",
|
||||
"level",
|
||||
"parent",
|
||||
"children",
|
||||
"title",
|
||||
"summary",
|
||||
"full_content",
|
||||
"rank",
|
||||
"rank_explanation",
|
||||
"findings",
|
||||
"full_content_json",
|
||||
"period",
|
||||
"size",
|
||||
],
|
||||
COMMUNITY_REPORTS_FINAL_COLUMNS,
|
||||
]
|
||||
|
@ -9,6 +9,7 @@ import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
from graphrag.data_model.schemas import ENTITIES_FINAL_COLUMNS
|
||||
from graphrag.index.operations.compute_degree import compute_degree
|
||||
from graphrag.index.operations.create_graph import create_graph
|
||||
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
|
||||
@ -52,16 +53,5 @@ def finalize_entities(
|
||||
)
|
||||
return final_entities.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"title",
|
||||
"type",
|
||||
"description",
|
||||
"text_unit_ids",
|
||||
"frequency",
|
||||
"degree",
|
||||
"x",
|
||||
"y",
|
||||
],
|
||||
ENTITIES_FINAL_COLUMNS,
|
||||
]
|
||||
|
@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.data_model.schemas import RELATIONSHIPS_FINAL_COLUMNS
|
||||
from graphrag.index.operations.compute_degree import compute_degree
|
||||
from graphrag.index.operations.compute_edge_combined_degree import (
|
||||
compute_edge_combined_degree,
|
||||
@ -39,14 +40,5 @@ def finalize_relationships(
|
||||
|
||||
return final_relationships.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"source",
|
||||
"target",
|
||||
"description",
|
||||
"weight",
|
||||
"combined_degree",
|
||||
"text_unit_ids",
|
||||
],
|
||||
RELATIONSHIPS_FINAL_COLUMNS,
|
||||
]
|
||||
|
@ -78,7 +78,7 @@ async def _run_extractor(
|
||||
level=level,
|
||||
rank=report.rating,
|
||||
title=report.title,
|
||||
rank_explanation=report.rating_explanation,
|
||||
rating_explanation=report.rating_explanation,
|
||||
summary=report.summary,
|
||||
findings=[
|
||||
Finding(explanation=f.explanation, summary=f.summary)
|
||||
|
@ -36,7 +36,7 @@ class CommunityReport(TypedDict):
|
||||
full_content_json: str
|
||||
rank: float
|
||||
level: int
|
||||
rank_explanation: str
|
||||
rating_explanation: str
|
||||
findings: list[Finding]
|
||||
|
||||
|
||||
|
@ -5,6 +5,11 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.data_model.schemas import (
|
||||
COMMUNITIES_FINAL_COLUMNS,
|
||||
COMMUNITY_REPORTS_FINAL_COLUMNS,
|
||||
)
|
||||
|
||||
|
||||
def _update_and_merge_communities(
|
||||
old_communities: pd.DataFrame,
|
||||
@ -76,19 +81,7 @@ def _update_and_merge_communities(
|
||||
|
||||
merged_communities = merged_communities.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"community",
|
||||
"parent",
|
||||
"level",
|
||||
"title",
|
||||
"entity_ids",
|
||||
"relationship_ids",
|
||||
"text_unit_ids",
|
||||
"period",
|
||||
"size",
|
||||
],
|
||||
COMMUNITIES_FINAL_COLUMNS,
|
||||
]
|
||||
return merged_communities, community_id_mapping
|
||||
|
||||
@ -155,22 +148,4 @@ def _update_and_merge_community_reports(
|
||||
"community"
|
||||
]
|
||||
|
||||
return merged_community_reports.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"community",
|
||||
"parent",
|
||||
"level",
|
||||
"title",
|
||||
"summary",
|
||||
"full_content",
|
||||
"rank",
|
||||
"rank_explanation",
|
||||
"findings",
|
||||
"full_content_json",
|
||||
"period",
|
||||
"size",
|
||||
],
|
||||
]
|
||||
return merged_community_reports.loc[:, COMMUNITY_REPORTS_FINAL_COLUMNS]
|
||||
|
@ -12,6 +12,7 @@ import pandas as pd
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.schemas import ENTITIES_FINAL_COLUMNS
|
||||
from graphrag.index.operations.summarize_descriptions.graph_intelligence_strategy import (
|
||||
run_graph_intelligence as run_entity_summarization,
|
||||
)
|
||||
@ -79,21 +80,7 @@ def _group_and_resolve_entities(
|
||||
resolved: pd.DataFrame = pd.DataFrame(aggregated)
|
||||
|
||||
# Modify column order to keep consistency
|
||||
resolved = resolved.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"title",
|
||||
"type",
|
||||
"description",
|
||||
"text_unit_ids",
|
||||
"frequency",
|
||||
"degree",
|
||||
"x",
|
||||
"y",
|
||||
],
|
||||
]
|
||||
resolved = resolved.loc[:, ENTITIES_FINAL_COLUMNS]
|
||||
|
||||
return resolved, id_mapping
|
||||
|
||||
|
@ -6,6 +6,8 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.data_model.schemas import RELATIONSHIPS_FINAL_COLUMNS
|
||||
|
||||
|
||||
def _update_and_merge_relationships(
|
||||
old_relationships: pd.DataFrame, delta_relationships: pd.DataFrame
|
||||
@ -59,14 +61,5 @@ def _update_and_merge_relationships(
|
||||
|
||||
return final_relationships.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"source",
|
||||
"target",
|
||||
"description",
|
||||
"weight",
|
||||
"combined_degree",
|
||||
"text_unit_ids",
|
||||
],
|
||||
RELATIONSHIPS_FINAL_COLUMNS,
|
||||
]
|
||||
|
@ -12,6 +12,7 @@ import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.cluster_graph import cluster_graph
|
||||
from graphrag.index.operations.create_graph import create_graph
|
||||
@ -146,18 +147,5 @@ def create_communities(
|
||||
|
||||
return final_communities.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"community",
|
||||
"level",
|
||||
"parent",
|
||||
"children",
|
||||
"title",
|
||||
"entity_ids",
|
||||
"relationship_ids",
|
||||
"text_unit_ids",
|
||||
"period",
|
||||
"size",
|
||||
],
|
||||
COMMUNITIES_FINAL_COLUMNS,
|
||||
]
|
||||
|
@ -175,7 +175,7 @@ def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
|
||||
[
|
||||
schemas.SHORT_ID,
|
||||
schemas.CLAIM_SUBJECT,
|
||||
schemas.CLAIM_TYPE,
|
||||
schemas.TYPE,
|
||||
schemas.CLAIM_STATUS,
|
||||
schemas.DESCRIPTION,
|
||||
],
|
||||
|
@ -7,6 +7,7 @@ import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
@ -66,17 +67,7 @@ def create_final_documents(
|
||||
rejoined["id"] = rejoined["id"].astype(str)
|
||||
rejoined["human_readable_id"] = rejoined.index + 1
|
||||
|
||||
# set the final column order, but adjust for metadata
|
||||
core_columns = [
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"title",
|
||||
"text",
|
||||
"text_unit_ids",
|
||||
"creation_date",
|
||||
]
|
||||
final_columns = [column for column in core_columns if column in rejoined.columns]
|
||||
if "metadata" in rejoined.columns:
|
||||
final_columns.append("metadata")
|
||||
if "metadata" not in rejoined.columns:
|
||||
rejoined["metadata"] = pd.Series(dtype="object")
|
||||
|
||||
return rejoined.loc[:, final_columns]
|
||||
return rejoined.loc[:, DOCUMENTS_FINAL_COLUMNS]
|
||||
|
@ -7,6 +7,7 @@ import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import (
|
||||
@ -65,21 +66,14 @@ def create_final_text_units(
|
||||
if final_covariates is not None:
|
||||
covariate_join = _covariates(final_covariates)
|
||||
final_joined = _join(relationship_joined, covariate_join)
|
||||
else:
|
||||
final_joined["covariate_ids"] = [[] for i in range(len(final_joined))]
|
||||
|
||||
aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()
|
||||
|
||||
return aggregated.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"text",
|
||||
"n_tokens",
|
||||
"document_ids",
|
||||
"entity_ids",
|
||||
"relationship_ids",
|
||||
*([] if final_covariates is None else ["covariate_ids"]),
|
||||
],
|
||||
TEXT_UNITS_FINAL_COLUMNS,
|
||||
]
|
||||
|
||||
|
||||
|
@ -12,6 +12,7 @@ from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.schemas import COVARIATES_FINAL_COLUMNS
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.extract_covariates.extract_covariates import (
|
||||
extract_covariates as extractor,
|
||||
@ -83,20 +84,4 @@ async def extract_covariates(
|
||||
covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4()))
|
||||
covariates["human_readable_id"] = covariates.index + 1
|
||||
|
||||
return covariates.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"human_readable_id",
|
||||
"covariate_type",
|
||||
"type",
|
||||
"description",
|
||||
"subject_id",
|
||||
"object_id",
|
||||
"status",
|
||||
"start_date",
|
||||
"end_date",
|
||||
"source_text",
|
||||
"text_unit_id",
|
||||
],
|
||||
]
|
||||
return covariates.loc[:, COVARIATES_FINAL_COLUMNS]
|
||||
|
@ -79,4 +79,4 @@ async def test_create_community_reports():
|
||||
|
||||
# assert a handful of mock data items to confirm they get put in the right spot
|
||||
assert actual["rank"][:1][0] == 2
|
||||
assert actual["rank_explanation"][:1][0] == "<rating_explanation>"
|
||||
assert actual["rating_explanation"][:1][0] == "<rating_explanation>"
|
||||
|
Loading…
x
Reference in New Issue
Block a user