mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-07 20:50:29 +00:00
Various minor updates (#1932)
* Add text unit ids to Community model * Add graph utilities * Turn off LCC for clustering by default * Simplify embeddings config/flow * Semver
This commit is contained in:
parent
ee1b2db4a0
commit
36948b8d2e
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "A few fixes and enhancements for better reuse and flow."
|
||||
}
|
||||
@ -201,8 +201,7 @@ Supported embeddings names are:
|
||||
- `vector_store_id` **str** - Name of vector store definition to write to.
|
||||
- `batch_size` **int** - The maximum batch size to use.
|
||||
- `batch_max_tokens` **int** - The maximum batch # of tokens.
|
||||
- `target` **required|all|selected|none** - Determines which set of embeddings to export.
|
||||
- `names` **list[str]** - If target=selected, this should be an explicit list of the embeddings names we support.
|
||||
- `names` **list[str]** - List of the embeddings names to run (must be in supported list).
|
||||
|
||||
### extract_graph
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from graphrag.config.embeddings import default_embeddings
|
||||
from graphrag.config.enums import (
|
||||
AsyncType,
|
||||
AuthType,
|
||||
@ -18,7 +19,6 @@ from graphrag.config.enums import (
|
||||
NounPhraseExtractorType,
|
||||
OutputType,
|
||||
ReportingType,
|
||||
TextEmbeddingTarget,
|
||||
)
|
||||
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
|
||||
EN_STOP_WORDS,
|
||||
@ -147,9 +147,8 @@ class EmbedTextDefaults:
|
||||
model: str = "text-embedding-3-small"
|
||||
batch_size: int = 16
|
||||
batch_max_tokens: int = 8191
|
||||
target = TextEmbeddingTarget.required
|
||||
model_id: str = DEFAULT_EMBEDDING_MODEL_ID
|
||||
names: list[str] = field(default_factory=list)
|
||||
names: list[str] = field(default_factory=lambda: default_embeddings)
|
||||
strategy: None = None
|
||||
vector_store_id: str = DEFAULT_VECTOR_STORE_ID
|
||||
|
||||
|
||||
@ -3,9 +3,6 @@
|
||||
|
||||
"""A module containing embeddings values."""
|
||||
|
||||
from graphrag.config.enums import TextEmbeddingTarget
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
entity_title_embedding = "entity.title"
|
||||
entity_description_embedding = "entity.description"
|
||||
relationship_description_embedding = "relationship.description"
|
||||
@ -25,60 +22,11 @@ all_embeddings: set[str] = {
|
||||
community_full_content_embedding,
|
||||
text_unit_text_embedding,
|
||||
}
|
||||
required_embeddings: set[str] = {
|
||||
default_embeddings: list[str] = [
|
||||
entity_description_embedding,
|
||||
community_full_content_embedding,
|
||||
text_unit_text_embedding,
|
||||
}
|
||||
|
||||
|
||||
def get_embedded_fields(settings: GraphRagConfig) -> set[str]:
|
||||
"""Get the fields to embed based on the enum or specifically selected embeddings."""
|
||||
match settings.embed_text.target:
|
||||
case TextEmbeddingTarget.all:
|
||||
return all_embeddings
|
||||
case TextEmbeddingTarget.required:
|
||||
return required_embeddings
|
||||
case TextEmbeddingTarget.selected:
|
||||
return set(settings.embed_text.names)
|
||||
case TextEmbeddingTarget.none:
|
||||
return set()
|
||||
case _:
|
||||
msg = f"Unknown embeddings target: {settings.embed_text.target}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def get_embedding_settings(
|
||||
settings: GraphRagConfig,
|
||||
vector_store_params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Transform GraphRAG config into settings for workflows."""
|
||||
# TEMP
|
||||
embeddings_llm_settings = settings.get_language_model_config(
|
||||
settings.embed_text.model_id
|
||||
)
|
||||
vector_store_settings = settings.get_vector_store_config(
|
||||
settings.embed_text.vector_store_id
|
||||
).model_dump()
|
||||
|
||||
#
|
||||
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
|
||||
# settings.vector_store.base contains connection information, or may be undefined
|
||||
# settings.vector_store.<vector_name> contains the specific settings for this embedding
|
||||
#
|
||||
strategy = settings.embed_text.resolved_strategy(
|
||||
embeddings_llm_settings
|
||||
) # get the default strategy
|
||||
strategy.update({
|
||||
"vector_store": {
|
||||
**(vector_store_params or {}),
|
||||
**(vector_store_settings),
|
||||
}
|
||||
}) # update the default strategy with the vector store settings
|
||||
# This ensures the vector store config is part of the strategy and not the global config
|
||||
return {
|
||||
"strategy": strategy,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def create_collection_name(
|
||||
|
||||
@ -87,19 +87,6 @@ class ReportingType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class TextEmbeddingTarget(str, Enum):
|
||||
"""The target to use for text embeddings."""
|
||||
|
||||
all = "all"
|
||||
required = "required"
|
||||
selected = "selected"
|
||||
none = "none"
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
"""LLMType enum class definition."""
|
||||
|
||||
@ -176,3 +163,15 @@ class NounPhraseExtractorType(str, Enum):
|
||||
"""Noun phrase extractor based on dependency parsing and NER using SpaCy."""
|
||||
CFG = "cfg"
|
||||
"""Noun phrase extractor combining CFG-based noun-chunk extraction and NER."""
|
||||
|
||||
|
||||
class ModularityMetric(str, Enum):
|
||||
"""Enum for the modularity metric to use."""
|
||||
|
||||
Graph = "graph"
|
||||
"""Graph modularity metric."""
|
||||
|
||||
LCC = "lcc"
|
||||
|
||||
WeightedComponents = "weighted_components"
|
||||
"""Weighted components modularity metric."""
|
||||
|
||||
39
graphrag/config/get_embedding_settings.py
Normal file
39
graphrag/config/get_embedding_settings.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing get_embedding_settings."""
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
|
||||
def get_embedding_settings(
|
||||
settings: GraphRagConfig,
|
||||
vector_store_params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Transform GraphRAG config into settings for workflows."""
|
||||
# TEMP
|
||||
embeddings_llm_settings = settings.get_language_model_config(
|
||||
settings.embed_text.model_id
|
||||
)
|
||||
vector_store_settings = settings.get_vector_store_config(
|
||||
settings.embed_text.vector_store_id
|
||||
).model_dump()
|
||||
|
||||
#
|
||||
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
|
||||
# settings.vector_store.base contains connection information, or may be undefined
|
||||
# settings.vector_store.<vector_name> contains the specific settings for this embedding
|
||||
#
|
||||
strategy = settings.embed_text.resolved_strategy(
|
||||
embeddings_llm_settings
|
||||
) # get the default strategy
|
||||
strategy.update({
|
||||
"vector_store": {
|
||||
**(vector_store_params or {}),
|
||||
**(vector_store_settings),
|
||||
}
|
||||
}) # update the default strategy with the vector store settings
|
||||
# This ensures the vector store config is part of the strategy and not the global config
|
||||
return {
|
||||
"strategy": strategy,
|
||||
}
|
||||
@ -6,7 +6,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import TextEmbeddingTarget
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
@ -29,10 +28,6 @@ class TextEmbeddingConfig(BaseModel):
|
||||
description="The batch max tokens to use.",
|
||||
default=graphrag_config_defaults.embed_text.batch_max_tokens,
|
||||
)
|
||||
target: TextEmbeddingTarget = Field(
|
||||
description="The target to use. 'all', 'required', 'selected', or 'none'.",
|
||||
default=graphrag_config_defaults.embed_text.target,
|
||||
)
|
||||
names: list[str] = Field(
|
||||
description="The specific embeddings to perform.",
|
||||
default=graphrag_config_defaults.embed_text.names,
|
||||
|
||||
@ -28,6 +28,9 @@ class Community(Named):
|
||||
relationship_ids: list[str] | None = None
|
||||
"""List of relationship IDs related to the community (optional)."""
|
||||
|
||||
text_unit_ids: list[str] | None = None
|
||||
"""List of text unit IDs related to the community (optional)."""
|
||||
|
||||
covariate_ids: dict[str, list[str]] | None = None
|
||||
"""Dictionary of different types of covariates related to the community (optional), e.g. claims"""
|
||||
|
||||
@ -50,6 +53,7 @@ class Community(Named):
|
||||
level_key: str = "level",
|
||||
entities_key: str = "entity_ids",
|
||||
relationships_key: str = "relationship_ids",
|
||||
text_units_key: str = "text_unit_ids",
|
||||
covariates_key: str = "covariate_ids",
|
||||
parent_key: str = "parent",
|
||||
children_key: str = "children",
|
||||
@ -67,6 +71,7 @@ class Community(Named):
|
||||
short_id=d.get(short_id_key),
|
||||
entity_ids=d.get(entities_key),
|
||||
relationship_ids=d.get(relationships_key),
|
||||
text_unit_ids=d.get(text_units_key),
|
||||
covariate_ids=d.get(covariates_key),
|
||||
attributes=d.get(attributes_key),
|
||||
size=d.get(size_key),
|
||||
|
||||
@ -15,6 +15,7 @@ from graphrag.index.operations.build_noun_graph.np_extractors.base import (
|
||||
BaseNounPhraseExtractor,
|
||||
)
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
from graphrag.index.utils.graphs import calculate_pmi_edge_weights
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
|
||||
|
||||
@ -127,52 +128,6 @@ def _extract_edges(
|
||||
]
|
||||
if normalize_edge_weights:
|
||||
# use PMI weight instead of raw weight
|
||||
grouped_edge_df = _calculate_pmi_edge_weights(nodes_df, grouped_edge_df)
|
||||
grouped_edge_df = calculate_pmi_edge_weights(nodes_df, grouped_edge_df)
|
||||
|
||||
return grouped_edge_df
|
||||
|
||||
|
||||
def _calculate_pmi_edge_weights(
|
||||
nodes_df: pd.DataFrame,
|
||||
edges_df: pd.DataFrame,
|
||||
node_name_col="title",
|
||||
node_freq_col="frequency",
|
||||
edge_weight_col="weight",
|
||||
edge_source_col="source",
|
||||
edge_target_col="target",
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate pointwise mutual information (PMI) edge weights.
|
||||
|
||||
pmi(x,y) = log2(p(x,y) / (p(x)p(y)))
|
||||
p(x,y) = edge_weight(x,y) / total_edge_weights
|
||||
p(x) = freq_occurrence(x) / total_freq_occurrences
|
||||
"""
|
||||
copied_nodes_df = nodes_df[[node_name_col, node_freq_col]]
|
||||
|
||||
total_edge_weights = edges_df[edge_weight_col].sum()
|
||||
total_freq_occurrences = nodes_df[node_freq_col].sum()
|
||||
copied_nodes_df["prop_occurrence"] = (
|
||||
copied_nodes_df[node_freq_col] / total_freq_occurrences
|
||||
)
|
||||
copied_nodes_df = copied_nodes_df.loc[:, [node_name_col, "prop_occurrence"]]
|
||||
|
||||
edges_df["prop_weight"] = edges_df[edge_weight_col] / total_edge_weights
|
||||
edges_df = (
|
||||
edges_df.merge(
|
||||
copied_nodes_df, left_on=edge_source_col, right_on=node_name_col, how="left"
|
||||
)
|
||||
.drop(columns=[node_name_col])
|
||||
.rename(columns={"prop_occurrence": "source_prop"})
|
||||
)
|
||||
edges_df = (
|
||||
edges_df.merge(
|
||||
copied_nodes_df, left_on=edge_target_col, right_on=node_name_col, how="left"
|
||||
)
|
||||
.drop(columns=[node_name_col])
|
||||
.rename(columns={"prop_occurrence": "target_prop"})
|
||||
)
|
||||
edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2(
|
||||
edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"])
|
||||
)
|
||||
return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"])
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
import logging
|
||||
|
||||
import networkx as nx
|
||||
from graspologic.partition import hierarchical_leiden
|
||||
|
||||
from graphrag.index.utils.stable_lcc import stable_largest_connected_component
|
||||
|
||||
@ -60,9 +61,6 @@ def _compute_leiden_communities(
|
||||
seed: int | None = None,
|
||||
) -> tuple[dict[int, dict[str, int]], dict[int, int]]:
|
||||
"""Return Leiden root communities and their hierarchy mapping."""
|
||||
# NOTE: This import is done here to reduce the initial import time of the graphrag package
|
||||
from graspologic.partition import hierarchical_leiden
|
||||
|
||||
if use_lcc:
|
||||
graph = stable_largest_connected_component(graph)
|
||||
|
||||
|
||||
242
graphrag/index/utils/graphs.py
Normal file
242
graphrag/index/utils/graphs.py
Normal file
@ -0,0 +1,242 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Collection of graph utility functions."""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from graspologic.partition import hierarchical_leiden, modularity
|
||||
from graspologic.utils import largest_connected_component
|
||||
|
||||
from graphrag.config.enums import ModularityMetric
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_root_modularity(
|
||||
graph: nx.Graph,
|
||||
max_cluster_size: int = 10,
|
||||
random_seed: int = 0xDEADBEEF,
|
||||
) -> float:
|
||||
"""Calculate distance between the modularity of the graph's root clusters and the target modularity."""
|
||||
hcs = hierarchical_leiden(
|
||||
graph, max_cluster_size=max_cluster_size, random_seed=random_seed
|
||||
)
|
||||
root_clusters = hcs.first_level_hierarchical_clustering()
|
||||
return modularity(graph, root_clusters)
|
||||
|
||||
|
||||
def calculate_leaf_modularity(
|
||||
graph: nx.Graph,
|
||||
max_cluster_size: int = 10,
|
||||
random_seed: int = 0xDEADBEEF,
|
||||
) -> float:
|
||||
"""Calculate distance between the modularity of the graph's leaf clusters and the target modularity."""
|
||||
hcs = hierarchical_leiden(
|
||||
graph, max_cluster_size=max_cluster_size, random_seed=random_seed
|
||||
)
|
||||
leaf_clusters = hcs.final_level_hierarchical_clustering()
|
||||
return modularity(graph, leaf_clusters)
|
||||
|
||||
|
||||
def calculate_graph_modularity(
|
||||
graph: nx.Graph,
|
||||
max_cluster_size: int = 10,
|
||||
random_seed: int = 0xDEADBEEF,
|
||||
use_root_modularity: bool = True,
|
||||
) -> float:
|
||||
"""Calculate modularity of the whole graph."""
|
||||
if use_root_modularity:
|
||||
return calculate_root_modularity(
|
||||
graph, max_cluster_size=max_cluster_size, random_seed=random_seed
|
||||
)
|
||||
return calculate_leaf_modularity(
|
||||
graph, max_cluster_size=max_cluster_size, random_seed=random_seed
|
||||
)
|
||||
|
||||
|
||||
def calculate_lcc_modularity(
|
||||
graph: nx.Graph,
|
||||
max_cluster_size: int = 10,
|
||||
random_seed: int = 0xDEADBEEF,
|
||||
use_root_modularity: bool = True,
|
||||
) -> float:
|
||||
"""Calculate modularity of the largest connected component of the graph."""
|
||||
lcc = cast("nx.Graph", largest_connected_component(graph))
|
||||
if use_root_modularity:
|
||||
return calculate_root_modularity(
|
||||
lcc, max_cluster_size=max_cluster_size, random_seed=random_seed
|
||||
)
|
||||
return calculate_leaf_modularity(
|
||||
lcc, max_cluster_size=max_cluster_size, random_seed=random_seed
|
||||
)
|
||||
|
||||
|
||||
def calculate_weighted_modularity(
|
||||
graph: nx.Graph,
|
||||
max_cluster_size: int = 10,
|
||||
random_seed: int = 0xDEADBEEF,
|
||||
min_connected_component_size: int = 10,
|
||||
use_root_modularity: bool = True,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate weighted modularity of all connected components with size greater than min_connected_component_size.
|
||||
|
||||
Modularity = sum(component_modularity * component_size) / total_nodes.
|
||||
"""
|
||||
connected_components: list[set] = list(nx.connected_components(graph))
|
||||
filtered_components = [
|
||||
component
|
||||
for component in connected_components
|
||||
if len(component) > min_connected_component_size
|
||||
]
|
||||
if len(filtered_components) == 0:
|
||||
filtered_components = [graph]
|
||||
|
||||
total_nodes = sum(len(component) for component in filtered_components)
|
||||
total_modularity = 0
|
||||
for component in filtered_components:
|
||||
if len(component) > min_connected_component_size:
|
||||
subgraph = graph.subgraph(component)
|
||||
if use_root_modularity:
|
||||
modularity = calculate_root_modularity(
|
||||
subgraph, max_cluster_size=max_cluster_size, random_seed=random_seed
|
||||
)
|
||||
else:
|
||||
modularity = calculate_leaf_modularity(
|
||||
subgraph, max_cluster_size=max_cluster_size, random_seed=random_seed
|
||||
)
|
||||
total_modularity += modularity * len(component) / total_nodes
|
||||
return total_modularity
|
||||
|
||||
|
||||
def calculate_modularity(
|
||||
graph: nx.Graph,
|
||||
max_cluster_size: int = 10,
|
||||
random_seed: int = 0xDEADBEEF,
|
||||
use_root_modularity: bool = True,
|
||||
modularity_metric: ModularityMetric = ModularityMetric.WeightedComponents,
|
||||
) -> float:
|
||||
"""Calculate modularity of the graph based on the modularity metric type."""
|
||||
match modularity_metric:
|
||||
case ModularityMetric.Graph:
|
||||
logger.info("Calculating graph modularity")
|
||||
return calculate_graph_modularity(
|
||||
graph,
|
||||
max_cluster_size=max_cluster_size,
|
||||
random_seed=random_seed,
|
||||
use_root_modularity=use_root_modularity,
|
||||
)
|
||||
case ModularityMetric.LCC:
|
||||
logger.info("Calculating LCC modularity")
|
||||
return calculate_lcc_modularity(
|
||||
graph,
|
||||
max_cluster_size=max_cluster_size,
|
||||
random_seed=random_seed,
|
||||
use_root_modularity=use_root_modularity,
|
||||
)
|
||||
case ModularityMetric.WeightedComponents:
|
||||
logger.info("Calculating weighted-components modularity")
|
||||
return calculate_weighted_modularity(
|
||||
graph,
|
||||
max_cluster_size=max_cluster_size,
|
||||
random_seed=random_seed,
|
||||
use_root_modularity=use_root_modularity,
|
||||
)
|
||||
case _:
|
||||
msg = f"Unknown modularity metric type: {modularity_metric}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def calculate_pmi_edge_weights(
|
||||
nodes_df: pd.DataFrame,
|
||||
edges_df: pd.DataFrame,
|
||||
node_name_col: str = "title",
|
||||
node_freq_col: str = "frequency",
|
||||
edge_weight_col: str = "weight",
|
||||
edge_source_col: str = "source",
|
||||
edge_target_col: str = "target",
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate pointwise mutual information (PMI) edge weights.
|
||||
|
||||
Uses a variant of PMI that accounts for bias towards low-frequency events.
|
||||
pmi(x,y) = p(x,y) * log2(p(x,y)/ (p(x)*p(y))
|
||||
p(x,y) = edge_weight(x,y) / total_edge_weights
|
||||
p(x) = freq_occurrence(x) / total_freq_occurrences.
|
||||
|
||||
"""
|
||||
copied_nodes_df = nodes_df[[node_name_col, node_freq_col]]
|
||||
|
||||
total_edge_weights = edges_df[edge_weight_col].sum()
|
||||
total_freq_occurrences = nodes_df[node_freq_col].sum()
|
||||
copied_nodes_df["prop_occurrence"] = (
|
||||
copied_nodes_df[node_freq_col] / total_freq_occurrences
|
||||
)
|
||||
copied_nodes_df = copied_nodes_df.loc[:, [node_name_col, "prop_occurrence"]]
|
||||
|
||||
edges_df["prop_weight"] = edges_df[edge_weight_col] / total_edge_weights
|
||||
edges_df = (
|
||||
edges_df.merge(
|
||||
copied_nodes_df, left_on=edge_source_col, right_on=node_name_col, how="left"
|
||||
)
|
||||
.drop(columns=[node_name_col])
|
||||
.rename(columns={"prop_occurrence": "source_prop"})
|
||||
)
|
||||
edges_df = (
|
||||
edges_df.merge(
|
||||
copied_nodes_df, left_on=edge_target_col, right_on=node_name_col, how="left"
|
||||
)
|
||||
.drop(columns=[node_name_col])
|
||||
.rename(columns={"prop_occurrence": "target_prop"})
|
||||
)
|
||||
edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2(
|
||||
edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"])
|
||||
)
|
||||
|
||||
return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"])
|
||||
|
||||
|
||||
def calculate_rrf_edge_weights(
|
||||
nodes_df: pd.DataFrame,
|
||||
edges_df: pd.DataFrame,
|
||||
node_name_col="title",
|
||||
node_freq_col="freq",
|
||||
edge_weight_col="weight",
|
||||
edge_source_col="source",
|
||||
edge_target_col="target",
|
||||
rrf_smoothing_factor: int = 60,
|
||||
) -> pd.DataFrame:
|
||||
"""Calculate reciprocal rank fusion (RRF) edge weights as a combination of PMI weight and combined freq of source and target."""
|
||||
edges_df = calculate_pmi_edge_weights(
|
||||
nodes_df,
|
||||
edges_df,
|
||||
node_name_col,
|
||||
node_freq_col,
|
||||
edge_weight_col,
|
||||
edge_source_col,
|
||||
edge_target_col,
|
||||
)
|
||||
|
||||
edges_df["pmi_rank"] = edges_df[edge_weight_col].rank(method="min", ascending=False)
|
||||
edges_df["raw_weight_rank"] = edges_df[edge_weight_col].rank(
|
||||
method="min", ascending=False
|
||||
)
|
||||
edges_df[edge_weight_col] = edges_df.apply(
|
||||
lambda x: (1 / (rrf_smoothing_factor + x["pmi_rank"]))
|
||||
+ (1 / (rrf_smoothing_factor + x["raw_weight_rank"])),
|
||||
axis=1,
|
||||
)
|
||||
|
||||
return edges_df.drop(columns=["pmi_rank", "raw_weight_rank"])
|
||||
|
||||
|
||||
def get_upper_threshold_by_std(data: list[float] | list[int], std_trim: float) -> float:
|
||||
"""Get upper threshold by standard deviation."""
|
||||
mean = np.mean(data)
|
||||
std = np.std(data)
|
||||
return cast("float", mean + std_trim * std)
|
||||
@ -16,11 +16,10 @@ from graphrag.config.embeddings import (
|
||||
document_text_embedding,
|
||||
entity_description_embedding,
|
||||
entity_title_embedding,
|
||||
get_embedded_fields,
|
||||
get_embedding_settings,
|
||||
relationship_description_embedding,
|
||||
text_unit_text_embedding,
|
||||
)
|
||||
from graphrag.config.get_embedding_settings import get_embedding_settings
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
@ -57,7 +56,7 @@ async def run_workflow(
|
||||
"community_reports", context.storage
|
||||
)
|
||||
|
||||
embedded_fields = get_embedded_fields(config)
|
||||
embedded_fields = config.embed_text.names
|
||||
text_embed = get_embedding_settings(config)
|
||||
|
||||
output = await generate_text_embeddings(
|
||||
@ -92,7 +91,7 @@ async def generate_text_embeddings(
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_embed_config: dict,
|
||||
embedded_fields: set[str],
|
||||
embedded_fields: list[str],
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""All the steps to generate all embeddings."""
|
||||
embedding_param_map = {
|
||||
@ -148,20 +147,20 @@ async def generate_text_embeddings(
|
||||
outputs = {}
|
||||
for field in embedded_fields:
|
||||
if embedding_param_map[field]["data"] is None:
|
||||
msg = f"Embedding {field} is specified but data table is not in storage."
|
||||
raise ValueError(msg)
|
||||
|
||||
outputs[field] = await _run_and_snapshot_embeddings(
|
||||
name=field,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
text_embed_config=text_embed_config,
|
||||
**embedding_param_map[field],
|
||||
)
|
||||
msg = f"Embedding {field} is specified but data table is not in storage. This may or may not be intentional - if you expect it to me here, please check for errors earlier in the logs."
|
||||
log.warning(msg)
|
||||
else:
|
||||
outputs[field] = await _run_embeddings(
|
||||
name=field,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
text_embed_config=text_embed_config,
|
||||
**embedding_param_map[field],
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
async def _run_and_snapshot_embeddings(
|
||||
async def _run_embeddings(
|
||||
name: str,
|
||||
data: pd.DataFrame,
|
||||
embed_column: str,
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings
|
||||
from graphrag.config.get_embedding_settings import get_embedding_settings
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.run.utils import get_update_storages
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
@ -34,7 +34,7 @@ async def run_workflow(
|
||||
"incremental_update_merged_community_reports"
|
||||
]
|
||||
|
||||
embedded_fields = get_embedded_fields(config)
|
||||
embedded_fields = config.embed_text.names
|
||||
text_embed = get_embedding_settings(config)
|
||||
result = await generate_text_embeddings(
|
||||
documents=final_documents_df,
|
||||
|
||||
@ -154,6 +154,7 @@ def read_communities(
|
||||
level_col: str = "level",
|
||||
entities_col: str | None = "entity_ids",
|
||||
relationships_col: str | None = "relationship_ids",
|
||||
text_units_col: str | None = "text_unit_ids",
|
||||
covariates_col: str | None = "covariate_ids",
|
||||
parent_col: str | None = "parent",
|
||||
children_col: str | None = "children",
|
||||
@ -171,6 +172,7 @@ def read_communities(
|
||||
level=to_str(row, level_col),
|
||||
entity_ids=to_optional_list(row, entities_col, item_type=str),
|
||||
relationship_ids=to_optional_list(row, relationships_col, item_type=str),
|
||||
text_unit_ids=to_optional_list(row, text_units_col, item_type=str),
|
||||
covariate_ids=to_optional_dict(
|
||||
row, covariates_col, key_type=str, value_type=str
|
||||
),
|
||||
|
||||
@ -194,7 +194,6 @@ def assert_text_embedding_configs(
|
||||
) -> None:
|
||||
assert actual.batch_size == expected.batch_size
|
||||
assert actual.batch_max_tokens == expected.batch_max_tokens
|
||||
assert actual.target == expected.target
|
||||
assert actual.names == expected.names
|
||||
assert actual.strategy == expected.strategy
|
||||
assert actual.model_id == expected.model_id
|
||||
|
||||
@ -5,7 +5,7 @@ from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.embeddings import (
|
||||
all_embeddings,
|
||||
)
|
||||
from graphrag.config.enums import ModelType, TextEmbeddingTarget
|
||||
from graphrag.config.enums import ModelType
|
||||
from graphrag.index.operations.embed_text.embed_text import TextEmbedStrategyType
|
||||
from graphrag.index.workflows.generate_text_embeddings import (
|
||||
run_workflow,
|
||||
@ -39,7 +39,7 @@ async def test_generate_text_embeddings():
|
||||
"type": TextEmbedStrategyType.openai,
|
||||
"llm": llm_settings,
|
||||
}
|
||||
config.embed_text.target = TextEmbeddingTarget.all
|
||||
config.embed_text.names = list(all_embeddings)
|
||||
config.snapshots.embeddings = True
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user