mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Move embeddings snapshots (#1737)
* Move embedding snapshots to the workflow runner * Semver * Rename input tables
This commit is contained in:
parent
e0d233fe10
commit
5dd9fc53cd
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "minor",
|
||||||
|
"description": "Move embeddings snspshots to the workflow runner."
|
||||||
|
}
|
@ -145,19 +145,24 @@ async def update_dataframe_outputs(
|
|||||||
progress_logger.info("Updating Text Embeddings")
|
progress_logger.info("Updating Text Embeddings")
|
||||||
embedded_fields = get_embedded_fields(config)
|
embedded_fields = get_embedded_fields(config)
|
||||||
text_embed = get_embedding_settings(config)
|
text_embed = get_embedding_settings(config)
|
||||||
await generate_text_embeddings(
|
result = await generate_text_embeddings(
|
||||||
final_documents=final_documents_df,
|
documents=final_documents_df,
|
||||||
final_relationships=merged_relationships_df,
|
relationships=merged_relationships_df,
|
||||||
final_text_units=merged_text_units,
|
text_units=merged_text_units,
|
||||||
final_entities=merged_entities_df,
|
entities=merged_entities_df,
|
||||||
final_community_reports=merged_community_reports,
|
community_reports=merged_community_reports,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
storage=output_storage,
|
|
||||||
text_embed_config=text_embed,
|
text_embed_config=text_embed,
|
||||||
embedded_fields=embedded_fields,
|
embedded_fields=embedded_fields,
|
||||||
snapshot_embeddings_enabled=config.snapshots.embeddings,
|
|
||||||
)
|
)
|
||||||
|
if config.snapshots.embeddings:
|
||||||
|
for name, table in result.items():
|
||||||
|
await write_table_to_storage(
|
||||||
|
table,
|
||||||
|
f"embeddings.{name}",
|
||||||
|
output_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _update_community_reports(
|
async def _update_community_reports(
|
||||||
|
@ -25,7 +25,6 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
|
|||||||
from graphrag.index.context import PipelineRunContext
|
from graphrag.index.context import PipelineRunContext
|
||||||
from graphrag.index.operations.embed_text import embed_text
|
from graphrag.index.operations.embed_text import embed_text
|
||||||
from graphrag.index.typing import WorkflowFunctionOutput
|
from graphrag.index.typing import WorkflowFunctionOutput
|
||||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
|
||||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -37,114 +36,112 @@ async def run_workflow(
|
|||||||
callbacks: WorkflowCallbacks,
|
callbacks: WorkflowCallbacks,
|
||||||
) -> WorkflowFunctionOutput:
|
) -> WorkflowFunctionOutput:
|
||||||
"""All the steps to transform community reports."""
|
"""All the steps to transform community reports."""
|
||||||
final_documents = await load_table_from_storage("documents", context.storage)
|
documents = await load_table_from_storage("documents", context.storage)
|
||||||
final_relationships = await load_table_from_storage(
|
relationships = await load_table_from_storage("relationships", context.storage)
|
||||||
"relationships", context.storage
|
text_units = await load_table_from_storage("text_units", context.storage)
|
||||||
)
|
entities = await load_table_from_storage("entities", context.storage)
|
||||||
final_text_units = await load_table_from_storage("text_units", context.storage)
|
community_reports = await load_table_from_storage(
|
||||||
final_entities = await load_table_from_storage("entities", context.storage)
|
|
||||||
final_community_reports = await load_table_from_storage(
|
|
||||||
"community_reports", context.storage
|
"community_reports", context.storage
|
||||||
)
|
)
|
||||||
|
|
||||||
embedded_fields = get_embedded_fields(config)
|
embedded_fields = get_embedded_fields(config)
|
||||||
text_embed = get_embedding_settings(config)
|
text_embed = get_embedding_settings(config)
|
||||||
|
|
||||||
await generate_text_embeddings(
|
result = await generate_text_embeddings(
|
||||||
final_documents=final_documents,
|
documents=documents,
|
||||||
final_relationships=final_relationships,
|
relationships=relationships,
|
||||||
final_text_units=final_text_units,
|
text_units=text_units,
|
||||||
final_entities=final_entities,
|
entities=entities,
|
||||||
final_community_reports=final_community_reports,
|
community_reports=community_reports,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
cache=context.cache,
|
cache=context.cache,
|
||||||
storage=context.storage,
|
|
||||||
text_embed_config=text_embed,
|
text_embed_config=text_embed,
|
||||||
embedded_fields=embedded_fields,
|
embedded_fields=embedded_fields,
|
||||||
snapshot_embeddings_enabled=config.snapshots.embeddings,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return WorkflowFunctionOutput(result=None, config=None)
|
if config.snapshots.embeddings:
|
||||||
|
for name, table in result.items():
|
||||||
|
await write_table_to_storage(
|
||||||
|
table,
|
||||||
|
f"embeddings.{name}",
|
||||||
|
context.storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowFunctionOutput(result=result, config=None)
|
||||||
|
|
||||||
|
|
||||||
async def generate_text_embeddings(
|
async def generate_text_embeddings(
|
||||||
final_documents: pd.DataFrame | None,
|
documents: pd.DataFrame | None,
|
||||||
final_relationships: pd.DataFrame | None,
|
relationships: pd.DataFrame | None,
|
||||||
final_text_units: pd.DataFrame | None,
|
text_units: pd.DataFrame | None,
|
||||||
final_entities: pd.DataFrame | None,
|
entities: pd.DataFrame | None,
|
||||||
final_community_reports: pd.DataFrame | None,
|
community_reports: pd.DataFrame | None,
|
||||||
callbacks: WorkflowCallbacks,
|
callbacks: WorkflowCallbacks,
|
||||||
cache: PipelineCache,
|
cache: PipelineCache,
|
||||||
storage: PipelineStorage,
|
|
||||||
text_embed_config: dict,
|
text_embed_config: dict,
|
||||||
embedded_fields: set[str],
|
embedded_fields: set[str],
|
||||||
snapshot_embeddings_enabled: bool = False,
|
) -> dict[str, pd.DataFrame]:
|
||||||
) -> None:
|
|
||||||
"""All the steps to generate all embeddings."""
|
"""All the steps to generate all embeddings."""
|
||||||
embedding_param_map = {
|
embedding_param_map = {
|
||||||
document_text_embedding: {
|
document_text_embedding: {
|
||||||
"data": final_documents.loc[:, ["id", "text"]]
|
"data": documents.loc[:, ["id", "text"]] if documents is not None else None,
|
||||||
if final_documents is not None
|
|
||||||
else None,
|
|
||||||
"embed_column": "text",
|
"embed_column": "text",
|
||||||
},
|
},
|
||||||
relationship_description_embedding: {
|
relationship_description_embedding: {
|
||||||
"data": final_relationships.loc[:, ["id", "description"]]
|
"data": relationships.loc[:, ["id", "description"]]
|
||||||
if final_relationships is not None
|
if relationships is not None
|
||||||
else None,
|
else None,
|
||||||
"embed_column": "description",
|
"embed_column": "description",
|
||||||
},
|
},
|
||||||
text_unit_text_embedding: {
|
text_unit_text_embedding: {
|
||||||
"data": final_text_units.loc[:, ["id", "text"]]
|
"data": text_units.loc[:, ["id", "text"]]
|
||||||
if final_text_units is not None
|
if text_units is not None
|
||||||
else None,
|
else None,
|
||||||
"embed_column": "text",
|
"embed_column": "text",
|
||||||
},
|
},
|
||||||
entity_title_embedding: {
|
entity_title_embedding: {
|
||||||
"data": final_entities.loc[:, ["id", "title"]]
|
"data": entities.loc[:, ["id", "title"]] if entities is not None else None,
|
||||||
if final_entities is not None
|
|
||||||
else None,
|
|
||||||
"embed_column": "title",
|
"embed_column": "title",
|
||||||
},
|
},
|
||||||
entity_description_embedding: {
|
entity_description_embedding: {
|
||||||
"data": final_entities.loc[:, ["id", "title", "description"]].assign(
|
"data": entities.loc[:, ["id", "title", "description"]].assign(
|
||||||
title_description=lambda df: df["title"] + ":" + df["description"]
|
title_description=lambda df: df["title"] + ":" + df["description"]
|
||||||
)
|
)
|
||||||
if final_entities is not None
|
if entities is not None
|
||||||
else None,
|
else None,
|
||||||
"embed_column": "title_description",
|
"embed_column": "title_description",
|
||||||
},
|
},
|
||||||
community_title_embedding: {
|
community_title_embedding: {
|
||||||
"data": final_community_reports.loc[:, ["id", "title"]]
|
"data": community_reports.loc[:, ["id", "title"]]
|
||||||
if final_community_reports is not None
|
if community_reports is not None
|
||||||
else None,
|
else None,
|
||||||
"embed_column": "title",
|
"embed_column": "title",
|
||||||
},
|
},
|
||||||
community_summary_embedding: {
|
community_summary_embedding: {
|
||||||
"data": final_community_reports.loc[:, ["id", "summary"]]
|
"data": community_reports.loc[:, ["id", "summary"]]
|
||||||
if final_community_reports is not None
|
if community_reports is not None
|
||||||
else None,
|
else None,
|
||||||
"embed_column": "summary",
|
"embed_column": "summary",
|
||||||
},
|
},
|
||||||
community_full_content_embedding: {
|
community_full_content_embedding: {
|
||||||
"data": final_community_reports.loc[:, ["id", "full_content"]]
|
"data": community_reports.loc[:, ["id", "full_content"]]
|
||||||
if final_community_reports is not None
|
if community_reports is not None
|
||||||
else None,
|
else None,
|
||||||
"embed_column": "full_content",
|
"embed_column": "full_content",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.info("Creating embeddings")
|
log.info("Creating embeddings")
|
||||||
|
outputs = {}
|
||||||
for field in embedded_fields:
|
for field in embedded_fields:
|
||||||
await _run_and_snapshot_embeddings(
|
outputs[field] = await _run_and_snapshot_embeddings(
|
||||||
name=field,
|
name=field,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
storage=storage,
|
|
||||||
text_embed_config=text_embed_config,
|
text_embed_config=text_embed_config,
|
||||||
snapshot_embeddings_enabled=snapshot_embeddings_enabled,
|
|
||||||
**embedding_param_map[field],
|
**embedding_param_map[field],
|
||||||
)
|
)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
async def _run_and_snapshot_embeddings(
|
async def _run_and_snapshot_embeddings(
|
||||||
@ -153,21 +150,16 @@ async def _run_and_snapshot_embeddings(
|
|||||||
embed_column: str,
|
embed_column: str,
|
||||||
callbacks: WorkflowCallbacks,
|
callbacks: WorkflowCallbacks,
|
||||||
cache: PipelineCache,
|
cache: PipelineCache,
|
||||||
storage: PipelineStorage,
|
|
||||||
text_embed_config: dict,
|
text_embed_config: dict,
|
||||||
snapshot_embeddings_enabled: bool,
|
) -> pd.DataFrame:
|
||||||
) -> None:
|
|
||||||
"""All the steps to generate single embedding."""
|
"""All the steps to generate single embedding."""
|
||||||
if text_embed_config:
|
data["embedding"] = await embed_text(
|
||||||
data["embedding"] = await embed_text(
|
input=data,
|
||||||
input=data,
|
callbacks=callbacks,
|
||||||
callbacks=callbacks,
|
cache=cache,
|
||||||
cache=cache,
|
embed_column=embed_column,
|
||||||
embed_column=embed_column,
|
embedding_name=name,
|
||||||
embedding_name=name,
|
strategy=text_embed_config["strategy"],
|
||||||
strategy=text_embed_config["strategy"],
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if snapshot_embeddings_enabled is True:
|
return data.loc[:, ["id", "embedding"]]
|
||||||
data = data.loc[:, ["id", "embedding"]]
|
|
||||||
await write_table_to_storage(data, f"embeddings.{name}", storage)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user