mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-11 23:13:00 +00:00
Collapse create final documents (#1217)
* Collapse create_final_documents * Semver
This commit is contained in:
parent
dda4edd0fd
commit
14750f4d37
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Collapse create-final-documents."
|
||||
}
|
||||
@ -16,7 +16,6 @@ def build_steps(
|
||||
|
||||
## Dependencies
|
||||
* `workflow:create_base_documents`
|
||||
* `workflow:create_base_document_nodes`
|
||||
"""
|
||||
base_text_embed = config.get("text_embed", {})
|
||||
document_raw_content_embed_config = config.get(
|
||||
@ -25,17 +24,12 @@ def build_steps(
|
||||
skip_raw_content_embedding = config.get("skip_raw_content_embedding", False)
|
||||
return [
|
||||
{
|
||||
"verb": "rename",
|
||||
"args": {"columns": {"text_units": "text_unit_ids"}},
|
||||
"verb": "create_final_documents",
|
||||
"args": {
|
||||
"columns": {"text_units": "text_unit_ids"},
|
||||
"skip_embedding": skip_raw_content_embedding,
|
||||
"text_embed": document_raw_content_embed_config,
|
||||
},
|
||||
"input": {"source": "workflow:create_base_documents"},
|
||||
},
|
||||
{
|
||||
"verb": "text_embed",
|
||||
"enabled": not skip_raw_content_embedding,
|
||||
"args": {
|
||||
"column": "raw_content",
|
||||
"to": "raw_content_embedding",
|
||||
**document_raw_content_embed_config,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
from .create_base_documents import create_base_documents
|
||||
from .create_base_text_units import create_base_text_units
|
||||
from .create_final_communities import create_final_communities
|
||||
from .create_final_documents import create_final_documents
|
||||
from .create_final_nodes import create_final_nodes
|
||||
from .create_final_relationships import (
|
||||
create_final_relationships,
|
||||
@ -16,6 +17,7 @@ __all__ = [
|
||||
"create_base_documents",
|
||||
"create_base_text_units",
|
||||
"create_final_communities",
|
||||
"create_final_documents",
|
||||
"create_final_nodes",
|
||||
"create_final_relationships",
|
||||
"create_final_text_units_pre_embedding",
|
||||
|
||||
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform final documents."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||
|
||||
|
||||
@verb(
|
||||
name="create_final_documents",
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def create_final_documents(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_embed: dict,
|
||||
skip_embedding: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final documents."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
source.rename(columns={"text_units": "text_unit_ids"}, inplace=True)
|
||||
|
||||
if not skip_embedding:
|
||||
source = await text_embed_df(
|
||||
source,
|
||||
callbacks,
|
||||
cache,
|
||||
column="raw_content",
|
||||
strategy=text_embed["strategy"],
|
||||
to="raw_content_embedding",
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, source))
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform final relationships before they are embedded."""
|
||||
"""All the steps to transform final relationships."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
@ -35,7 +35,7 @@ async def create_final_relationships(
|
||||
skip_embedding: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final relationships before they are embedded."""
|
||||
"""All the steps to transform final relationships."""
|
||||
table = cast(pd.DataFrame, input.get_input())
|
||||
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
||||
|
||||
|
||||
66
tests/verbs/test_create_final_documents.py
Normal file
66
tests/verbs/test_create_final_documents.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.index.workflows.v1.create_final_documents import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
compare_outputs,
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_expected,
|
||||
load_input_tables,
|
||||
remove_disabled_steps,
|
||||
)
|
||||
|
||||
|
||||
async def test_create_final_documents():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_documents",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_raw_content_embedding"] = True
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
|
||||
async def test_create_final_documents_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_documents",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_raw_content_embedding"] = False
|
||||
# default config has a detailed standard embed config
|
||||
# just override the strategy to mock so the rest of the required parameters are in place
|
||||
config["document_raw_content_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "raw_content_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["raw_content_embedding"][0]) == 3
|
||||
Loading…
x
Reference in New Issue
Block a user