mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Move embeddings snapshots (#1737)
* Move embedding snapshots to the workflow runner * Semver * Rename input tables
This commit is contained in:
parent
e0d233fe10
commit
5dd9fc53cd
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Move embeddings snspshots to the workflow runner."
|
||||
}
|
@ -145,19 +145,24 @@ async def update_dataframe_outputs(
|
||||
progress_logger.info("Updating Text Embeddings")
|
||||
embedded_fields = get_embedded_fields(config)
|
||||
text_embed = get_embedding_settings(config)
|
||||
await generate_text_embeddings(
|
||||
final_documents=final_documents_df,
|
||||
final_relationships=merged_relationships_df,
|
||||
final_text_units=merged_text_units,
|
||||
final_entities=merged_entities_df,
|
||||
final_community_reports=merged_community_reports,
|
||||
result = await generate_text_embeddings(
|
||||
documents=final_documents_df,
|
||||
relationships=merged_relationships_df,
|
||||
text_units=merged_text_units,
|
||||
entities=merged_entities_df,
|
||||
community_reports=merged_community_reports,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
storage=output_storage,
|
||||
text_embed_config=text_embed,
|
||||
embedded_fields=embedded_fields,
|
||||
snapshot_embeddings_enabled=config.snapshots.embeddings,
|
||||
)
|
||||
if config.snapshots.embeddings:
|
||||
for name, table in result.items():
|
||||
await write_table_to_storage(
|
||||
table,
|
||||
f"embeddings.{name}",
|
||||
output_storage,
|
||||
)
|
||||
|
||||
|
||||
async def _update_community_reports(
|
||||
|
@ -25,7 +25,6 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -37,114 +36,112 @@ async def run_workflow(
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
final_documents = await load_table_from_storage("documents", context.storage)
|
||||
final_relationships = await load_table_from_storage(
|
||||
"relationships", context.storage
|
||||
)
|
||||
final_text_units = await load_table_from_storage("text_units", context.storage)
|
||||
final_entities = await load_table_from_storage("entities", context.storage)
|
||||
final_community_reports = await load_table_from_storage(
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
relationships = await load_table_from_storage("relationships", context.storage)
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
community_reports = await load_table_from_storage(
|
||||
"community_reports", context.storage
|
||||
)
|
||||
|
||||
embedded_fields = get_embedded_fields(config)
|
||||
text_embed = get_embedding_settings(config)
|
||||
|
||||
await generate_text_embeddings(
|
||||
final_documents=final_documents,
|
||||
final_relationships=final_relationships,
|
||||
final_text_units=final_text_units,
|
||||
final_entities=final_entities,
|
||||
final_community_reports=final_community_reports,
|
||||
result = await generate_text_embeddings(
|
||||
documents=documents,
|
||||
relationships=relationships,
|
||||
text_units=text_units,
|
||||
entities=entities,
|
||||
community_reports=community_reports,
|
||||
callbacks=callbacks,
|
||||
cache=context.cache,
|
||||
storage=context.storage,
|
||||
text_embed_config=text_embed,
|
||||
embedded_fields=embedded_fields,
|
||||
snapshot_embeddings_enabled=config.snapshots.embeddings,
|
||||
)
|
||||
|
||||
return WorkflowFunctionOutput(result=None, config=None)
|
||||
if config.snapshots.embeddings:
|
||||
for name, table in result.items():
|
||||
await write_table_to_storage(
|
||||
table,
|
||||
f"embeddings.{name}",
|
||||
context.storage,
|
||||
)
|
||||
|
||||
return WorkflowFunctionOutput(result=result, config=None)
|
||||
|
||||
|
||||
async def generate_text_embeddings(
|
||||
final_documents: pd.DataFrame | None,
|
||||
final_relationships: pd.DataFrame | None,
|
||||
final_text_units: pd.DataFrame | None,
|
||||
final_entities: pd.DataFrame | None,
|
||||
final_community_reports: pd.DataFrame | None,
|
||||
documents: pd.DataFrame | None,
|
||||
relationships: pd.DataFrame | None,
|
||||
text_units: pd.DataFrame | None,
|
||||
entities: pd.DataFrame | None,
|
||||
community_reports: pd.DataFrame | None,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
text_embed_config: dict,
|
||||
embedded_fields: set[str],
|
||||
snapshot_embeddings_enabled: bool = False,
|
||||
) -> None:
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""All the steps to generate all embeddings."""
|
||||
embedding_param_map = {
|
||||
document_text_embedding: {
|
||||
"data": final_documents.loc[:, ["id", "text"]]
|
||||
if final_documents is not None
|
||||
else None,
|
||||
"data": documents.loc[:, ["id", "text"]] if documents is not None else None,
|
||||
"embed_column": "text",
|
||||
},
|
||||
relationship_description_embedding: {
|
||||
"data": final_relationships.loc[:, ["id", "description"]]
|
||||
if final_relationships is not None
|
||||
"data": relationships.loc[:, ["id", "description"]]
|
||||
if relationships is not None
|
||||
else None,
|
||||
"embed_column": "description",
|
||||
},
|
||||
text_unit_text_embedding: {
|
||||
"data": final_text_units.loc[:, ["id", "text"]]
|
||||
if final_text_units is not None
|
||||
"data": text_units.loc[:, ["id", "text"]]
|
||||
if text_units is not None
|
||||
else None,
|
||||
"embed_column": "text",
|
||||
},
|
||||
entity_title_embedding: {
|
||||
"data": final_entities.loc[:, ["id", "title"]]
|
||||
if final_entities is not None
|
||||
else None,
|
||||
"data": entities.loc[:, ["id", "title"]] if entities is not None else None,
|
||||
"embed_column": "title",
|
||||
},
|
||||
entity_description_embedding: {
|
||||
"data": final_entities.loc[:, ["id", "title", "description"]].assign(
|
||||
"data": entities.loc[:, ["id", "title", "description"]].assign(
|
||||
title_description=lambda df: df["title"] + ":" + df["description"]
|
||||
)
|
||||
if final_entities is not None
|
||||
if entities is not None
|
||||
else None,
|
||||
"embed_column": "title_description",
|
||||
},
|
||||
community_title_embedding: {
|
||||
"data": final_community_reports.loc[:, ["id", "title"]]
|
||||
if final_community_reports is not None
|
||||
"data": community_reports.loc[:, ["id", "title"]]
|
||||
if community_reports is not None
|
||||
else None,
|
||||
"embed_column": "title",
|
||||
},
|
||||
community_summary_embedding: {
|
||||
"data": final_community_reports.loc[:, ["id", "summary"]]
|
||||
if final_community_reports is not None
|
||||
"data": community_reports.loc[:, ["id", "summary"]]
|
||||
if community_reports is not None
|
||||
else None,
|
||||
"embed_column": "summary",
|
||||
},
|
||||
community_full_content_embedding: {
|
||||
"data": final_community_reports.loc[:, ["id", "full_content"]]
|
||||
if final_community_reports is not None
|
||||
"data": community_reports.loc[:, ["id", "full_content"]]
|
||||
if community_reports is not None
|
||||
else None,
|
||||
"embed_column": "full_content",
|
||||
},
|
||||
}
|
||||
|
||||
log.info("Creating embeddings")
|
||||
outputs = {}
|
||||
for field in embedded_fields:
|
||||
await _run_and_snapshot_embeddings(
|
||||
outputs[field] = await _run_and_snapshot_embeddings(
|
||||
name=field,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
storage=storage,
|
||||
text_embed_config=text_embed_config,
|
||||
snapshot_embeddings_enabled=snapshot_embeddings_enabled,
|
||||
**embedding_param_map[field],
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
async def _run_and_snapshot_embeddings(
|
||||
@ -153,21 +150,16 @@ async def _run_and_snapshot_embeddings(
|
||||
embed_column: str,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
text_embed_config: dict,
|
||||
snapshot_embeddings_enabled: bool,
|
||||
) -> None:
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to generate single embedding."""
|
||||
if text_embed_config:
|
||||
data["embedding"] = await embed_text(
|
||||
input=data,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
embed_column=embed_column,
|
||||
embedding_name=name,
|
||||
strategy=text_embed_config["strategy"],
|
||||
)
|
||||
data["embedding"] = await embed_text(
|
||||
input=data,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
embed_column=embed_column,
|
||||
embedding_name=name,
|
||||
strategy=text_embed_config["strategy"],
|
||||
)
|
||||
|
||||
if snapshot_embeddings_enabled is True:
|
||||
data = data.loc[:, ["id", "embedding"]]
|
||||
await write_table_to_storage(data, f"embeddings.{name}", storage)
|
||||
return data.loc[:, ["id", "embedding"]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user