mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-02 18:10:16 +00:00
Collapse create base text units (#1178)
* Collapse non-attribute verbs * Include document_column_attributes in collapse * Remove merge_override verb * Semver * Setup initial test and config * Collapse create_base_text_units * Semver * Spelling * Fix smoke tests * Addres PR comments --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
be7d3eb189
commit
1755afbdec
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Collapse create_base_text_units."
|
||||
}
|
||||
@ -100,6 +100,7 @@ aembed
|
||||
dedupe
|
||||
dropna
|
||||
dtypes
|
||||
notna
|
||||
|
||||
# LLM Terms
|
||||
AOAI
|
||||
|
||||
@ -16,7 +16,7 @@ def genid(
|
||||
input: VerbInput,
|
||||
to: str,
|
||||
method: str = "md5_hash",
|
||||
hash: list[str] = [], # noqa A002
|
||||
hash: list[str] | None = None, # noqa A002
|
||||
**_kwargs: dict,
|
||||
) -> TableContainer:
|
||||
"""
|
||||
@ -52,15 +52,29 @@ def genid(
|
||||
"""
|
||||
data = cast(pd.DataFrame, input.source.table)
|
||||
|
||||
if method == "md5_hash":
|
||||
if len(hash) == 0:
|
||||
msg = 'Must specify the "hash" columns to use md5_hash method'
|
||||
output = genid_df(data, to, method, hash)
|
||||
|
||||
return TableContainer(table=output)
|
||||
|
||||
|
||||
def genid_df(
|
||||
input: pd.DataFrame,
|
||||
to: str,
|
||||
method: str = "md5_hash",
|
||||
hash: list[str] | None = None, # noqa A002
|
||||
):
|
||||
"""Generate a unique id for each row in the tabular data."""
|
||||
data = input
|
||||
match method:
|
||||
case "md5_hash":
|
||||
if not hash:
|
||||
msg = 'Must specify the "hash" columns to use md5_hash method'
|
||||
raise ValueError(msg)
|
||||
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
|
||||
case "increment":
|
||||
data[to] = data.index + 1
|
||||
case _:
|
||||
msg = f"Unknown method {method}"
|
||||
raise ValueError(msg)
|
||||
|
||||
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
|
||||
elif method == "increment":
|
||||
data[to] = data.index + 1
|
||||
else:
|
||||
msg = f"Unknown method {method}"
|
||||
raise ValueError(msg)
|
||||
return TableContainer(table=data)
|
||||
return data
|
||||
|
||||
@ -85,9 +85,24 @@ def chunk(
|
||||
type: sentence
|
||||
```
|
||||
"""
|
||||
input_table = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
output = chunk_df(input_table, column, to, callbacks, strategy)
|
||||
|
||||
return TableContainer(table=output)
|
||||
|
||||
|
||||
def chunk_df(
|
||||
input: pd.DataFrame,
|
||||
column: str,
|
||||
to: str,
|
||||
callbacks: VerbCallbacks,
|
||||
strategy: dict[str, Any] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Chunk a piece of text into smaller pieces."""
|
||||
output = input
|
||||
if strategy is None:
|
||||
strategy = {}
|
||||
output = cast(pd.DataFrame, input.get_input())
|
||||
strategy_name = strategy.get("type", ChunkStrategyType.tokens)
|
||||
strategy_config = {**strategy}
|
||||
strategy_exec = load_strategy(strategy_name)
|
||||
@ -102,7 +117,7 @@ def chunk(
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
return TableContainer(table=output)
|
||||
return output
|
||||
|
||||
|
||||
def run_strategy(
|
||||
|
||||
@ -22,91 +22,16 @@ def build_steps(
|
||||
chunk_column_name = config.get("chunk_column", "chunk")
|
||||
chunk_by_columns = config.get("chunk_by", []) or []
|
||||
n_tokens_column_name = config.get("n_tokens_column", "n_tokens")
|
||||
text_chunk = config.get("text_chunk", {})
|
||||
return [
|
||||
{
|
||||
"verb": "orderby",
|
||||
"verb": "create_base_text_units",
|
||||
"args": {
|
||||
"orders": [
|
||||
# sort for reproducibility
|
||||
{"column": "id", "direction": "asc"},
|
||||
]
|
||||
"chunk_column_name": chunk_column_name,
|
||||
"n_tokens_column_name": n_tokens_column_name,
|
||||
"chunk_by_columns": chunk_by_columns,
|
||||
**text_chunk,
|
||||
},
|
||||
"input": {"source": DEFAULT_INPUT_NAME},
|
||||
},
|
||||
{
|
||||
"verb": "zip",
|
||||
"args": {
|
||||
# Pack the document ids with the text
|
||||
# So when we unpack the chunks, we can restore the document id
|
||||
"columns": ["id", "text"],
|
||||
"to": "text_with_ids",
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "aggregate_override",
|
||||
"args": {
|
||||
"groupby": [*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
|
||||
"aggregations": [
|
||||
{
|
||||
"column": "text_with_ids",
|
||||
"operation": "array_agg",
|
||||
"to": "texts",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "chunk",
|
||||
"args": {"column": "texts", "to": "chunks", **config.get("text_chunk", {})},
|
||||
},
|
||||
{
|
||||
"verb": "select",
|
||||
"args": {
|
||||
"columns": [*chunk_by_columns, "chunks"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "unroll",
|
||||
"args": {
|
||||
"column": "chunks",
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "rename",
|
||||
"args": {
|
||||
"columns": {
|
||||
"chunks": chunk_column_name,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "genid",
|
||||
"args": {
|
||||
# Generate a unique id for each chunk
|
||||
"to": "chunk_id",
|
||||
"method": "md5_hash",
|
||||
"hash": [chunk_column_name],
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "unzip",
|
||||
"args": {
|
||||
"column": chunk_column_name,
|
||||
"to": ["document_ids", chunk_column_name, n_tokens_column_name],
|
||||
},
|
||||
},
|
||||
{"verb": "copy", "args": {"column": "chunk_id", "to": "id"}},
|
||||
{
|
||||
# ELIMINATE EMPTY CHUNKS
|
||||
"verb": "filter",
|
||||
"args": {
|
||||
"column": chunk_column_name,
|
||||
"criteria": [
|
||||
{
|
||||
"type": "value",
|
||||
"operator": "is not empty",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""The Indexing Engine workflows -> subflows package root."""
|
||||
|
||||
from .create_base_documents import create_base_documents
|
||||
from .create_base_text_units import create_base_text_units
|
||||
from .create_final_communities import create_final_communities
|
||||
from .create_final_nodes import create_final_nodes
|
||||
from .create_final_relationships_post_embedding import (
|
||||
@ -16,6 +17,7 @@ from .create_final_text_units_pre_embedding import create_final_text_units_pre_e
|
||||
|
||||
__all__ = [
|
||||
"create_base_documents",
|
||||
"create_base_text_units",
|
||||
"create_final_communities",
|
||||
"create_final_nodes",
|
||||
"create_final_relationships_post_embedding",
|
||||
|
||||
@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform base text_units."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.verbs.genid import genid_df
|
||||
from graphrag.index.verbs.overrides.aggregate import aggregate_df
|
||||
from graphrag.index.verbs.text.chunk.text_chunk import chunk_df
|
||||
|
||||
|
||||
@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
|
||||
def create_base_text_units(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
chunk_column_name: str,
|
||||
n_tokens_column_name: str,
|
||||
chunk_by_columns: list[str],
|
||||
strategy: dict[str, Any] | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform base text_units."""
|
||||
table = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
sort = table.sort_values(by=["id"], ascending=[True])
|
||||
|
||||
sort["text_with_ids"] = list(
|
||||
zip(*[sort[col] for col in ["id", "text"]], strict=True)
|
||||
)
|
||||
|
||||
aggregated = aggregate_df(
|
||||
sort,
|
||||
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
|
||||
aggregations=[
|
||||
{
|
||||
"column": "text_with_ids",
|
||||
"operation": "array_agg",
|
||||
"to": "texts",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
chunked = chunk_df(
|
||||
aggregated,
|
||||
column="texts",
|
||||
to="chunks",
|
||||
callbacks=callbacks,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
||||
chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]])
|
||||
chunked = chunked.explode("chunks")
|
||||
chunked.rename(
|
||||
columns={
|
||||
"chunks": chunk_column_name,
|
||||
},
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
chunked = genid_df(
|
||||
chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name]
|
||||
)
|
||||
|
||||
chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame(
|
||||
chunked[chunk_column_name].tolist(), index=chunked.index
|
||||
)
|
||||
chunked["id"] = chunked["chunk_id"]
|
||||
|
||||
filtered = chunked[chunked[chunk_column_name].notna()].reset_index(drop=True)
|
||||
|
||||
return create_verb_result(
|
||||
cast(
|
||||
Table,
|
||||
filtered,
|
||||
)
|
||||
)
|
||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -7,7 +7,7 @@
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 11,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
},
|
||||
"create_base_extracted_entities": {
|
||||
|
||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -7,7 +7,7 @@
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 11,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
},
|
||||
"create_base_extracted_entities": {
|
||||
|
||||
35
tests/verbs/test_create_base_text_units.py
Normal file
35
tests/verbs/test_create_base_text_units.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.index.workflows.v1.create_base_text_units import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
compare_outputs,
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_expected,
|
||||
load_input_tables,
|
||||
)
|
||||
|
||||
|
||||
async def test_create_base_text_units():
|
||||
input_tables = load_input_tables(inputs=[])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
# test data was created with 4o, so we need to match the encoding for chunks to be identical
|
||||
config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
@ -31,6 +31,7 @@ def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
|
||||
# remove the workflow: prefix if it exists, because that is not part of the actual table filename
|
||||
name = input.replace("workflow:", "")
|
||||
input_tables[input] = pd.read_parquet(f"tests/verbs/data/{name}.parquet")
|
||||
|
||||
return input_tables
|
||||
|
||||
|
||||
@ -42,6 +43,7 @@ def load_expected(output: str) -> pd.DataFrame:
|
||||
def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
|
||||
"""Instantiates the bare minimum config to get a default workflow config for testing."""
|
||||
config = create_graphrag_config()
|
||||
print(config)
|
||||
pipeline_config = create_pipeline_config(config)
|
||||
print(pipeline_config.workflows)
|
||||
result = next(conf for conf in pipeline_config.workflows if conf.name == name)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user