Move embeddings snapshots (#1737)

* Move embedding snapshots to the workflow runner

* Semver

* Rename input tables
This commit is contained in:
Nathan Evans 2025-02-24 17:38:01 -08:00 committed by GitHub
parent e0d233fe10
commit 5dd9fc53cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 69 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Move embeddings snspshots to the workflow runner."
}

View File

@ -145,19 +145,24 @@ async def update_dataframe_outputs(
progress_logger.info("Updating Text Embeddings")
embedded_fields = get_embedded_fields(config)
text_embed = get_embedding_settings(config)
await generate_text_embeddings(
final_documents=final_documents_df,
final_relationships=merged_relationships_df,
final_text_units=merged_text_units,
final_entities=merged_entities_df,
final_community_reports=merged_community_reports,
result = await generate_text_embeddings(
documents=final_documents_df,
relationships=merged_relationships_df,
text_units=merged_text_units,
entities=merged_entities_df,
community_reports=merged_community_reports,
callbacks=callbacks,
cache=cache,
storage=output_storage,
text_embed_config=text_embed,
embedded_fields=embedded_fields,
snapshot_embeddings_enabled=config.snapshots.embeddings,
)
if config.snapshots.embeddings:
for name, table in result.items():
await write_table_to_storage(
table,
f"embeddings.{name}",
output_storage,
)
async def _update_community_reports(

View File

@ -25,7 +25,6 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
log = logging.getLogger(__name__)
@ -37,114 +36,112 @@ async def run_workflow(
callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
final_documents = await load_table_from_storage("documents", context.storage)
final_relationships = await load_table_from_storage(
"relationships", context.storage
)
final_text_units = await load_table_from_storage("text_units", context.storage)
final_entities = await load_table_from_storage("entities", context.storage)
final_community_reports = await load_table_from_storage(
documents = await load_table_from_storage("documents", context.storage)
relationships = await load_table_from_storage("relationships", context.storage)
text_units = await load_table_from_storage("text_units", context.storage)
entities = await load_table_from_storage("entities", context.storage)
community_reports = await load_table_from_storage(
"community_reports", context.storage
)
embedded_fields = get_embedded_fields(config)
text_embed = get_embedding_settings(config)
await generate_text_embeddings(
final_documents=final_documents,
final_relationships=final_relationships,
final_text_units=final_text_units,
final_entities=final_entities,
final_community_reports=final_community_reports,
result = await generate_text_embeddings(
documents=documents,
relationships=relationships,
text_units=text_units,
entities=entities,
community_reports=community_reports,
callbacks=callbacks,
cache=context.cache,
storage=context.storage,
text_embed_config=text_embed,
embedded_fields=embedded_fields,
snapshot_embeddings_enabled=config.snapshots.embeddings,
)
return WorkflowFunctionOutput(result=None, config=None)
if config.snapshots.embeddings:
for name, table in result.items():
await write_table_to_storage(
table,
f"embeddings.{name}",
context.storage,
)
return WorkflowFunctionOutput(result=result, config=None)
async def generate_text_embeddings(
final_documents: pd.DataFrame | None,
final_relationships: pd.DataFrame | None,
final_text_units: pd.DataFrame | None,
final_entities: pd.DataFrame | None,
final_community_reports: pd.DataFrame | None,
documents: pd.DataFrame | None,
relationships: pd.DataFrame | None,
text_units: pd.DataFrame | None,
entities: pd.DataFrame | None,
community_reports: pd.DataFrame | None,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
text_embed_config: dict,
embedded_fields: set[str],
snapshot_embeddings_enabled: bool = False,
) -> None:
) -> dict[str, pd.DataFrame]:
"""All the steps to generate all embeddings."""
embedding_param_map = {
document_text_embedding: {
"data": final_documents.loc[:, ["id", "text"]]
if final_documents is not None
else None,
"data": documents.loc[:, ["id", "text"]] if documents is not None else None,
"embed_column": "text",
},
relationship_description_embedding: {
"data": final_relationships.loc[:, ["id", "description"]]
if final_relationships is not None
"data": relationships.loc[:, ["id", "description"]]
if relationships is not None
else None,
"embed_column": "description",
},
text_unit_text_embedding: {
"data": final_text_units.loc[:, ["id", "text"]]
if final_text_units is not None
"data": text_units.loc[:, ["id", "text"]]
if text_units is not None
else None,
"embed_column": "text",
},
entity_title_embedding: {
"data": final_entities.loc[:, ["id", "title"]]
if final_entities is not None
else None,
"data": entities.loc[:, ["id", "title"]] if entities is not None else None,
"embed_column": "title",
},
entity_description_embedding: {
"data": final_entities.loc[:, ["id", "title", "description"]].assign(
"data": entities.loc[:, ["id", "title", "description"]].assign(
title_description=lambda df: df["title"] + ":" + df["description"]
)
if final_entities is not None
if entities is not None
else None,
"embed_column": "title_description",
},
community_title_embedding: {
"data": final_community_reports.loc[:, ["id", "title"]]
if final_community_reports is not None
"data": community_reports.loc[:, ["id", "title"]]
if community_reports is not None
else None,
"embed_column": "title",
},
community_summary_embedding: {
"data": final_community_reports.loc[:, ["id", "summary"]]
if final_community_reports is not None
"data": community_reports.loc[:, ["id", "summary"]]
if community_reports is not None
else None,
"embed_column": "summary",
},
community_full_content_embedding: {
"data": final_community_reports.loc[:, ["id", "full_content"]]
if final_community_reports is not None
"data": community_reports.loc[:, ["id", "full_content"]]
if community_reports is not None
else None,
"embed_column": "full_content",
},
}
log.info("Creating embeddings")
outputs = {}
for field in embedded_fields:
await _run_and_snapshot_embeddings(
outputs[field] = await _run_and_snapshot_embeddings(
name=field,
callbacks=callbacks,
cache=cache,
storage=storage,
text_embed_config=text_embed_config,
snapshot_embeddings_enabled=snapshot_embeddings_enabled,
**embedding_param_map[field],
)
return outputs
async def _run_and_snapshot_embeddings(
@ -153,21 +150,16 @@ async def _run_and_snapshot_embeddings(
embed_column: str,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
text_embed_config: dict,
snapshot_embeddings_enabled: bool,
) -> None:
) -> pd.DataFrame:
"""All the steps to generate single embedding."""
if text_embed_config:
data["embedding"] = await embed_text(
input=data,
callbacks=callbacks,
cache=cache,
embed_column=embed_column,
embedding_name=name,
strategy=text_embed_config["strategy"],
)
data["embedding"] = await embed_text(
input=data,
callbacks=callbacks,
cache=cache,
embed_column=embed_column,
embedding_name=name,
strategy=text_embed_config["strategy"],
)
if snapshot_embeddings_enabled is True:
data = data.loc[:, ["id", "embedding"]]
await write_table_to_storage(data, f"embeddings.{name}", storage)
return data.loc[:, ["id", "embedding"]]