mirror of
https://github.com/microsoft/graphrag.git
synced 2025-11-30 00:49:55 +00:00
Test and unify text splitter functionality (#1547)
* add text_splitting unit test * change folder test text splitting * fix chunk fn * test new function * run formatter * run spell check * run semver * remove tiktoken mocked from tests * change progress ticker * fix ruff check
This commit is contained in:
parent
0e7d22bfb0
commit
2f2cfa7b70
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "unit tests for text_splitting"
|
||||
}
|
||||
@ -10,7 +10,10 @@ import tiktoken
|
||||
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.index.text_splitting.text_splitting import Tokenizer
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
Tokenizer,
|
||||
split_multiple_texts_on_tokens,
|
||||
)
|
||||
from graphrag.logger.progress import ProgressTicker
|
||||
|
||||
|
||||
@ -31,7 +34,7 @@ def run_tokens(
|
||||
def decode(tokens: list[int]) -> str:
|
||||
return enc.decode(tokens)
|
||||
|
||||
return _split_text_on_tokens(
|
||||
return split_multiple_texts_on_tokens(
|
||||
input,
|
||||
Tokenizer(
|
||||
chunk_overlap=chunk_overlap,
|
||||
@ -43,44 +46,6 @@ def run_tokens(
|
||||
)
|
||||
|
||||
|
||||
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
|
||||
# So we could have better control over the chunking process
|
||||
def _split_text_on_tokens(
|
||||
texts: list[str], enc: Tokenizer, tick: ProgressTicker
|
||||
) -> list[TextChunk]:
|
||||
"""Split incoming text and return chunks."""
|
||||
result = []
|
||||
mapped_ids = []
|
||||
|
||||
for source_doc_idx, text in enumerate(texts):
|
||||
encoded = enc.encode(text)
|
||||
tick(1)
|
||||
mapped_ids.append((source_doc_idx, encoded))
|
||||
|
||||
input_ids: list[tuple[int, int]] = [
|
||||
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
|
||||
]
|
||||
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
while start_idx < len(input_ids):
|
||||
chunk_text = enc.decode([id for _, id in chunk_ids])
|
||||
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
|
||||
result.append(
|
||||
TextChunk(
|
||||
text_chunk=chunk_text,
|
||||
source_doc_indices=doc_indices,
|
||||
n_tokens=len(chunk_ids),
|
||||
)
|
||||
)
|
||||
start_idx += enc.tokens_per_chunk - enc.chunk_overlap
|
||||
cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def run_sentences(
|
||||
input: list[str], _config: ChunkingConfig, tick: ProgressTicker
|
||||
) -> Iterable[TextChunk]:
|
||||
|
||||
@ -3,19 +3,18 @@
|
||||
|
||||
"""A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Collection, Iterable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.logger.progress import ProgressTicker
|
||||
|
||||
EncodedText = list[int]
|
||||
DecodeFn = Callable[[EncodedText], str]
|
||||
@ -123,10 +122,10 @@ class TokenTextSplitter(TextSplitter):
|
||||
|
||||
def split_text(self, text: str | list[str]) -> list[str]:
|
||||
"""Split text method."""
|
||||
if cast("bool", pd.isna(text)) or text == "":
|
||||
return []
|
||||
if isinstance(text, list):
|
||||
text = " ".join(text)
|
||||
elif cast("bool", pd.isna(text)) or text == "":
|
||||
return []
|
||||
if not isinstance(text, str):
|
||||
msg = f"Attempting to split a non-string value, actual is {type(text)}"
|
||||
raise TypeError(msg)
|
||||
@ -138,108 +137,57 @@ class TokenTextSplitter(TextSplitter):
|
||||
encode=lambda text: self.encode(text),
|
||||
)
|
||||
|
||||
return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
return split_single_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class TextListSplitterType(str, Enum):
|
||||
"""Enum for the type of the TextListSplitter."""
|
||||
|
||||
DELIMITED_STRING = "delimited_string"
|
||||
JSON = "json"
|
||||
|
||||
|
||||
class TextListSplitter(TextSplitter):
|
||||
"""Text list splitter class definition."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int,
|
||||
splitter_type: TextListSplitterType = TextListSplitterType.JSON,
|
||||
input_delimiter: str | None = None,
|
||||
output_delimiter: str | None = None,
|
||||
model_name: str | None = None,
|
||||
encoding_name: str | None = None,
|
||||
):
|
||||
"""Initialize the TextListSplitter with a chunk size."""
|
||||
# Set the chunk overlap to 0 as we use full strings
|
||||
super().__init__(chunk_size, chunk_overlap=0)
|
||||
self._type = splitter_type
|
||||
self._input_delimiter = input_delimiter
|
||||
self._output_delimiter = output_delimiter or "\n"
|
||||
self._length_function = lambda x: num_tokens_from_string(
|
||||
x, model=model_name, encoding_name=encoding_name
|
||||
)
|
||||
|
||||
def split_text(self, text: str | list[str]) -> Iterable[str]:
|
||||
"""Split a string list into a list of strings for a given chunk size."""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
result: list[str] = []
|
||||
current_chunk: list[str] = []
|
||||
|
||||
# Add the brackets
|
||||
current_length: int = self._length_function("[]")
|
||||
|
||||
# Input should be a string list joined by a delimiter
|
||||
string_list = self._load_text_list(text)
|
||||
|
||||
if len(string_list) == 1:
|
||||
return string_list
|
||||
|
||||
for item in string_list:
|
||||
# Count the length of the item and add comma
|
||||
item_length = self._length_function(f"{item},")
|
||||
|
||||
if current_length + item_length > self._chunk_size:
|
||||
if current_chunk and len(current_chunk) > 0:
|
||||
# Add the current chunk to the result
|
||||
self._append_to_result(result, current_chunk)
|
||||
|
||||
# Start a new chunk
|
||||
current_chunk = [item]
|
||||
# Add 2 for the brackets
|
||||
current_length = item_length
|
||||
else:
|
||||
# Add the item to the current chunk
|
||||
current_chunk.append(item)
|
||||
# Add 1 for the comma
|
||||
current_length += item_length
|
||||
|
||||
# Add the last chunk to the result
|
||||
self._append_to_result(result, current_chunk)
|
||||
|
||||
return result
|
||||
|
||||
def _load_text_list(self, text: str | list[str]):
|
||||
"""Load the text list based on the type."""
|
||||
if isinstance(text, list):
|
||||
string_list = text
|
||||
elif self._type == TextListSplitterType.JSON:
|
||||
string_list = json.loads(text)
|
||||
else:
|
||||
string_list = text.split(self._input_delimiter)
|
||||
return string_list
|
||||
|
||||
def _append_to_result(self, chunk_list: list[str], new_chunk: list[str]):
|
||||
"""Append the current chunk to the result."""
|
||||
if new_chunk and len(new_chunk) > 0:
|
||||
if self._type == TextListSplitterType.JSON:
|
||||
chunk_list.append(json.dumps(new_chunk, ensure_ascii=False))
|
||||
else:
|
||||
chunk_list.append(self._output_delimiter.join(new_chunk))
|
||||
|
||||
|
||||
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
"""Split incoming text and return chunks using tokenizer."""
|
||||
splits: list[str] = []
|
||||
def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
"""Split a single text and return chunks using the tokenizer."""
|
||||
result = []
|
||||
input_ids = tokenizer.encode(text)
|
||||
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
|
||||
while start_idx < len(input_ids):
|
||||
splits.append(tokenizer.decode(chunk_ids))
|
||||
chunk_text = tokenizer.decode(list(chunk_ids))
|
||||
result.append(chunk_text) # Append chunked text as string
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
return splits
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
|
||||
# So we could have better control over the chunking process
|
||||
def split_multiple_texts_on_tokens(
|
||||
texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker
|
||||
) -> list[TextChunk]:
|
||||
"""Split multiple texts and return chunks with metadata using the tokenizer."""
|
||||
result = []
|
||||
mapped_ids = []
|
||||
|
||||
for source_doc_idx, text in enumerate(texts):
|
||||
encoded = tokenizer.encode(text)
|
||||
if tick:
|
||||
tick(1) # Track progress if tick callback is provided
|
||||
mapped_ids.append((source_doc_idx, encoded))
|
||||
|
||||
input_ids = [
|
||||
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
|
||||
]
|
||||
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
|
||||
while start_idx < len(input_ids):
|
||||
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
|
||||
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
|
||||
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
|
||||
return result
|
||||
|
||||
2
tests/unit/indexing/text_splitting/__init__.py
Normal file
2
tests/unit/indexing/text_splitting/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
161
tests/unit/indexing/text_splitting/test_text_splitting.py
Normal file
161
tests/unit/indexing/text_splitting/test_text_splitting.py
Normal file
@ -0,0 +1,161 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
NoopTextSplitter,
|
||||
Tokenizer,
|
||||
TokenTextSplitter,
|
||||
split_multiple_texts_on_tokens,
|
||||
split_single_text_on_tokens,
|
||||
)
|
||||
|
||||
|
||||
def test_noop_text_splitter() -> None:
|
||||
splitter = NoopTextSplitter()
|
||||
|
||||
assert list(splitter.split_text("some text")) == ["some text"]
|
||||
assert list(splitter.split_text(["some", "text"])) == ["some", "text"]
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
def encode(self, text):
|
||||
return [ord(char) for char in text]
|
||||
|
||||
def decode(self, token_ids):
|
||||
return "".join(chr(id) for id in token_ids)
|
||||
|
||||
|
||||
def test_split_text_str_empty():
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
result = splitter.split_text("")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_split_text_str_bool():
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
result = splitter.split_text(None) # type: ignore
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_split_text_str_int():
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
with pytest.raises(TypeError):
|
||||
splitter.split_text(123) # type: ignore
|
||||
|
||||
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
|
||||
def test_split_text_large_input(mock_split):
|
||||
large_text = "a" * 10_000
|
||||
mock_split.return_value = ["chunk"] * 2_000
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
|
||||
result = splitter.split_text(large_text)
|
||||
|
||||
assert len(result) == 2_000, "Large input was not split correctly"
|
||||
mock_split.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.Tokenizer")
|
||||
def test_token_text_splitter(mock_tokenizer, mock_split_text):
|
||||
text = "chunk1 chunk2 chunk3"
|
||||
expected_chunks = ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
mocked_tokenizer = MagicMock()
|
||||
mock_tokenizer.return_value = mocked_tokenizer
|
||||
mock_split_text.return_value = expected_chunks
|
||||
|
||||
splitter = TokenTextSplitter()
|
||||
|
||||
splitter.split_text(["chunk1", "chunk2", "chunk3"])
|
||||
|
||||
mock_split_text.assert_called_once_with(text=text, tokenizer=mocked_tokenizer)
|
||||
|
||||
|
||||
def test_encode_basic():
|
||||
splitter = TokenTextSplitter()
|
||||
result = splitter.encode("abc def")
|
||||
|
||||
assert result == [13997, 711], "Encoding failed to return expected tokens"
|
||||
|
||||
|
||||
def test_num_tokens_empty_input():
|
||||
splitter = TokenTextSplitter()
|
||||
result = splitter.num_tokens("")
|
||||
|
||||
assert result == 0, "Token count for empty input should be 0"
|
||||
|
||||
|
||||
def test_model_name():
|
||||
splitter = TokenTextSplitter(model_name="gpt-4o")
|
||||
result = splitter.encode("abc def")
|
||||
|
||||
assert result == [26682, 1056], "Encoding failed to return expected tokens"
|
||||
|
||||
|
||||
@mock.patch("tiktoken.encoding_for_model", side_effect=KeyError)
|
||||
@mock.patch("tiktoken.get_encoding")
|
||||
def test_model_name_exception(mock_get_encoding, mock_encoding_for_model):
|
||||
mock_get_encoding.return_value = mock.MagicMock()
|
||||
|
||||
TokenTextSplitter(model_name="mock_model", encoding_name="mock_encoding")
|
||||
|
||||
mock_get_encoding.assert_called_once_with("mock_encoding")
|
||||
mock_encoding_for_model.assert_called_once_with("mock_model")
|
||||
|
||||
|
||||
def test_split_single_text_on_tokens():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
tokenizer = Tokenizer(
|
||||
chunk_overlap=5,
|
||||
tokens_per_chunk=10,
|
||||
decode=mocked_tokenizer.decode,
|
||||
encode=lambda text: mocked_tokenizer.encode(text),
|
||||
)
|
||||
|
||||
expected_splits = [
|
||||
"This is a ",
|
||||
"is a test ",
|
||||
"test text,",
|
||||
"text, mean",
|
||||
" meaning t",
|
||||
"ing to be ",
|
||||
"o be taken",
|
||||
"taken seri", # cspell:disable-line
|
||||
" seriously",
|
||||
"ously by t", # cspell:disable-line
|
||||
" by this t",
|
||||
"his test o",
|
||||
"est only.",
|
||||
"nly.",
|
||||
]
|
||||
|
||||
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
assert result == expected_splits
|
||||
|
||||
|
||||
def test_split_multiple_texts_on_tokens():
|
||||
texts = [
|
||||
"This is a test text, meaning to be taken seriously by this test only.",
|
||||
"This is th second text, meaning to be taken seriously by this test only.",
|
||||
]
|
||||
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
mock_tick = MagicMock()
|
||||
tokenizer = Tokenizer(
|
||||
chunk_overlap=5,
|
||||
tokens_per_chunk=10,
|
||||
decode=mocked_tokenizer.decode,
|
||||
encode=lambda text: mocked_tokenizer.encode(text),
|
||||
)
|
||||
|
||||
split_multiple_texts_on_tokens(texts, tokenizer, tick=mock_tick)
|
||||
mock_tick.assert_called()
|
||||
Loading…
x
Reference in New Issue
Block a user