From 36948b8d2e57bf19a857aec7b02cffa4c8b7d53d Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 16 May 2025 14:48:53 -0700 Subject: [PATCH] Various minor updates (#1932) * Add text unit ids to Community model * Add graph utilities * Turn off LCC for clustering by default * Simplify embeddings config/flow * Semver --- .../patch-20250515212234042330.json | 4 + docs/config/yaml.md | 3 +- graphrag/config/defaults.py | 5 +- graphrag/config/embeddings.py | 56 +--- graphrag/config/enums.py | 25 +- graphrag/config/get_embedding_settings.py | 39 +++ .../config/models/text_embedding_config.py | 5 - graphrag/data_model/community.py | 5 + .../build_noun_graph/build_noun_graph.py | 49 +--- graphrag/index/operations/cluster_graph.py | 4 +- graphrag/index/utils/graphs.py | 242 ++++++++++++++++++ .../workflows/generate_text_embeddings.py | 29 +-- .../index/workflows/update_text_embeddings.py | 4 +- graphrag/query/input/loaders/dfs.py | 2 + tests/unit/config/utils.py | 1 - tests/verbs/test_generate_text_embeddings.py | 4 +- 16 files changed, 330 insertions(+), 147 deletions(-) create mode 100644 .semversioner/next-release/patch-20250515212234042330.json create mode 100644 graphrag/config/get_embedding_settings.py create mode 100644 graphrag/index/utils/graphs.py diff --git a/.semversioner/next-release/patch-20250515212234042330.json b/.semversioner/next-release/patch-20250515212234042330.json new file mode 100644 index 00000000..d0199eb5 --- /dev/null +++ b/.semversioner/next-release/patch-20250515212234042330.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "A few fixes and enhancements for better reuse and flow." +} diff --git a/docs/config/yaml.md b/docs/config/yaml.md index 171bc7ab..6e578b19 100644 --- a/docs/config/yaml.md +++ b/docs/config/yaml.md @@ -201,8 +201,7 @@ Supported embeddings names are: - `vector_store_id` **str** - Name of vector store definition to write to. - `batch_size` **int** - The maximum batch size to use. - `batch_max_tokens` **int** - The maximum batch # of tokens. -- `target` **required|all|selected|none** - Determines which set of embeddings to export. -- `names` **list[str]** - If target=selected, this should be an explicit list of the embeddings names we support. +- `names` **list[str]** - List of the embeddings names to run (must be in supported list). ### extract_graph diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index e2d9a6bc..0379f110 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Literal +from graphrag.config.embeddings import default_embeddings from graphrag.config.enums import ( AsyncType, AuthType, @@ -18,7 +19,6 @@ from graphrag.config.enums import ( NounPhraseExtractorType, OutputType, ReportingType, - TextEmbeddingTarget, ) from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import ( EN_STOP_WORDS, @@ -147,9 +147,8 @@ class EmbedTextDefaults: model: str = "text-embedding-3-small" batch_size: int = 16 batch_max_tokens: int = 8191 - target = TextEmbeddingTarget.required model_id: str = DEFAULT_EMBEDDING_MODEL_ID - names: list[str] = field(default_factory=list) + names: list[str] = field(default_factory=lambda: default_embeddings) strategy: None = None vector_store_id: str = DEFAULT_VECTOR_STORE_ID diff --git a/graphrag/config/embeddings.py b/graphrag/config/embeddings.py index 73d85a03..4865da55 100644 --- a/graphrag/config/embeddings.py +++ b/graphrag/config/embeddings.py @@ -3,9 +3,6 @@ """A module containing embeddings values.""" -from graphrag.config.enums import TextEmbeddingTarget -from graphrag.config.models.graph_rag_config import GraphRagConfig - entity_title_embedding = "entity.title" entity_description_embedding = "entity.description" relationship_description_embedding = "relationship.description" @@ -25,60 +22,11 @@ all_embeddings: set[str] = { community_full_content_embedding, text_unit_text_embedding, } -required_embeddings: set[str] = { +default_embeddings: list[str] = [ entity_description_embedding, community_full_content_embedding, text_unit_text_embedding, -} - - -def get_embedded_fields(settings: GraphRagConfig) -> set[str]: - """Get the fields to embed based on the enum or specifically selected embeddings.""" - match settings.embed_text.target: - case TextEmbeddingTarget.all: - return all_embeddings - case TextEmbeddingTarget.required: - return required_embeddings - case TextEmbeddingTarget.selected: - return set(settings.embed_text.names) - case TextEmbeddingTarget.none: - return set() - case _: - msg = f"Unknown embeddings target: {settings.embed_text.target}" - raise ValueError(msg) - - -def get_embedding_settings( - settings: GraphRagConfig, - vector_store_params: dict | None = None, -) -> dict: - """Transform GraphRAG config into settings for workflows.""" - # TEMP - embeddings_llm_settings = settings.get_language_model_config( - settings.embed_text.model_id - ) - vector_store_settings = settings.get_vector_store_config( - settings.embed_text.vector_store_id - ).model_dump() - - # - # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. - # settings.vector_store.base contains connection information, or may be undefined - # settings.vector_store. contains the specific settings for this embedding - # - strategy = settings.embed_text.resolved_strategy( - embeddings_llm_settings - ) # get the default strategy - strategy.update({ - "vector_store": { - **(vector_store_params or {}), - **(vector_store_settings), - } - }) # update the default strategy with the vector store settings - # This ensures the vector store config is part of the strategy and not the global config - return { - "strategy": strategy, - } +] def create_collection_name( diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index e382bf1a..f3efdbd2 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -87,19 +87,6 @@ class ReportingType(str, Enum): return f'"{self.value}"' -class TextEmbeddingTarget(str, Enum): - """The target to use for text embeddings.""" - - all = "all" - required = "required" - selected = "selected" - none = "none" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' - - class ModelType(str, Enum): """LLMType enum class definition.""" @@ -176,3 +163,15 @@ class NounPhraseExtractorType(str, Enum): """Noun phrase extractor based on dependency parsing and NER using SpaCy.""" CFG = "cfg" """Noun phrase extractor combining CFG-based noun-chunk extraction and NER.""" + + +class ModularityMetric(str, Enum): + """Enum for the modularity metric to use.""" + + Graph = "graph" + """Graph modularity metric.""" + + LCC = "lcc" + + WeightedComponents = "weighted_components" + """Weighted components modularity metric.""" diff --git a/graphrag/config/get_embedding_settings.py b/graphrag/config/get_embedding_settings.py new file mode 100644 index 00000000..e15f1ddc --- /dev/null +++ b/graphrag/config/get_embedding_settings.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing get_embedding_settings.""" + +from graphrag.config.models.graph_rag_config import GraphRagConfig + + +def get_embedding_settings( + settings: GraphRagConfig, + vector_store_params: dict | None = None, +) -> dict: + """Transform GraphRAG config into settings for workflows.""" + # TEMP + embeddings_llm_settings = settings.get_language_model_config( + settings.embed_text.model_id + ) + vector_store_settings = settings.get_vector_store_config( + settings.embed_text.vector_store_id + ).model_dump() + + # + # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. + # settings.vector_store.base contains connection information, or may be undefined + # settings.vector_store. contains the specific settings for this embedding + # + strategy = settings.embed_text.resolved_strategy( + embeddings_llm_settings + ) # get the default strategy + strategy.update({ + "vector_store": { + **(vector_store_params or {}), + **(vector_store_settings), + } + }) # update the default strategy with the vector store settings + # This ensures the vector store config is part of the strategy and not the global config + return { + "strategy": strategy, + } diff --git a/graphrag/config/models/text_embedding_config.py b/graphrag/config/models/text_embedding_config.py index 51e655fa..be066a57 100644 --- a/graphrag/config/models/text_embedding_config.py +++ b/graphrag/config/models/text_embedding_config.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, Field from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.enums import TextEmbeddingTarget from graphrag.config.models.language_model_config import LanguageModelConfig @@ -29,10 +28,6 @@ class TextEmbeddingConfig(BaseModel): description="The batch max tokens to use.", default=graphrag_config_defaults.embed_text.batch_max_tokens, ) - target: TextEmbeddingTarget = Field( - description="The target to use. 'all', 'required', 'selected', or 'none'.", - default=graphrag_config_defaults.embed_text.target, - ) names: list[str] = Field( description="The specific embeddings to perform.", default=graphrag_config_defaults.embed_text.names, diff --git a/graphrag/data_model/community.py b/graphrag/data_model/community.py index 34173931..b7016cce 100644 --- a/graphrag/data_model/community.py +++ b/graphrag/data_model/community.py @@ -28,6 +28,9 @@ class Community(Named): relationship_ids: list[str] | None = None """List of relationship IDs related to the community (optional).""" + text_unit_ids: list[str] | None = None + """List of text unit IDs related to the community (optional).""" + covariate_ids: dict[str, list[str]] | None = None """Dictionary of different types of covariates related to the community (optional), e.g. claims""" @@ -50,6 +53,7 @@ class Community(Named): level_key: str = "level", entities_key: str = "entity_ids", relationships_key: str = "relationship_ids", + text_units_key: str = "text_unit_ids", covariates_key: str = "covariate_ids", parent_key: str = "parent", children_key: str = "children", @@ -67,6 +71,7 @@ class Community(Named): short_id=d.get(short_id_key), entity_ids=d.get(entities_key), relationship_ids=d.get(relationships_key), + text_unit_ids=d.get(text_units_key), covariate_ids=d.get(covariates_key), attributes=d.get(attributes_key), size=d.get(size_key), diff --git a/graphrag/index/operations/build_noun_graph/build_noun_graph.py b/graphrag/index/operations/build_noun_graph/build_noun_graph.py index 5036ee9d..0e45e351 100644 --- a/graphrag/index/operations/build_noun_graph/build_noun_graph.py +++ b/graphrag/index/operations/build_noun_graph/build_noun_graph.py @@ -15,6 +15,7 @@ from graphrag.index.operations.build_noun_graph.np_extractors.base import ( BaseNounPhraseExtractor, ) from graphrag.index.utils.derive_from_rows import derive_from_rows +from graphrag.index.utils.graphs import calculate_pmi_edge_weights from graphrag.index.utils.hashing import gen_sha512_hash @@ -127,52 +128,6 @@ def _extract_edges( ] if normalize_edge_weights: # use PMI weight instead of raw weight - grouped_edge_df = _calculate_pmi_edge_weights(nodes_df, grouped_edge_df) + grouped_edge_df = calculate_pmi_edge_weights(nodes_df, grouped_edge_df) return grouped_edge_df - - -def _calculate_pmi_edge_weights( - nodes_df: pd.DataFrame, - edges_df: pd.DataFrame, - node_name_col="title", - node_freq_col="frequency", - edge_weight_col="weight", - edge_source_col="source", - edge_target_col="target", -) -> pd.DataFrame: - """ - Calculate pointwise mutual information (PMI) edge weights. - - pmi(x,y) = log2(p(x,y) / (p(x)p(y))) - p(x,y) = edge_weight(x,y) / total_edge_weights - p(x) = freq_occurrence(x) / total_freq_occurrences - """ - copied_nodes_df = nodes_df[[node_name_col, node_freq_col]] - - total_edge_weights = edges_df[edge_weight_col].sum() - total_freq_occurrences = nodes_df[node_freq_col].sum() - copied_nodes_df["prop_occurrence"] = ( - copied_nodes_df[node_freq_col] / total_freq_occurrences - ) - copied_nodes_df = copied_nodes_df.loc[:, [node_name_col, "prop_occurrence"]] - - edges_df["prop_weight"] = edges_df[edge_weight_col] / total_edge_weights - edges_df = ( - edges_df.merge( - copied_nodes_df, left_on=edge_source_col, right_on=node_name_col, how="left" - ) - .drop(columns=[node_name_col]) - .rename(columns={"prop_occurrence": "source_prop"}) - ) - edges_df = ( - edges_df.merge( - copied_nodes_df, left_on=edge_target_col, right_on=node_name_col, how="left" - ) - .drop(columns=[node_name_col]) - .rename(columns={"prop_occurrence": "target_prop"}) - ) - edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2( - edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"]) - ) - return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"]) diff --git a/graphrag/index/operations/cluster_graph.py b/graphrag/index/operations/cluster_graph.py index 28ee3507..b74d807a 100644 --- a/graphrag/index/operations/cluster_graph.py +++ b/graphrag/index/operations/cluster_graph.py @@ -6,6 +6,7 @@ import logging import networkx as nx +from graspologic.partition import hierarchical_leiden from graphrag.index.utils.stable_lcc import stable_largest_connected_component @@ -60,9 +61,6 @@ def _compute_leiden_communities( seed: int | None = None, ) -> tuple[dict[int, dict[str, int]], dict[int, int]]: """Return Leiden root communities and their hierarchy mapping.""" - # NOTE: This import is done here to reduce the initial import time of the graphrag package - from graspologic.partition import hierarchical_leiden - if use_lcc: graph = stable_largest_connected_component(graph) diff --git a/graphrag/index/utils/graphs.py b/graphrag/index/utils/graphs.py new file mode 100644 index 00000000..6c06e05c --- /dev/null +++ b/graphrag/index/utils/graphs.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Collection of graph utility functions.""" + +import logging +from typing import cast + +import networkx as nx +import numpy as np +import pandas as pd +from graspologic.partition import hierarchical_leiden, modularity +from graspologic.utils import largest_connected_component + +from graphrag.config.enums import ModularityMetric + +logger = logging.getLogger(__name__) + + +def calculate_root_modularity( + graph: nx.Graph, + max_cluster_size: int = 10, + random_seed: int = 0xDEADBEEF, +) -> float: + """Calculate distance between the modularity of the graph's root clusters and the target modularity.""" + hcs = hierarchical_leiden( + graph, max_cluster_size=max_cluster_size, random_seed=random_seed + ) + root_clusters = hcs.first_level_hierarchical_clustering() + return modularity(graph, root_clusters) + + +def calculate_leaf_modularity( + graph: nx.Graph, + max_cluster_size: int = 10, + random_seed: int = 0xDEADBEEF, +) -> float: + """Calculate distance between the modularity of the graph's leaf clusters and the target modularity.""" + hcs = hierarchical_leiden( + graph, max_cluster_size=max_cluster_size, random_seed=random_seed + ) + leaf_clusters = hcs.final_level_hierarchical_clustering() + return modularity(graph, leaf_clusters) + + +def calculate_graph_modularity( + graph: nx.Graph, + max_cluster_size: int = 10, + random_seed: int = 0xDEADBEEF, + use_root_modularity: bool = True, +) -> float: + """Calculate modularity of the whole graph.""" + if use_root_modularity: + return calculate_root_modularity( + graph, max_cluster_size=max_cluster_size, random_seed=random_seed + ) + return calculate_leaf_modularity( + graph, max_cluster_size=max_cluster_size, random_seed=random_seed + ) + + +def calculate_lcc_modularity( + graph: nx.Graph, + max_cluster_size: int = 10, + random_seed: int = 0xDEADBEEF, + use_root_modularity: bool = True, +) -> float: + """Calculate modularity of the largest connected component of the graph.""" + lcc = cast("nx.Graph", largest_connected_component(graph)) + if use_root_modularity: + return calculate_root_modularity( + lcc, max_cluster_size=max_cluster_size, random_seed=random_seed + ) + return calculate_leaf_modularity( + lcc, max_cluster_size=max_cluster_size, random_seed=random_seed + ) + + +def calculate_weighted_modularity( + graph: nx.Graph, + max_cluster_size: int = 10, + random_seed: int = 0xDEADBEEF, + min_connected_component_size: int = 10, + use_root_modularity: bool = True, +) -> float: + """ + Calculate weighted modularity of all connected components with size greater than min_connected_component_size. + + Modularity = sum(component_modularity * component_size) / total_nodes. + """ + connected_components: list[set] = list(nx.connected_components(graph)) + filtered_components = [ + component + for component in connected_components + if len(component) > min_connected_component_size + ] + if len(filtered_components) == 0: + filtered_components = [graph] + + total_nodes = sum(len(component) for component in filtered_components) + total_modularity = 0 + for component in filtered_components: + if len(component) > min_connected_component_size: + subgraph = graph.subgraph(component) + if use_root_modularity: + modularity = calculate_root_modularity( + subgraph, max_cluster_size=max_cluster_size, random_seed=random_seed + ) + else: + modularity = calculate_leaf_modularity( + subgraph, max_cluster_size=max_cluster_size, random_seed=random_seed + ) + total_modularity += modularity * len(component) / total_nodes + return total_modularity + + +def calculate_modularity( + graph: nx.Graph, + max_cluster_size: int = 10, + random_seed: int = 0xDEADBEEF, + use_root_modularity: bool = True, + modularity_metric: ModularityMetric = ModularityMetric.WeightedComponents, +) -> float: + """Calculate modularity of the graph based on the modularity metric type.""" + match modularity_metric: + case ModularityMetric.Graph: + logger.info("Calculating graph modularity") + return calculate_graph_modularity( + graph, + max_cluster_size=max_cluster_size, + random_seed=random_seed, + use_root_modularity=use_root_modularity, + ) + case ModularityMetric.LCC: + logger.info("Calculating LCC modularity") + return calculate_lcc_modularity( + graph, + max_cluster_size=max_cluster_size, + random_seed=random_seed, + use_root_modularity=use_root_modularity, + ) + case ModularityMetric.WeightedComponents: + logger.info("Calculating weighted-components modularity") + return calculate_weighted_modularity( + graph, + max_cluster_size=max_cluster_size, + random_seed=random_seed, + use_root_modularity=use_root_modularity, + ) + case _: + msg = f"Unknown modularity metric type: {modularity_metric}" + raise ValueError(msg) + + +def calculate_pmi_edge_weights( + nodes_df: pd.DataFrame, + edges_df: pd.DataFrame, + node_name_col: str = "title", + node_freq_col: str = "frequency", + edge_weight_col: str = "weight", + edge_source_col: str = "source", + edge_target_col: str = "target", +) -> pd.DataFrame: + """ + Calculate pointwise mutual information (PMI) edge weights. + + Uses a variant of PMI that accounts for bias towards low-frequency events. + pmi(x,y) = p(x,y) * log2(p(x,y)/ (p(x)*p(y)) + p(x,y) = edge_weight(x,y) / total_edge_weights + p(x) = freq_occurrence(x) / total_freq_occurrences. + + """ + copied_nodes_df = nodes_df[[node_name_col, node_freq_col]] + + total_edge_weights = edges_df[edge_weight_col].sum() + total_freq_occurrences = nodes_df[node_freq_col].sum() + copied_nodes_df["prop_occurrence"] = ( + copied_nodes_df[node_freq_col] / total_freq_occurrences + ) + copied_nodes_df = copied_nodes_df.loc[:, [node_name_col, "prop_occurrence"]] + + edges_df["prop_weight"] = edges_df[edge_weight_col] / total_edge_weights + edges_df = ( + edges_df.merge( + copied_nodes_df, left_on=edge_source_col, right_on=node_name_col, how="left" + ) + .drop(columns=[node_name_col]) + .rename(columns={"prop_occurrence": "source_prop"}) + ) + edges_df = ( + edges_df.merge( + copied_nodes_df, left_on=edge_target_col, right_on=node_name_col, how="left" + ) + .drop(columns=[node_name_col]) + .rename(columns={"prop_occurrence": "target_prop"}) + ) + edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2( + edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"]) + ) + + return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"]) + + +def calculate_rrf_edge_weights( + nodes_df: pd.DataFrame, + edges_df: pd.DataFrame, + node_name_col="title", + node_freq_col="freq", + edge_weight_col="weight", + edge_source_col="source", + edge_target_col="target", + rrf_smoothing_factor: int = 60, +) -> pd.DataFrame: + """Calculate reciprocal rank fusion (RRF) edge weights as a combination of PMI weight and combined freq of source and target.""" + edges_df = calculate_pmi_edge_weights( + nodes_df, + edges_df, + node_name_col, + node_freq_col, + edge_weight_col, + edge_source_col, + edge_target_col, + ) + + edges_df["pmi_rank"] = edges_df[edge_weight_col].rank(method="min", ascending=False) + edges_df["raw_weight_rank"] = edges_df[edge_weight_col].rank( + method="min", ascending=False + ) + edges_df[edge_weight_col] = edges_df.apply( + lambda x: (1 / (rrf_smoothing_factor + x["pmi_rank"])) + + (1 / (rrf_smoothing_factor + x["raw_weight_rank"])), + axis=1, + ) + + return edges_df.drop(columns=["pmi_rank", "raw_weight_rank"]) + + +def get_upper_threshold_by_std(data: list[float] | list[int], std_trim: float) -> float: + """Get upper threshold by standard deviation.""" + mean = np.mean(data) + std = np.std(data) + return cast("float", mean + std_trim * std) diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index b4790f55..f10fcbdb 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -16,11 +16,10 @@ from graphrag.config.embeddings import ( document_text_embedding, entity_description_embedding, entity_title_embedding, - get_embedded_fields, - get_embedding_settings, relationship_description_embedding, text_unit_text_embedding, ) +from graphrag.config.get_embedding_settings import get_embedding_settings from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.operations.embed_text import embed_text from graphrag.index.typing.context import PipelineRunContext @@ -57,7 +56,7 @@ async def run_workflow( "community_reports", context.storage ) - embedded_fields = get_embedded_fields(config) + embedded_fields = config.embed_text.names text_embed = get_embedding_settings(config) output = await generate_text_embeddings( @@ -92,7 +91,7 @@ async def generate_text_embeddings( callbacks: WorkflowCallbacks, cache: PipelineCache, text_embed_config: dict, - embedded_fields: set[str], + embedded_fields: list[str], ) -> dict[str, pd.DataFrame]: """All the steps to generate all embeddings.""" embedding_param_map = { @@ -148,20 +147,20 @@ async def generate_text_embeddings( outputs = {} for field in embedded_fields: if embedding_param_map[field]["data"] is None: - msg = f"Embedding {field} is specified but data table is not in storage." - raise ValueError(msg) - - outputs[field] = await _run_and_snapshot_embeddings( - name=field, - callbacks=callbacks, - cache=cache, - text_embed_config=text_embed_config, - **embedding_param_map[field], - ) + msg = f"Embedding {field} is specified but data table is not in storage. This may or may not be intentional - if you expect it to me here, please check for errors earlier in the logs." + log.warning(msg) + else: + outputs[field] = await _run_embeddings( + name=field, + callbacks=callbacks, + cache=cache, + text_embed_config=text_embed_config, + **embedding_param_map[field], + ) return outputs -async def _run_and_snapshot_embeddings( +async def _run_embeddings( name: str, data: pd.DataFrame, embed_column: str, diff --git a/graphrag/index/workflows/update_text_embeddings.py b/graphrag/index/workflows/update_text_embeddings.py index c20fb1bf..ed57de86 100644 --- a/graphrag/index/workflows/update_text_embeddings.py +++ b/graphrag/index/workflows/update_text_embeddings.py @@ -5,7 +5,7 @@ import logging -from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings +from graphrag.config.get_embedding_settings import get_embedding_settings from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext @@ -34,7 +34,7 @@ async def run_workflow( "incremental_update_merged_community_reports" ] - embedded_fields = get_embedded_fields(config) + embedded_fields = config.embed_text.names text_embed = get_embedding_settings(config) result = await generate_text_embeddings( documents=final_documents_df, diff --git a/graphrag/query/input/loaders/dfs.py b/graphrag/query/input/loaders/dfs.py index 9dde2d34..7182090c 100644 --- a/graphrag/query/input/loaders/dfs.py +++ b/graphrag/query/input/loaders/dfs.py @@ -154,6 +154,7 @@ def read_communities( level_col: str = "level", entities_col: str | None = "entity_ids", relationships_col: str | None = "relationship_ids", + text_units_col: str | None = "text_unit_ids", covariates_col: str | None = "covariate_ids", parent_col: str | None = "parent", children_col: str | None = "children", @@ -171,6 +172,7 @@ def read_communities( level=to_str(row, level_col), entity_ids=to_optional_list(row, entities_col, item_type=str), relationship_ids=to_optional_list(row, relationships_col, item_type=str), + text_unit_ids=to_optional_list(row, text_units_col, item_type=str), covariate_ids=to_optional_dict( row, covariates_col, key_type=str, value_type=str ), diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index e079cf95..4aea91a0 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -194,7 +194,6 @@ def assert_text_embedding_configs( ) -> None: assert actual.batch_size == expected.batch_size assert actual.batch_max_tokens == expected.batch_max_tokens - assert actual.target == expected.target assert actual.names == expected.names assert actual.strategy == expected.strategy assert actual.model_id == expected.model_id diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py index 511dbf3b..788ddfc9 100644 --- a/tests/verbs/test_generate_text_embeddings.py +++ b/tests/verbs/test_generate_text_embeddings.py @@ -5,7 +5,7 @@ from graphrag.config.create_graphrag_config import create_graphrag_config from graphrag.config.embeddings import ( all_embeddings, ) -from graphrag.config.enums import ModelType, TextEmbeddingTarget +from graphrag.config.enums import ModelType from graphrag.index.operations.embed_text.embed_text import TextEmbedStrategyType from graphrag.index.workflows.generate_text_embeddings import ( run_workflow, @@ -39,7 +39,7 @@ async def test_generate_text_embeddings(): "type": TextEmbedStrategyType.openai, "llm": llm_settings, } - config.embed_text.target = TextEmbeddingTarget.all + config.embed_text.names = list(all_embeddings) config.snapshots.embeddings = True await run_workflow(config, context)