Test and unify text splitter functionality (#1547)

* add text_splitting unit test

* change folder test text splitting

* fix chunk fn

* test new function

* run formatter

* run spell check

* run semver

* remove tiktoken mocked from tests

* change progress ticker

* fix ruff check
This commit is contained in:
Dayenne Souza 2025-01-13 18:42:44 -03:00 committed by GitHub
parent 0e7d22bfb0
commit 2f2cfa7b70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 220 additions and 140 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "unit tests for text_splitting"
}

View File

@ -10,7 +10,10 @@ import tiktoken
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.index.text_splitting.text_splitting import Tokenizer
from graphrag.index.text_splitting.text_splitting import (
Tokenizer,
split_multiple_texts_on_tokens,
)
from graphrag.logger.progress import ProgressTicker
@ -31,7 +34,7 @@ def run_tokens(
def decode(tokens: list[int]) -> str:
return enc.decode(tokens)
return _split_text_on_tokens(
return split_multiple_texts_on_tokens(
input,
Tokenizer(
chunk_overlap=chunk_overlap,
@ -43,44 +46,6 @@ def run_tokens(
)
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
# So we could have better control over the chunking process
def _split_text_on_tokens(
texts: list[str], enc: Tokenizer, tick: ProgressTicker
) -> list[TextChunk]:
"""Split incoming text and return chunks."""
result = []
mapped_ids = []
for source_doc_idx, text in enumerate(texts):
encoded = enc.encode(text)
tick(1)
mapped_ids.append((source_doc_idx, encoded))
input_ids: list[tuple[int, int]] = [
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
]
start_idx = 0
cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = enc.decode([id for _, id in chunk_ids])
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
result.append(
TextChunk(
text_chunk=chunk_text,
source_doc_indices=doc_indices,
n_tokens=len(chunk_ids),
)
)
start_idx += enc.tokens_per_chunk - enc.chunk_overlap
cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return result
def run_sentences(
input: list[str], _config: ChunkingConfig, tick: ProgressTicker
) -> Iterable[TextChunk]:

View File

@ -3,19 +3,18 @@
"""A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models."""
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, cast
import pandas as pd
import tiktoken
import graphrag.config.defaults as defs
from graphrag.index.utils.tokens import num_tokens_from_string
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.logger.progress import ProgressTicker
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]
@ -123,10 +122,10 @@ class TokenTextSplitter(TextSplitter):
def split_text(self, text: str | list[str]) -> list[str]:
"""Split text method."""
if cast("bool", pd.isna(text)) or text == "":
return []
if isinstance(text, list):
text = " ".join(text)
elif cast("bool", pd.isna(text)) or text == "":
return []
if not isinstance(text, str):
msg = f"Attempting to split a non-string value, actual is {type(text)}"
raise TypeError(msg)
@ -138,108 +137,57 @@ class TokenTextSplitter(TextSplitter):
encode=lambda text: self.encode(text),
)
return split_text_on_tokens(text=text, tokenizer=tokenizer)
return split_single_text_on_tokens(text=text, tokenizer=tokenizer)
class TextListSplitterType(str, Enum):
"""Enum for the type of the TextListSplitter."""
DELIMITED_STRING = "delimited_string"
JSON = "json"
class TextListSplitter(TextSplitter):
"""Text list splitter class definition."""
def __init__(
self,
chunk_size: int,
splitter_type: TextListSplitterType = TextListSplitterType.JSON,
input_delimiter: str | None = None,
output_delimiter: str | None = None,
model_name: str | None = None,
encoding_name: str | None = None,
):
"""Initialize the TextListSplitter with a chunk size."""
# Set the chunk overlap to 0 as we use full strings
super().__init__(chunk_size, chunk_overlap=0)
self._type = splitter_type
self._input_delimiter = input_delimiter
self._output_delimiter = output_delimiter or "\n"
self._length_function = lambda x: num_tokens_from_string(
x, model=model_name, encoding_name=encoding_name
)
def split_text(self, text: str | list[str]) -> Iterable[str]:
"""Split a string list into a list of strings for a given chunk size."""
if not text:
return []
result: list[str] = []
current_chunk: list[str] = []
# Add the brackets
current_length: int = self._length_function("[]")
# Input should be a string list joined by a delimiter
string_list = self._load_text_list(text)
if len(string_list) == 1:
return string_list
for item in string_list:
# Count the length of the item and add comma
item_length = self._length_function(f"{item},")
if current_length + item_length > self._chunk_size:
if current_chunk and len(current_chunk) > 0:
# Add the current chunk to the result
self._append_to_result(result, current_chunk)
# Start a new chunk
current_chunk = [item]
# Add 2 for the brackets
current_length = item_length
else:
# Add the item to the current chunk
current_chunk.append(item)
# Add 1 for the comma
current_length += item_length
# Add the last chunk to the result
self._append_to_result(result, current_chunk)
return result
def _load_text_list(self, text: str | list[str]):
"""Load the text list based on the type."""
if isinstance(text, list):
string_list = text
elif self._type == TextListSplitterType.JSON:
string_list = json.loads(text)
else:
string_list = text.split(self._input_delimiter)
return string_list
def _append_to_result(self, chunk_list: list[str], new_chunk: list[str]):
"""Append the current chunk to the result."""
if new_chunk and len(new_chunk) > 0:
if self._type == TextListSplitterType.JSON:
chunk_list.append(json.dumps(new_chunk, ensure_ascii=False))
else:
chunk_list.append(self._output_delimiter.join(new_chunk))
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
"""Split incoming text and return chunks using tokenizer."""
splits: list[str] = []
def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
splits.append(tokenizer.decode(chunk_ids))
chunk_text = tokenizer.decode(list(chunk_ids))
result.append(chunk_text) # Append chunked text as string
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits
return result
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
# So we could have better control over the chunking process
def split_multiple_texts_on_tokens(
texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker
) -> list[TextChunk]:
"""Split multiple texts and return chunks with metadata using the tokenizer."""
result = []
mapped_ids = []
for source_doc_idx, text in enumerate(texts):
encoded = tokenizer.encode(text)
if tick:
tick(1) # Track progress if tick callback is provided
mapped_ids.append((source_doc_idx, encoded))
input_ids = [
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
]
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return result

View File

@ -0,0 +1,2 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

View File

@ -0,0 +1,161 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from unittest import mock
from unittest.mock import MagicMock
import pytest
from graphrag.index.text_splitting.text_splitting import (
NoopTextSplitter,
Tokenizer,
TokenTextSplitter,
split_multiple_texts_on_tokens,
split_single_text_on_tokens,
)
def test_noop_text_splitter() -> None:
splitter = NoopTextSplitter()
assert list(splitter.split_text("some text")) == ["some text"]
assert list(splitter.split_text(["some", "text"])) == ["some", "text"]
class MockTokenizer:
def encode(self, text):
return [ord(char) for char in text]
def decode(self, token_ids):
return "".join(chr(id) for id in token_ids)
def test_split_text_str_empty():
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
result = splitter.split_text("")
assert result == []
def test_split_text_str_bool():
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
result = splitter.split_text(None) # type: ignore
assert result == []
def test_split_text_str_int():
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
with pytest.raises(TypeError):
splitter.split_text(123) # type: ignore
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
def test_split_text_large_input(mock_split):
large_text = "a" * 10_000
mock_split.return_value = ["chunk"] * 2_000
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
result = splitter.split_text(large_text)
assert len(result) == 2_000, "Large input was not split correctly"
mock_split.assert_called_once()
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
@mock.patch("graphrag.index.text_splitting.text_splitting.Tokenizer")
def test_token_text_splitter(mock_tokenizer, mock_split_text):
text = "chunk1 chunk2 chunk3"
expected_chunks = ["chunk1", "chunk2", "chunk3"]
mocked_tokenizer = MagicMock()
mock_tokenizer.return_value = mocked_tokenizer
mock_split_text.return_value = expected_chunks
splitter = TokenTextSplitter()
splitter.split_text(["chunk1", "chunk2", "chunk3"])
mock_split_text.assert_called_once_with(text=text, tokenizer=mocked_tokenizer)
def test_encode_basic():
splitter = TokenTextSplitter()
result = splitter.encode("abc def")
assert result == [13997, 711], "Encoding failed to return expected tokens"
def test_num_tokens_empty_input():
splitter = TokenTextSplitter()
result = splitter.num_tokens("")
assert result == 0, "Token count for empty input should be 0"
def test_model_name():
splitter = TokenTextSplitter(model_name="gpt-4o")
result = splitter.encode("abc def")
assert result == [26682, 1056], "Encoding failed to return expected tokens"
@mock.patch("tiktoken.encoding_for_model", side_effect=KeyError)
@mock.patch("tiktoken.get_encoding")
def test_model_name_exception(mock_get_encoding, mock_encoding_for_model):
mock_get_encoding.return_value = mock.MagicMock()
TokenTextSplitter(model_name="mock_model", encoding_name="mock_encoding")
mock_get_encoding.assert_called_once_with("mock_encoding")
mock_encoding_for_model.assert_called_once_with("mock_model")
def test_split_single_text_on_tokens():
text = "This is a test text, meaning to be taken seriously by this test only."
mocked_tokenizer = MockTokenizer()
tokenizer = Tokenizer(
chunk_overlap=5,
tokens_per_chunk=10,
decode=mocked_tokenizer.decode,
encode=lambda text: mocked_tokenizer.encode(text),
)
expected_splits = [
"This is a ",
"is a test ",
"test text,",
"text, mean",
" meaning t",
"ing to be ",
"o be taken",
"taken seri", # cspell:disable-line
" seriously",
"ously by t", # cspell:disable-line
" by this t",
"his test o",
"est only.",
"nly.",
]
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
assert result == expected_splits
def test_split_multiple_texts_on_tokens():
texts = [
"This is a test text, meaning to be taken seriously by this test only.",
"This is th second text, meaning to be taken seriously by this test only.",
]
mocked_tokenizer = MockTokenizer()
mock_tick = MagicMock()
tokenizer = Tokenizer(
chunk_overlap=5,
tokens_per_chunk=10,
decode=mocked_tokenizer.decode,
encode=lambda text: mocked_tokenizer.encode(text),
)
split_multiple_texts_on_tokens(texts, tokenizer, tick=mock_tick)
mock_tick.assert_called()