diff --git a/.semversioner/next-release/minor-20250225003631981779.json b/.semversioner/next-release/minor-20250225003631981779.json new file mode 100644 index 00000000..4a689a74 --- /dev/null +++ b/.semversioner/next-release/minor-20250225003631981779.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Move embeddings snspshots to the workflow runner." +} diff --git a/graphrag/index/update/incremental_index.py b/graphrag/index/update/incremental_index.py index 3671d97e..c0fd5494 100644 --- a/graphrag/index/update/incremental_index.py +++ b/graphrag/index/update/incremental_index.py @@ -145,19 +145,24 @@ async def update_dataframe_outputs( progress_logger.info("Updating Text Embeddings") embedded_fields = get_embedded_fields(config) text_embed = get_embedding_settings(config) - await generate_text_embeddings( - final_documents=final_documents_df, - final_relationships=merged_relationships_df, - final_text_units=merged_text_units, - final_entities=merged_entities_df, - final_community_reports=merged_community_reports, + result = await generate_text_embeddings( + documents=final_documents_df, + relationships=merged_relationships_df, + text_units=merged_text_units, + entities=merged_entities_df, + community_reports=merged_community_reports, callbacks=callbacks, cache=cache, - storage=output_storage, text_embed_config=text_embed, embedded_fields=embedded_fields, - snapshot_embeddings_enabled=config.snapshots.embeddings, ) + if config.snapshots.embeddings: + for name, table in result.items(): + await write_table_to_storage( + table, + f"embeddings.{name}", + output_storage, + ) async def _update_community_reports( diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index 5c7a4a0e..0f9f2de9 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -25,7 +25,6 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.context import PipelineRunContext from graphrag.index.operations.embed_text import embed_text from graphrag.index.typing import WorkflowFunctionOutput -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage log = logging.getLogger(__name__) @@ -37,114 +36,112 @@ async def run_workflow( callbacks: WorkflowCallbacks, ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" - final_documents = await load_table_from_storage("documents", context.storage) - final_relationships = await load_table_from_storage( - "relationships", context.storage - ) - final_text_units = await load_table_from_storage("text_units", context.storage) - final_entities = await load_table_from_storage("entities", context.storage) - final_community_reports = await load_table_from_storage( + documents = await load_table_from_storage("documents", context.storage) + relationships = await load_table_from_storage("relationships", context.storage) + text_units = await load_table_from_storage("text_units", context.storage) + entities = await load_table_from_storage("entities", context.storage) + community_reports = await load_table_from_storage( "community_reports", context.storage ) embedded_fields = get_embedded_fields(config) text_embed = get_embedding_settings(config) - await generate_text_embeddings( - final_documents=final_documents, - final_relationships=final_relationships, - final_text_units=final_text_units, - final_entities=final_entities, - final_community_reports=final_community_reports, + result = await generate_text_embeddings( + documents=documents, + relationships=relationships, + text_units=text_units, + entities=entities, + community_reports=community_reports, callbacks=callbacks, cache=context.cache, - storage=context.storage, text_embed_config=text_embed, embedded_fields=embedded_fields, - snapshot_embeddings_enabled=config.snapshots.embeddings, ) - return WorkflowFunctionOutput(result=None, config=None) + if config.snapshots.embeddings: + for name, table in result.items(): + await write_table_to_storage( + table, + f"embeddings.{name}", + context.storage, + ) + + return WorkflowFunctionOutput(result=result, config=None) async def generate_text_embeddings( - final_documents: pd.DataFrame | None, - final_relationships: pd.DataFrame | None, - final_text_units: pd.DataFrame | None, - final_entities: pd.DataFrame | None, - final_community_reports: pd.DataFrame | None, + documents: pd.DataFrame | None, + relationships: pd.DataFrame | None, + text_units: pd.DataFrame | None, + entities: pd.DataFrame | None, + community_reports: pd.DataFrame | None, callbacks: WorkflowCallbacks, cache: PipelineCache, - storage: PipelineStorage, text_embed_config: dict, embedded_fields: set[str], - snapshot_embeddings_enabled: bool = False, -) -> None: +) -> dict[str, pd.DataFrame]: """All the steps to generate all embeddings.""" embedding_param_map = { document_text_embedding: { - "data": final_documents.loc[:, ["id", "text"]] - if final_documents is not None - else None, + "data": documents.loc[:, ["id", "text"]] if documents is not None else None, "embed_column": "text", }, relationship_description_embedding: { - "data": final_relationships.loc[:, ["id", "description"]] - if final_relationships is not None + "data": relationships.loc[:, ["id", "description"]] + if relationships is not None else None, "embed_column": "description", }, text_unit_text_embedding: { - "data": final_text_units.loc[:, ["id", "text"]] - if final_text_units is not None + "data": text_units.loc[:, ["id", "text"]] + if text_units is not None else None, "embed_column": "text", }, entity_title_embedding: { - "data": final_entities.loc[:, ["id", "title"]] - if final_entities is not None - else None, + "data": entities.loc[:, ["id", "title"]] if entities is not None else None, "embed_column": "title", }, entity_description_embedding: { - "data": final_entities.loc[:, ["id", "title", "description"]].assign( + "data": entities.loc[:, ["id", "title", "description"]].assign( title_description=lambda df: df["title"] + ":" + df["description"] ) - if final_entities is not None + if entities is not None else None, "embed_column": "title_description", }, community_title_embedding: { - "data": final_community_reports.loc[:, ["id", "title"]] - if final_community_reports is not None + "data": community_reports.loc[:, ["id", "title"]] + if community_reports is not None else None, "embed_column": "title", }, community_summary_embedding: { - "data": final_community_reports.loc[:, ["id", "summary"]] - if final_community_reports is not None + "data": community_reports.loc[:, ["id", "summary"]] + if community_reports is not None else None, "embed_column": "summary", }, community_full_content_embedding: { - "data": final_community_reports.loc[:, ["id", "full_content"]] - if final_community_reports is not None + "data": community_reports.loc[:, ["id", "full_content"]] + if community_reports is not None else None, "embed_column": "full_content", }, } log.info("Creating embeddings") + outputs = {} for field in embedded_fields: - await _run_and_snapshot_embeddings( + outputs[field] = await _run_and_snapshot_embeddings( name=field, callbacks=callbacks, cache=cache, - storage=storage, text_embed_config=text_embed_config, - snapshot_embeddings_enabled=snapshot_embeddings_enabled, **embedding_param_map[field], ) + return outputs async def _run_and_snapshot_embeddings( @@ -153,21 +150,16 @@ async def _run_and_snapshot_embeddings( embed_column: str, callbacks: WorkflowCallbacks, cache: PipelineCache, - storage: PipelineStorage, text_embed_config: dict, - snapshot_embeddings_enabled: bool, -) -> None: +) -> pd.DataFrame: """All the steps to generate single embedding.""" - if text_embed_config: - data["embedding"] = await embed_text( - input=data, - callbacks=callbacks, - cache=cache, - embed_column=embed_column, - embedding_name=name, - strategy=text_embed_config["strategy"], - ) + data["embedding"] = await embed_text( + input=data, + callbacks=callbacks, + cache=cache, + embed_column=embed_column, + embedding_name=name, + strategy=text_embed_config["strategy"], + ) - if snapshot_embeddings_enabled is True: - data = data.loc[:, ["id", "embedding"]] - await write_table_to_storage(data, f"embeddings.{name}", storage) + return data.loc[:, ["id", "embedding"]]