Various minor updates (#1932)

* Add text unit ids to Community model

* Add graph utilities

* Turn off LCC for clustering by default

* Simplify embeddings config/flow

* Semver
This commit is contained in:
Nathan Evans 2025-05-16 14:48:53 -07:00 committed by GitHub
parent ee1b2db4a0
commit 36948b8d2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 330 additions and 147 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "A few fixes and enhancements for better reuse and flow."
}

View File

@ -201,8 +201,7 @@ Supported embeddings names are:
- `vector_store_id` **str** - Name of vector store definition to write to.
- `batch_size` **int** - The maximum batch size to use.
- `batch_max_tokens` **int** - The maximum batch # of tokens.
- `target` **required|all|selected|none** - Determines which set of embeddings to export.
- `names` **list[str]** - If target=selected, this should be an explicit list of the embeddings names we support.
- `names` **list[str]** - List of the embeddings names to run (must be in supported list).
### extract_graph

View File

@ -7,6 +7,7 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal
from graphrag.config.embeddings import default_embeddings
from graphrag.config.enums import (
AsyncType,
AuthType,
@ -18,7 +19,6 @@ from graphrag.config.enums import (
NounPhraseExtractorType,
OutputType,
ReportingType,
TextEmbeddingTarget,
)
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
EN_STOP_WORDS,
@ -147,9 +147,8 @@ class EmbedTextDefaults:
model: str = "text-embedding-3-small"
batch_size: int = 16
batch_max_tokens: int = 8191
target = TextEmbeddingTarget.required
model_id: str = DEFAULT_EMBEDDING_MODEL_ID
names: list[str] = field(default_factory=list)
names: list[str] = field(default_factory=lambda: default_embeddings)
strategy: None = None
vector_store_id: str = DEFAULT_VECTOR_STORE_ID

View File

@ -3,9 +3,6 @@
"""A module containing embeddings values."""
from graphrag.config.enums import TextEmbeddingTarget
from graphrag.config.models.graph_rag_config import GraphRagConfig
entity_title_embedding = "entity.title"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
@ -25,60 +22,11 @@ all_embeddings: set[str] = {
community_full_content_embedding,
text_unit_text_embedding,
}
required_embeddings: set[str] = {
default_embeddings: list[str] = [
entity_description_embedding,
community_full_content_embedding,
text_unit_text_embedding,
}
def get_embedded_fields(settings: GraphRagConfig) -> set[str]:
"""Get the fields to embed based on the enum or specifically selected embeddings."""
match settings.embed_text.target:
case TextEmbeddingTarget.all:
return all_embeddings
case TextEmbeddingTarget.required:
return required_embeddings
case TextEmbeddingTarget.selected:
return set(settings.embed_text.names)
case TextEmbeddingTarget.none:
return set()
case _:
msg = f"Unknown embeddings target: {settings.embed_text.target}"
raise ValueError(msg)
def get_embedding_settings(
settings: GraphRagConfig,
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
# TEMP
embeddings_llm_settings = settings.get_language_model_config(
settings.embed_text.model_id
)
vector_store_settings = settings.get_vector_store_config(
settings.embed_text.vector_store_id
).model_dump()
#
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
# settings.vector_store.base contains connection information, or may be undefined
# settings.vector_store.<vector_name> contains the specific settings for this embedding
#
strategy = settings.embed_text.resolved_strategy(
embeddings_llm_settings
) # get the default strategy
strategy.update({
"vector_store": {
**(vector_store_params or {}),
**(vector_store_settings),
}
}) # update the default strategy with the vector store settings
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
}
]
def create_collection_name(

View File

@ -87,19 +87,6 @@ class ReportingType(str, Enum):
return f'"{self.value}"'
class TextEmbeddingTarget(str, Enum):
"""The target to use for text embeddings."""
all = "all"
required = "required"
selected = "selected"
none = "none"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
class ModelType(str, Enum):
"""LLMType enum class definition."""
@ -176,3 +163,15 @@ class NounPhraseExtractorType(str, Enum):
"""Noun phrase extractor based on dependency parsing and NER using SpaCy."""
CFG = "cfg"
"""Noun phrase extractor combining CFG-based noun-chunk extraction and NER."""
class ModularityMetric(str, Enum):
"""Enum for the modularity metric to use."""
Graph = "graph"
"""Graph modularity metric."""
LCC = "lcc"
WeightedComponents = "weighted_components"
"""Weighted components modularity metric."""

View File

@ -0,0 +1,39 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing get_embedding_settings."""
from graphrag.config.models.graph_rag_config import GraphRagConfig
def get_embedding_settings(
settings: GraphRagConfig,
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
# TEMP
embeddings_llm_settings = settings.get_language_model_config(
settings.embed_text.model_id
)
vector_store_settings = settings.get_vector_store_config(
settings.embed_text.vector_store_id
).model_dump()
#
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
# settings.vector_store.base contains connection information, or may be undefined
# settings.vector_store.<vector_name> contains the specific settings for this embedding
#
strategy = settings.embed_text.resolved_strategy(
embeddings_llm_settings
) # get the default strategy
strategy.update({
"vector_store": {
**(vector_store_params or {}),
**(vector_store_settings),
}
}) # update the default strategy with the vector store settings
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
}

View File

@ -6,7 +6,6 @@
from pydantic import BaseModel, Field
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import TextEmbeddingTarget
from graphrag.config.models.language_model_config import LanguageModelConfig
@ -29,10 +28,6 @@ class TextEmbeddingConfig(BaseModel):
description="The batch max tokens to use.",
default=graphrag_config_defaults.embed_text.batch_max_tokens,
)
target: TextEmbeddingTarget = Field(
description="The target to use. 'all', 'required', 'selected', or 'none'.",
default=graphrag_config_defaults.embed_text.target,
)
names: list[str] = Field(
description="The specific embeddings to perform.",
default=graphrag_config_defaults.embed_text.names,

View File

@ -28,6 +28,9 @@ class Community(Named):
relationship_ids: list[str] | None = None
"""List of relationship IDs related to the community (optional)."""
text_unit_ids: list[str] | None = None
"""List of text unit IDs related to the community (optional)."""
covariate_ids: dict[str, list[str]] | None = None
"""Dictionary of different types of covariates related to the community (optional), e.g. claims"""
@ -50,6 +53,7 @@ class Community(Named):
level_key: str = "level",
entities_key: str = "entity_ids",
relationships_key: str = "relationship_ids",
text_units_key: str = "text_unit_ids",
covariates_key: str = "covariate_ids",
parent_key: str = "parent",
children_key: str = "children",
@ -67,6 +71,7 @@ class Community(Named):
short_id=d.get(short_id_key),
entity_ids=d.get(entities_key),
relationship_ids=d.get(relationships_key),
text_unit_ids=d.get(text_units_key),
covariate_ids=d.get(covariates_key),
attributes=d.get(attributes_key),
size=d.get(size_key),

View File

@ -15,6 +15,7 @@ from graphrag.index.operations.build_noun_graph.np_extractors.base import (
BaseNounPhraseExtractor,
)
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.index.utils.graphs import calculate_pmi_edge_weights
from graphrag.index.utils.hashing import gen_sha512_hash
@ -127,52 +128,6 @@ def _extract_edges(
]
if normalize_edge_weights:
# use PMI weight instead of raw weight
grouped_edge_df = _calculate_pmi_edge_weights(nodes_df, grouped_edge_df)
grouped_edge_df = calculate_pmi_edge_weights(nodes_df, grouped_edge_df)
return grouped_edge_df
def _calculate_pmi_edge_weights(
nodes_df: pd.DataFrame,
edges_df: pd.DataFrame,
node_name_col="title",
node_freq_col="frequency",
edge_weight_col="weight",
edge_source_col="source",
edge_target_col="target",
) -> pd.DataFrame:
"""
Calculate pointwise mutual information (PMI) edge weights.
pmi(x,y) = log2(p(x,y) / (p(x)p(y)))
p(x,y) = edge_weight(x,y) / total_edge_weights
p(x) = freq_occurrence(x) / total_freq_occurrences
"""
copied_nodes_df = nodes_df[[node_name_col, node_freq_col]]
total_edge_weights = edges_df[edge_weight_col].sum()
total_freq_occurrences = nodes_df[node_freq_col].sum()
copied_nodes_df["prop_occurrence"] = (
copied_nodes_df[node_freq_col] / total_freq_occurrences
)
copied_nodes_df = copied_nodes_df.loc[:, [node_name_col, "prop_occurrence"]]
edges_df["prop_weight"] = edges_df[edge_weight_col] / total_edge_weights
edges_df = (
edges_df.merge(
copied_nodes_df, left_on=edge_source_col, right_on=node_name_col, how="left"
)
.drop(columns=[node_name_col])
.rename(columns={"prop_occurrence": "source_prop"})
)
edges_df = (
edges_df.merge(
copied_nodes_df, left_on=edge_target_col, right_on=node_name_col, how="left"
)
.drop(columns=[node_name_col])
.rename(columns={"prop_occurrence": "target_prop"})
)
edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2(
edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"])
)
return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"])

View File

@ -6,6 +6,7 @@
import logging
import networkx as nx
from graspologic.partition import hierarchical_leiden
from graphrag.index.utils.stable_lcc import stable_largest_connected_component
@ -60,9 +61,6 @@ def _compute_leiden_communities(
seed: int | None = None,
) -> tuple[dict[int, dict[str, int]], dict[int, int]]:
"""Return Leiden root communities and their hierarchy mapping."""
# NOTE: This import is done here to reduce the initial import time of the graphrag package
from graspologic.partition import hierarchical_leiden
if use_lcc:
graph = stable_largest_connected_component(graph)

View File

@ -0,0 +1,242 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Collection of graph utility functions."""
import logging
from typing import cast
import networkx as nx
import numpy as np
import pandas as pd
from graspologic.partition import hierarchical_leiden, modularity
from graspologic.utils import largest_connected_component
from graphrag.config.enums import ModularityMetric
logger = logging.getLogger(__name__)
def calculate_root_modularity(
graph: nx.Graph,
max_cluster_size: int = 10,
random_seed: int = 0xDEADBEEF,
) -> float:
"""Calculate distance between the modularity of the graph's root clusters and the target modularity."""
hcs = hierarchical_leiden(
graph, max_cluster_size=max_cluster_size, random_seed=random_seed
)
root_clusters = hcs.first_level_hierarchical_clustering()
return modularity(graph, root_clusters)
def calculate_leaf_modularity(
graph: nx.Graph,
max_cluster_size: int = 10,
random_seed: int = 0xDEADBEEF,
) -> float:
"""Calculate distance between the modularity of the graph's leaf clusters and the target modularity."""
hcs = hierarchical_leiden(
graph, max_cluster_size=max_cluster_size, random_seed=random_seed
)
leaf_clusters = hcs.final_level_hierarchical_clustering()
return modularity(graph, leaf_clusters)
def calculate_graph_modularity(
graph: nx.Graph,
max_cluster_size: int = 10,
random_seed: int = 0xDEADBEEF,
use_root_modularity: bool = True,
) -> float:
"""Calculate modularity of the whole graph."""
if use_root_modularity:
return calculate_root_modularity(
graph, max_cluster_size=max_cluster_size, random_seed=random_seed
)
return calculate_leaf_modularity(
graph, max_cluster_size=max_cluster_size, random_seed=random_seed
)
def calculate_lcc_modularity(
graph: nx.Graph,
max_cluster_size: int = 10,
random_seed: int = 0xDEADBEEF,
use_root_modularity: bool = True,
) -> float:
"""Calculate modularity of the largest connected component of the graph."""
lcc = cast("nx.Graph", largest_connected_component(graph))
if use_root_modularity:
return calculate_root_modularity(
lcc, max_cluster_size=max_cluster_size, random_seed=random_seed
)
return calculate_leaf_modularity(
lcc, max_cluster_size=max_cluster_size, random_seed=random_seed
)
def calculate_weighted_modularity(
graph: nx.Graph,
max_cluster_size: int = 10,
random_seed: int = 0xDEADBEEF,
min_connected_component_size: int = 10,
use_root_modularity: bool = True,
) -> float:
"""
Calculate weighted modularity of all connected components with size greater than min_connected_component_size.
Modularity = sum(component_modularity * component_size) / total_nodes.
"""
connected_components: list[set] = list(nx.connected_components(graph))
filtered_components = [
component
for component in connected_components
if len(component) > min_connected_component_size
]
if len(filtered_components) == 0:
filtered_components = [graph]
total_nodes = sum(len(component) for component in filtered_components)
total_modularity = 0
for component in filtered_components:
if len(component) > min_connected_component_size:
subgraph = graph.subgraph(component)
if use_root_modularity:
modularity = calculate_root_modularity(
subgraph, max_cluster_size=max_cluster_size, random_seed=random_seed
)
else:
modularity = calculate_leaf_modularity(
subgraph, max_cluster_size=max_cluster_size, random_seed=random_seed
)
total_modularity += modularity * len(component) / total_nodes
return total_modularity
def calculate_modularity(
graph: nx.Graph,
max_cluster_size: int = 10,
random_seed: int = 0xDEADBEEF,
use_root_modularity: bool = True,
modularity_metric: ModularityMetric = ModularityMetric.WeightedComponents,
) -> float:
"""Calculate modularity of the graph based on the modularity metric type."""
match modularity_metric:
case ModularityMetric.Graph:
logger.info("Calculating graph modularity")
return calculate_graph_modularity(
graph,
max_cluster_size=max_cluster_size,
random_seed=random_seed,
use_root_modularity=use_root_modularity,
)
case ModularityMetric.LCC:
logger.info("Calculating LCC modularity")
return calculate_lcc_modularity(
graph,
max_cluster_size=max_cluster_size,
random_seed=random_seed,
use_root_modularity=use_root_modularity,
)
case ModularityMetric.WeightedComponents:
logger.info("Calculating weighted-components modularity")
return calculate_weighted_modularity(
graph,
max_cluster_size=max_cluster_size,
random_seed=random_seed,
use_root_modularity=use_root_modularity,
)
case _:
msg = f"Unknown modularity metric type: {modularity_metric}"
raise ValueError(msg)
def calculate_pmi_edge_weights(
nodes_df: pd.DataFrame,
edges_df: pd.DataFrame,
node_name_col: str = "title",
node_freq_col: str = "frequency",
edge_weight_col: str = "weight",
edge_source_col: str = "source",
edge_target_col: str = "target",
) -> pd.DataFrame:
"""
Calculate pointwise mutual information (PMI) edge weights.
Uses a variant of PMI that accounts for bias towards low-frequency events.
pmi(x,y) = p(x,y) * log2(p(x,y)/ (p(x)*p(y))
p(x,y) = edge_weight(x,y) / total_edge_weights
p(x) = freq_occurrence(x) / total_freq_occurrences.
"""
copied_nodes_df = nodes_df[[node_name_col, node_freq_col]]
total_edge_weights = edges_df[edge_weight_col].sum()
total_freq_occurrences = nodes_df[node_freq_col].sum()
copied_nodes_df["prop_occurrence"] = (
copied_nodes_df[node_freq_col] / total_freq_occurrences
)
copied_nodes_df = copied_nodes_df.loc[:, [node_name_col, "prop_occurrence"]]
edges_df["prop_weight"] = edges_df[edge_weight_col] / total_edge_weights
edges_df = (
edges_df.merge(
copied_nodes_df, left_on=edge_source_col, right_on=node_name_col, how="left"
)
.drop(columns=[node_name_col])
.rename(columns={"prop_occurrence": "source_prop"})
)
edges_df = (
edges_df.merge(
copied_nodes_df, left_on=edge_target_col, right_on=node_name_col, how="left"
)
.drop(columns=[node_name_col])
.rename(columns={"prop_occurrence": "target_prop"})
)
edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2(
edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"])
)
return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"])
def calculate_rrf_edge_weights(
nodes_df: pd.DataFrame,
edges_df: pd.DataFrame,
node_name_col="title",
node_freq_col="freq",
edge_weight_col="weight",
edge_source_col="source",
edge_target_col="target",
rrf_smoothing_factor: int = 60,
) -> pd.DataFrame:
"""Calculate reciprocal rank fusion (RRF) edge weights as a combination of PMI weight and combined freq of source and target."""
edges_df = calculate_pmi_edge_weights(
nodes_df,
edges_df,
node_name_col,
node_freq_col,
edge_weight_col,
edge_source_col,
edge_target_col,
)
edges_df["pmi_rank"] = edges_df[edge_weight_col].rank(method="min", ascending=False)
edges_df["raw_weight_rank"] = edges_df[edge_weight_col].rank(
method="min", ascending=False
)
edges_df[edge_weight_col] = edges_df.apply(
lambda x: (1 / (rrf_smoothing_factor + x["pmi_rank"]))
+ (1 / (rrf_smoothing_factor + x["raw_weight_rank"])),
axis=1,
)
return edges_df.drop(columns=["pmi_rank", "raw_weight_rank"])
def get_upper_threshold_by_std(data: list[float] | list[int], std_trim: float) -> float:
"""Get upper threshold by standard deviation."""
mean = np.mean(data)
std = np.std(data)
return cast("float", mean + std_trim * std)

View File

@ -16,11 +16,10 @@ from graphrag.config.embeddings import (
document_text_embedding,
entity_description_embedding,
entity_title_embedding,
get_embedded_fields,
get_embedding_settings,
relationship_description_embedding,
text_unit_text_embedding,
)
from graphrag.config.get_embedding_settings import get_embedding_settings
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.typing.context import PipelineRunContext
@ -57,7 +56,7 @@ async def run_workflow(
"community_reports", context.storage
)
embedded_fields = get_embedded_fields(config)
embedded_fields = config.embed_text.names
text_embed = get_embedding_settings(config)
output = await generate_text_embeddings(
@ -92,7 +91,7 @@ async def generate_text_embeddings(
callbacks: WorkflowCallbacks,
cache: PipelineCache,
text_embed_config: dict,
embedded_fields: set[str],
embedded_fields: list[str],
) -> dict[str, pd.DataFrame]:
"""All the steps to generate all embeddings."""
embedding_param_map = {
@ -148,20 +147,20 @@ async def generate_text_embeddings(
outputs = {}
for field in embedded_fields:
if embedding_param_map[field]["data"] is None:
msg = f"Embedding {field} is specified but data table is not in storage."
raise ValueError(msg)
outputs[field] = await _run_and_snapshot_embeddings(
name=field,
callbacks=callbacks,
cache=cache,
text_embed_config=text_embed_config,
**embedding_param_map[field],
)
msg = f"Embedding {field} is specified but data table is not in storage. This may or may not be intentional - if you expect it to me here, please check for errors earlier in the logs."
log.warning(msg)
else:
outputs[field] = await _run_embeddings(
name=field,
callbacks=callbacks,
cache=cache,
text_embed_config=text_embed_config,
**embedding_param_map[field],
)
return outputs
async def _run_and_snapshot_embeddings(
async def _run_embeddings(
name: str,
data: pd.DataFrame,
embed_column: str,

View File

@ -5,7 +5,7 @@
import logging
from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings
from graphrag.config.get_embedding_settings import get_embedding_settings
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.utils import get_update_storages
from graphrag.index.typing.context import PipelineRunContext
@ -34,7 +34,7 @@ async def run_workflow(
"incremental_update_merged_community_reports"
]
embedded_fields = get_embedded_fields(config)
embedded_fields = config.embed_text.names
text_embed = get_embedding_settings(config)
result = await generate_text_embeddings(
documents=final_documents_df,

View File

@ -154,6 +154,7 @@ def read_communities(
level_col: str = "level",
entities_col: str | None = "entity_ids",
relationships_col: str | None = "relationship_ids",
text_units_col: str | None = "text_unit_ids",
covariates_col: str | None = "covariate_ids",
parent_col: str | None = "parent",
children_col: str | None = "children",
@ -171,6 +172,7 @@ def read_communities(
level=to_str(row, level_col),
entity_ids=to_optional_list(row, entities_col, item_type=str),
relationship_ids=to_optional_list(row, relationships_col, item_type=str),
text_unit_ids=to_optional_list(row, text_units_col, item_type=str),
covariate_ids=to_optional_dict(
row, covariates_col, key_type=str, value_type=str
),

View File

@ -194,7 +194,6 @@ def assert_text_embedding_configs(
) -> None:
assert actual.batch_size == expected.batch_size
assert actual.batch_max_tokens == expected.batch_max_tokens
assert actual.target == expected.target
assert actual.names == expected.names
assert actual.strategy == expected.strategy
assert actual.model_id == expected.model_id

View File

@ -5,7 +5,7 @@ from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.embeddings import (
all_embeddings,
)
from graphrag.config.enums import ModelType, TextEmbeddingTarget
from graphrag.config.enums import ModelType
from graphrag.index.operations.embed_text.embed_text import TextEmbedStrategyType
from graphrag.index.workflows.generate_text_embeddings import (
run_workflow,
@ -39,7 +39,7 @@ async def test_generate_text_embeddings():
"type": TextEmbedStrategyType.openai,
"llm": llm_settings,
}
config.embed_text.target = TextEmbeddingTarget.all
config.embed_text.names = list(all_embeddings)
config.snapshots.embeddings = True
await run_workflow(config, context)