diff --git a/.semversioner/next-release/patch-20240925221303861991.json b/.semversioner/next-release/patch-20240925221303861991.json new file mode 100644 index 00000000..2b34b9ec --- /dev/null +++ b/.semversioner/next-release/patch-20240925221303861991.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Collapse create-final-documents." +} diff --git a/graphrag/index/workflows/v1/create_final_documents.py b/graphrag/index/workflows/v1/create_final_documents.py index d09ce001..bf889b2c 100644 --- a/graphrag/index/workflows/v1/create_final_documents.py +++ b/graphrag/index/workflows/v1/create_final_documents.py @@ -16,7 +16,6 @@ def build_steps( ## Dependencies * `workflow:create_base_documents` - * `workflow:create_base_document_nodes` """ base_text_embed = config.get("text_embed", {}) document_raw_content_embed_config = config.get( @@ -25,17 +24,12 @@ def build_steps( skip_raw_content_embedding = config.get("skip_raw_content_embedding", False) return [ { - "verb": "rename", - "args": {"columns": {"text_units": "text_unit_ids"}}, + "verb": "create_final_documents", + "args": { + "columns": {"text_units": "text_unit_ids"}, + "skip_embedding": skip_raw_content_embedding, + "text_embed": document_raw_content_embed_config, + }, "input": {"source": "workflow:create_base_documents"}, }, - { - "verb": "text_embed", - "enabled": not skip_raw_content_embedding, - "args": { - "column": "raw_content", - "to": "raw_content_embedding", - **document_raw_content_embed_config, - }, - }, ] diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py index ead84cfd..57e9e215 100644 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ b/graphrag/index/workflows/v1/subflows/__init__.py @@ -6,6 +6,7 @@ from .create_base_documents import create_base_documents from .create_base_text_units import create_base_text_units from .create_final_communities import create_final_communities +from .create_final_documents import create_final_documents from .create_final_nodes import create_final_nodes from .create_final_relationships import ( create_final_relationships, @@ -16,6 +17,7 @@ __all__ = [ "create_base_documents", "create_base_text_units", "create_final_communities", + "create_final_documents", "create_final_nodes", "create_final_relationships", "create_final_text_units_pre_embedding", diff --git a/graphrag/index/workflows/v1/subflows/create_final_documents.py b/graphrag/index/workflows/v1/subflows/create_final_documents.py new file mode 100644 index 00000000..cb8accaa --- /dev/null +++ b/graphrag/index/workflows/v1/subflows/create_final_documents.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform final documents.""" + +from typing import cast + +import pandas as pd +from datashaper import ( + Table, + VerbCallbacks, + VerbInput, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result + +from graphrag.index.cache import PipelineCache +from graphrag.index.verbs.text.embed.text_embed import text_embed_df + + +@verb( + name="create_final_documents", + treats_input_tables_as_immutable=True, +) +async def create_final_documents( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + text_embed: dict, + skip_embedding: bool = False, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform final documents.""" + source = cast(pd.DataFrame, input.get_input()) + + source.rename(columns={"text_units": "text_unit_ids"}, inplace=True) + + if not skip_embedding: + source = await text_embed_df( + source, + callbacks, + cache, + column="raw_content", + strategy=text_embed["strategy"], + to="raw_content_embedding", + ) + + return create_verb_result(cast(Table, source)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_relationships.py b/graphrag/index/workflows/v1/subflows/create_final_relationships.py index 02fc7948..d7d7fa95 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_relationships.py +++ b/graphrag/index/workflows/v1/subflows/create_final_relationships.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""All the steps to transform final relationships before they are embedded.""" +"""All the steps to transform final relationships.""" from typing import Any, cast @@ -35,7 +35,7 @@ async def create_final_relationships( skip_embedding: bool = False, **_kwargs: dict, ) -> VerbResult: - """All the steps to transform final relationships before they are embedded.""" + """All the steps to transform final relationships.""" table = cast(pd.DataFrame, input.get_input()) nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table) diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py new file mode 100644 index 00000000..7092b963 --- /dev/null +++ b/tests/verbs/test_create_final_documents.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from graphrag.index.workflows.v1.create_final_documents import ( + build_steps, + workflow_name, +) + +from .util import ( + compare_outputs, + get_config_for_workflow, + get_workflow_output, + load_expected, + load_input_tables, + remove_disabled_steps, +) + + +async def test_create_final_documents(): + input_tables = load_input_tables([ + "workflow:create_base_documents", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + config["skip_raw_content_embedding"] = True + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + compare_outputs(actual, expected) + + +async def test_create_final_documents_with_embeddings(): + input_tables = load_input_tables([ + "workflow:create_base_documents", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + config["skip_raw_content_embedding"] = False + # default config has a detailed standard embed config + # just override the strategy to mock so the rest of the required parameters are in place + config["document_raw_content_embed"]["strategy"]["type"] = "mock" + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + assert "raw_content_embedding" in actual.columns + assert len(actual.columns) == len(expected.columns) + 1 + # the mock impl returns an array of 3 floats for each embedding + assert len(actual["raw_content_embedding"][0]) == 3