from __future__ import annotations import copy import logging import re from abc import ABC, abstractmethod from collections.abc import Callable, Collection, Iterable, Sequence, Set from dataclasses import dataclass from typing import ( Any, Literal, Optional, TypeVar, Union, ) from core.rag.models.document import BaseDocumentTransformer, Document logger = logging.getLogger(__name__) TS = TypeVar("TS", bound="TextSplitter") def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]: # Now that we have the separator, split the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. _splits = re.split(f"({re.escape(separator)})", text) splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)] if len(_splits) % 2 != 0: splits += _splits[-1:] else: splits = re.split(separator, text) else: splits = list(text) return [s for s in splits if (s not in {"", "\n"})] class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks.""" def __init__( self, chunk_size: int = 4000, chunk_overlap: int = 200, length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x], keep_separator: bool = False, add_start_index: bool = False, ) -> None: """Create a new TextSplitter. Args: chunk_size: Maximum size of chunks to return chunk_overlap: Overlap in characters between chunks length_function: Function that measures the length of given chunks keep_separator: Whether to keep the separator in the chunks add_start_index: If `True`, includes chunk's start index in metadata """ if chunk_overlap > chunk_size: raise ValueError( f"Got a larger chunk overlap ({chunk_overlap}) than chunk size ({chunk_size}), should be smaller." ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._length_function = length_function self._keep_separator = keep_separator self._add_start_index = add_start_index @abstractmethod def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] for i, text in enumerate(texts): index = -1 for chunk in self.split_text(text): metadata = copy.deepcopy(_metadatas[i]) if self._add_start_index: index = text.find(chunk, index + 1) metadata["start_index"] = index new_doc = Document(page_content=chunk, metadata=metadata) documents.append(new_doc) return documents def split_documents(self, documents: Iterable[Document]) -> list[Document]: """Split documents.""" texts, metadatas = [], [] for doc in documents: texts.append(doc.page_content) metadatas.append(doc.metadata or {}) return self.create_documents(texts, metadatas=metadatas) def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: text = separator.join(docs) text = text.strip() if text == "": return None else: return text def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. separator_len = self._length_function([separator])[0] docs = [] current_doc: list[str] = [] total = 0 index = 0 for d in splits: _len = lengths[index] if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( "Created a chunk of size %s, which is longer than the specified %s", total, self._chunk_size ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) # Keep on popping if: # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): total -= self._length_function([current_doc[0]])[0] + ( separator_len if len(current_doc) > 1 else 0 ) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) index += 1 doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) return docs @classmethod def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: """Text splitter that uses HuggingFace tokenizer to count length.""" try: from transformers import PreTrainedTokenizerBase # type: ignore if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") def _huggingface_tokenizer_length(text: str) -> int: return len(tokenizer.encode(text)) except ImportError: raise ValueError( "Could not import transformers python package. Please install it with `pip install transformers`." ) return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs) def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents)) async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a sequence of documents by splitting them.""" raise NotImplementedError # @dataclass(frozen=True, kw_only=True, slots=True) @dataclass(frozen=True) class Tokenizer: chunk_overlap: int tokens_per_chunk: int decode: Callable[[list[int]], str] encode: Callable[[str], list[int]] def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: """Split incoming text and return chunks using tokenizer.""" splits: list[str] = [] input_ids = tokenizer.encode(text) start_idx = 0 cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] while start_idx < len(input_ids): splits.append(tokenizer.decode(chunk_ids)) start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] return splits class TokenTextSplitter(TextSplitter): """Splitting text to tokens using model tokenizer.""" def __init__( self, encoding_name: str = "gpt2", model_name: Optional[str] = None, allowed_special: Union[Literal["all"], Set[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) try: import tiktoken except ImportError: raise ImportError( "Could not import tiktoken python package. " "This is needed in order to for TokenTextSplitter. " "Please install it with `pip install tiktoken`." ) if model_name is not None: enc = tiktoken.encoding_for_model(model_name) else: enc = tiktoken.get_encoding(encoding_name) self._tokenizer = enc self._allowed_special = allowed_special self._disallowed_special = disallowed_special def split_text(self, text: str) -> list[str]: def _encode(_text: str) -> list[int]: return self._tokenizer.encode( _text, allowed_special=self._allowed_special, disallowed_special=self._disallowed_special, ) tokenizer = Tokenizer( chunk_overlap=self._chunk_overlap, tokens_per_chunk=self._chunk_size, decode=self._tokenizer.decode, encode=_encode, ) return split_text_on_tokens(text=text, tokenizer=tokenizer) class RecursiveCharacterTextSplitter(TextSplitter): """Splitting text by recursively look at characters. Recursively tries to split by different characters to find one that works. """ def __init__( self, separators: Optional[list[str]] = None, keep_separator: bool = True, **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) self._separators = separators or ["\n\n", "\n", " ", ""] def _split_text(self, text: str, separators: list[str]) -> list[str]: final_chunks = [] separator = separators[-1] new_separators = [] for i, _s in enumerate(separators): if _s == "": separator = _s break if re.search(_s, text): separator = _s new_separators = separators[i + 1 :] break splits = _split_text_with_regex(text, separator, self._keep_separator) _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits _separator = "" if self._keep_separator else separator s_lens = self._length_function(splits) for s, s_len in zip(splits, s_lens): if s_len < self._chunk_size: _good_splits.append(s) _good_splits_lengths.append(s_len) else: if _good_splits: merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) final_chunks.extend(merged_text) _good_splits = [] _good_splits_lengths = [] if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) if _good_splits: merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) final_chunks.extend(merged_text) return final_chunks def split_text(self, text: str) -> list[str]: return self._split_text(text, self._separators)