Move embeddings snapshots (#1737)

* Move embedding snapshots to the workflow runner

* Semver

* Rename input tables
This commit is contained in:
Nathan Evans 2025-02-24 17:38:01 -08:00 committed by GitHub
parent e0d233fe10
commit 5dd9fc53cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 69 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Move embeddings snspshots to the workflow runner."
}

View File

@ -145,19 +145,24 @@ async def update_dataframe_outputs(
progress_logger.info("Updating Text Embeddings") progress_logger.info("Updating Text Embeddings")
embedded_fields = get_embedded_fields(config) embedded_fields = get_embedded_fields(config)
text_embed = get_embedding_settings(config) text_embed = get_embedding_settings(config)
await generate_text_embeddings( result = await generate_text_embeddings(
final_documents=final_documents_df, documents=final_documents_df,
final_relationships=merged_relationships_df, relationships=merged_relationships_df,
final_text_units=merged_text_units, text_units=merged_text_units,
final_entities=merged_entities_df, entities=merged_entities_df,
final_community_reports=merged_community_reports, community_reports=merged_community_reports,
callbacks=callbacks, callbacks=callbacks,
cache=cache, cache=cache,
storage=output_storage,
text_embed_config=text_embed, text_embed_config=text_embed,
embedded_fields=embedded_fields, embedded_fields=embedded_fields,
snapshot_embeddings_enabled=config.snapshots.embeddings,
) )
if config.snapshots.embeddings:
for name, table in result.items():
await write_table_to_storage(
table,
f"embeddings.{name}",
output_storage,
)
async def _update_community_reports( async def _update_community_reports(

View File

@ -25,7 +25,6 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.embed_text import embed_text from graphrag.index.operations.embed_text import embed_text
from graphrag.index.typing import WorkflowFunctionOutput from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -37,114 +36,112 @@ async def run_workflow(
callbacks: WorkflowCallbacks, callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput: ) -> WorkflowFunctionOutput:
"""All the steps to transform community reports.""" """All the steps to transform community reports."""
final_documents = await load_table_from_storage("documents", context.storage) documents = await load_table_from_storage("documents", context.storage)
final_relationships = await load_table_from_storage( relationships = await load_table_from_storage("relationships", context.storage)
"relationships", context.storage text_units = await load_table_from_storage("text_units", context.storage)
) entities = await load_table_from_storage("entities", context.storage)
final_text_units = await load_table_from_storage("text_units", context.storage) community_reports = await load_table_from_storage(
final_entities = await load_table_from_storage("entities", context.storage)
final_community_reports = await load_table_from_storage(
"community_reports", context.storage "community_reports", context.storage
) )
embedded_fields = get_embedded_fields(config) embedded_fields = get_embedded_fields(config)
text_embed = get_embedding_settings(config) text_embed = get_embedding_settings(config)
await generate_text_embeddings( result = await generate_text_embeddings(
final_documents=final_documents, documents=documents,
final_relationships=final_relationships, relationships=relationships,
final_text_units=final_text_units, text_units=text_units,
final_entities=final_entities, entities=entities,
final_community_reports=final_community_reports, community_reports=community_reports,
callbacks=callbacks, callbacks=callbacks,
cache=context.cache, cache=context.cache,
storage=context.storage,
text_embed_config=text_embed, text_embed_config=text_embed,
embedded_fields=embedded_fields, embedded_fields=embedded_fields,
snapshot_embeddings_enabled=config.snapshots.embeddings,
) )
return WorkflowFunctionOutput(result=None, config=None) if config.snapshots.embeddings:
for name, table in result.items():
await write_table_to_storage(
table,
f"embeddings.{name}",
context.storage,
)
return WorkflowFunctionOutput(result=result, config=None)
async def generate_text_embeddings( async def generate_text_embeddings(
final_documents: pd.DataFrame | None, documents: pd.DataFrame | None,
final_relationships: pd.DataFrame | None, relationships: pd.DataFrame | None,
final_text_units: pd.DataFrame | None, text_units: pd.DataFrame | None,
final_entities: pd.DataFrame | None, entities: pd.DataFrame | None,
final_community_reports: pd.DataFrame | None, community_reports: pd.DataFrame | None,
callbacks: WorkflowCallbacks, callbacks: WorkflowCallbacks,
cache: PipelineCache, cache: PipelineCache,
storage: PipelineStorage,
text_embed_config: dict, text_embed_config: dict,
embedded_fields: set[str], embedded_fields: set[str],
snapshot_embeddings_enabled: bool = False, ) -> dict[str, pd.DataFrame]:
) -> None:
"""All the steps to generate all embeddings.""" """All the steps to generate all embeddings."""
embedding_param_map = { embedding_param_map = {
document_text_embedding: { document_text_embedding: {
"data": final_documents.loc[:, ["id", "text"]] "data": documents.loc[:, ["id", "text"]] if documents is not None else None,
if final_documents is not None
else None,
"embed_column": "text", "embed_column": "text",
}, },
relationship_description_embedding: { relationship_description_embedding: {
"data": final_relationships.loc[:, ["id", "description"]] "data": relationships.loc[:, ["id", "description"]]
if final_relationships is not None if relationships is not None
else None, else None,
"embed_column": "description", "embed_column": "description",
}, },
text_unit_text_embedding: { text_unit_text_embedding: {
"data": final_text_units.loc[:, ["id", "text"]] "data": text_units.loc[:, ["id", "text"]]
if final_text_units is not None if text_units is not None
else None, else None,
"embed_column": "text", "embed_column": "text",
}, },
entity_title_embedding: { entity_title_embedding: {
"data": final_entities.loc[:, ["id", "title"]] "data": entities.loc[:, ["id", "title"]] if entities is not None else None,
if final_entities is not None
else None,
"embed_column": "title", "embed_column": "title",
}, },
entity_description_embedding: { entity_description_embedding: {
"data": final_entities.loc[:, ["id", "title", "description"]].assign( "data": entities.loc[:, ["id", "title", "description"]].assign(
title_description=lambda df: df["title"] + ":" + df["description"] title_description=lambda df: df["title"] + ":" + df["description"]
) )
if final_entities is not None if entities is not None
else None, else None,
"embed_column": "title_description", "embed_column": "title_description",
}, },
community_title_embedding: { community_title_embedding: {
"data": final_community_reports.loc[:, ["id", "title"]] "data": community_reports.loc[:, ["id", "title"]]
if final_community_reports is not None if community_reports is not None
else None, else None,
"embed_column": "title", "embed_column": "title",
}, },
community_summary_embedding: { community_summary_embedding: {
"data": final_community_reports.loc[:, ["id", "summary"]] "data": community_reports.loc[:, ["id", "summary"]]
if final_community_reports is not None if community_reports is not None
else None, else None,
"embed_column": "summary", "embed_column": "summary",
}, },
community_full_content_embedding: { community_full_content_embedding: {
"data": final_community_reports.loc[:, ["id", "full_content"]] "data": community_reports.loc[:, ["id", "full_content"]]
if final_community_reports is not None if community_reports is not None
else None, else None,
"embed_column": "full_content", "embed_column": "full_content",
}, },
} }
log.info("Creating embeddings") log.info("Creating embeddings")
outputs = {}
for field in embedded_fields: for field in embedded_fields:
await _run_and_snapshot_embeddings( outputs[field] = await _run_and_snapshot_embeddings(
name=field, name=field,
callbacks=callbacks, callbacks=callbacks,
cache=cache, cache=cache,
storage=storage,
text_embed_config=text_embed_config, text_embed_config=text_embed_config,
snapshot_embeddings_enabled=snapshot_embeddings_enabled,
**embedding_param_map[field], **embedding_param_map[field],
) )
return outputs
async def _run_and_snapshot_embeddings( async def _run_and_snapshot_embeddings(
@ -153,21 +150,16 @@ async def _run_and_snapshot_embeddings(
embed_column: str, embed_column: str,
callbacks: WorkflowCallbacks, callbacks: WorkflowCallbacks,
cache: PipelineCache, cache: PipelineCache,
storage: PipelineStorage,
text_embed_config: dict, text_embed_config: dict,
snapshot_embeddings_enabled: bool, ) -> pd.DataFrame:
) -> None:
"""All the steps to generate single embedding.""" """All the steps to generate single embedding."""
if text_embed_config: data["embedding"] = await embed_text(
data["embedding"] = await embed_text( input=data,
input=data, callbacks=callbacks,
callbacks=callbacks, cache=cache,
cache=cache, embed_column=embed_column,
embed_column=embed_column, embedding_name=name,
embedding_name=name, strategy=text_embed_config["strategy"],
strategy=text_embed_config["strategy"], )
)
if snapshot_embeddings_enabled is True: return data.loc[:, ["id", "embedding"]]
data = data.loc[:, ["id", "embedding"]]
await write_table_to_storage(data, f"embeddings.{name}", storage)