mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-13 07:51:34 +00:00
Collapse create final documents (#1217)
* Collapse create_final_documents * Semver
This commit is contained in:
parent
dda4edd0fd
commit
14750f4d37
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Collapse create-final-documents."
|
||||||
|
}
|
||||||
@ -16,7 +16,6 @@ def build_steps(
|
|||||||
|
|
||||||
## Dependencies
|
## Dependencies
|
||||||
* `workflow:create_base_documents`
|
* `workflow:create_base_documents`
|
||||||
* `workflow:create_base_document_nodes`
|
|
||||||
"""
|
"""
|
||||||
base_text_embed = config.get("text_embed", {})
|
base_text_embed = config.get("text_embed", {})
|
||||||
document_raw_content_embed_config = config.get(
|
document_raw_content_embed_config = config.get(
|
||||||
@ -25,17 +24,12 @@ def build_steps(
|
|||||||
skip_raw_content_embedding = config.get("skip_raw_content_embedding", False)
|
skip_raw_content_embedding = config.get("skip_raw_content_embedding", False)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"verb": "rename",
|
"verb": "create_final_documents",
|
||||||
"args": {"columns": {"text_units": "text_unit_ids"}},
|
"args": {
|
||||||
|
"columns": {"text_units": "text_unit_ids"},
|
||||||
|
"skip_embedding": skip_raw_content_embedding,
|
||||||
|
"text_embed": document_raw_content_embed_config,
|
||||||
|
},
|
||||||
"input": {"source": "workflow:create_base_documents"},
|
"input": {"source": "workflow:create_base_documents"},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"verb": "text_embed",
|
|
||||||
"enabled": not skip_raw_content_embedding,
|
|
||||||
"args": {
|
|
||||||
"column": "raw_content",
|
|
||||||
"to": "raw_content_embedding",
|
|
||||||
**document_raw_content_embed_config,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
from .create_base_documents import create_base_documents
|
from .create_base_documents import create_base_documents
|
||||||
from .create_base_text_units import create_base_text_units
|
from .create_base_text_units import create_base_text_units
|
||||||
from .create_final_communities import create_final_communities
|
from .create_final_communities import create_final_communities
|
||||||
|
from .create_final_documents import create_final_documents
|
||||||
from .create_final_nodes import create_final_nodes
|
from .create_final_nodes import create_final_nodes
|
||||||
from .create_final_relationships import (
|
from .create_final_relationships import (
|
||||||
create_final_relationships,
|
create_final_relationships,
|
||||||
@ -16,6 +17,7 @@ __all__ = [
|
|||||||
"create_base_documents",
|
"create_base_documents",
|
||||||
"create_base_text_units",
|
"create_base_text_units",
|
||||||
"create_final_communities",
|
"create_final_communities",
|
||||||
|
"create_final_documents",
|
||||||
"create_final_nodes",
|
"create_final_nodes",
|
||||||
"create_final_relationships",
|
"create_final_relationships",
|
||||||
"create_final_text_units_pre_embedding",
|
"create_final_text_units_pre_embedding",
|
||||||
|
|||||||
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
"""All the steps to transform final documents."""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from datashaper import (
|
||||||
|
Table,
|
||||||
|
VerbCallbacks,
|
||||||
|
VerbInput,
|
||||||
|
verb,
|
||||||
|
)
|
||||||
|
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||||
|
|
||||||
|
from graphrag.index.cache import PipelineCache
|
||||||
|
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||||
|
|
||||||
|
|
||||||
|
@verb(
|
||||||
|
name="create_final_documents",
|
||||||
|
treats_input_tables_as_immutable=True,
|
||||||
|
)
|
||||||
|
async def create_final_documents(
|
||||||
|
input: VerbInput,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
cache: PipelineCache,
|
||||||
|
text_embed: dict,
|
||||||
|
skip_embedding: bool = False,
|
||||||
|
**_kwargs: dict,
|
||||||
|
) -> VerbResult:
|
||||||
|
"""All the steps to transform final documents."""
|
||||||
|
source = cast(pd.DataFrame, input.get_input())
|
||||||
|
|
||||||
|
source.rename(columns={"text_units": "text_unit_ids"}, inplace=True)
|
||||||
|
|
||||||
|
if not skip_embedding:
|
||||||
|
source = await text_embed_df(
|
||||||
|
source,
|
||||||
|
callbacks,
|
||||||
|
cache,
|
||||||
|
column="raw_content",
|
||||||
|
strategy=text_embed["strategy"],
|
||||||
|
to="raw_content_embedding",
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_verb_result(cast(Table, source))
|
||||||
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
# Licensed under the MIT License
|
# Licensed under the MIT License
|
||||||
|
|
||||||
"""All the steps to transform final relationships before they are embedded."""
|
"""All the steps to transform final relationships."""
|
||||||
|
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ async def create_final_relationships(
|
|||||||
skip_embedding: bool = False,
|
skip_embedding: bool = False,
|
||||||
**_kwargs: dict,
|
**_kwargs: dict,
|
||||||
) -> VerbResult:
|
) -> VerbResult:
|
||||||
"""All the steps to transform final relationships before they are embedded."""
|
"""All the steps to transform final relationships."""
|
||||||
table = cast(pd.DataFrame, input.get_input())
|
table = cast(pd.DataFrame, input.get_input())
|
||||||
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
||||||
|
|
||||||
|
|||||||
66
tests/verbs/test_create_final_documents.py
Normal file
66
tests/verbs/test_create_final_documents.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
from graphrag.index.workflows.v1.create_final_documents import (
|
||||||
|
build_steps,
|
||||||
|
workflow_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .util import (
|
||||||
|
compare_outputs,
|
||||||
|
get_config_for_workflow,
|
||||||
|
get_workflow_output,
|
||||||
|
load_expected,
|
||||||
|
load_input_tables,
|
||||||
|
remove_disabled_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_final_documents():
|
||||||
|
input_tables = load_input_tables([
|
||||||
|
"workflow:create_base_documents",
|
||||||
|
])
|
||||||
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
|
config["skip_raw_content_embedding"] = True
|
||||||
|
|
||||||
|
steps = remove_disabled_steps(build_steps(config))
|
||||||
|
|
||||||
|
actual = await get_workflow_output(
|
||||||
|
input_tables,
|
||||||
|
{
|
||||||
|
"steps": steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
compare_outputs(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_final_documents_with_embeddings():
|
||||||
|
input_tables = load_input_tables([
|
||||||
|
"workflow:create_base_documents",
|
||||||
|
])
|
||||||
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
|
config["skip_raw_content_embedding"] = False
|
||||||
|
# default config has a detailed standard embed config
|
||||||
|
# just override the strategy to mock so the rest of the required parameters are in place
|
||||||
|
config["document_raw_content_embed"]["strategy"]["type"] = "mock"
|
||||||
|
|
||||||
|
steps = remove_disabled_steps(build_steps(config))
|
||||||
|
|
||||||
|
actual = await get_workflow_output(
|
||||||
|
input_tables,
|
||||||
|
{
|
||||||
|
"steps": steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "raw_content_embedding" in actual.columns
|
||||||
|
assert len(actual.columns) == len(expected.columns) + 1
|
||||||
|
# the mock impl returns an array of 3 floats for each embedding
|
||||||
|
assert len(actual["raw_content_embedding"][0]) == 3
|
||||||
Loading…
x
Reference in New Issue
Block a user