| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | from __future__ import annotations | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import copy | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import re | 
					
						
							|  |  |  | from abc import ABC, abstractmethod | 
					
						
							|  |  |  | from collections.abc import Callable, Collection, Iterable, Sequence, Set | 
					
						
							|  |  |  | from dataclasses import dataclass | 
					
						
							|  |  |  | from typing import ( | 
					
						
							|  |  |  |     Any, | 
					
						
							|  |  |  |     Literal, | 
					
						
							|  |  |  |     Optional, | 
					
						
							|  |  |  |     TypedDict, | 
					
						
							|  |  |  |     TypeVar, | 
					
						
							|  |  |  |     Union, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from core.rag.models.document import BaseDocumentTransformer, Document | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | logger = logging.getLogger(__name__) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TS = TypeVar("TS", bound="TextSplitter") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  | def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |     # Now that we have the separator, split the text | 
					
						
							|  |  |  |     if separator: | 
					
						
							|  |  |  |         if keep_separator: | 
					
						
							|  |  |  |             # The parentheses in the pattern keep the delimiters in the result. | 
					
						
							| 
									
										
										
										
											2024-03-07 18:25:49 +08:00
										 |  |  |             _splits = re.split(f"({re.escape(separator)})", text) | 
					
						
							| 
									
										
										
										
											2024-09-04 12:59:10 +08:00
										 |  |  |             splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)] | 
					
						
							|  |  |  |             if len(_splits) % 2 != 0: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                 splits += _splits[-1:] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             splits = re.split(separator, text) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         splits = list(text) | 
					
						
							| 
									
										
										
										
											2024-09-13 22:42:08 +08:00
										 |  |  |     return [s for s in splits if (s not in {"", "\n"})] | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TextSplitter(BaseDocumentTransformer, ABC): | 
					
						
							|  |  |  |     """Interface for splitting text into chunks.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         self, | 
					
						
							|  |  |  |         chunk_size: int = 4000, | 
					
						
							|  |  |  |         chunk_overlap: int = 200, | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |         length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x], | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         keep_separator: bool = False, | 
					
						
							|  |  |  |         add_start_index: bool = False, | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |     ) -> None: | 
					
						
							|  |  |  |         """Create a new TextSplitter.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             chunk_size: Maximum size of chunks to return | 
					
						
							|  |  |  |             chunk_overlap: Overlap in characters between chunks | 
					
						
							|  |  |  |             length_function: Function that measures the length of given chunks | 
					
						
							|  |  |  |             keep_separator: Whether to keep the separator in the chunks | 
					
						
							|  |  |  |             add_start_index: If `True`, includes chunk's start index in metadata | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if chunk_overlap > chunk_size: | 
					
						
							|  |  |  |             raise ValueError( | 
					
						
							| 
									
										
										
										
											2024-09-13 14:24:49 +08:00
										 |  |  |                 f"Got a larger chunk overlap ({chunk_overlap}) than chunk size ({chunk_size}), should be smaller." | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |         self._chunk_size = chunk_size | 
					
						
							|  |  |  |         self._chunk_overlap = chunk_overlap | 
					
						
							|  |  |  |         self._length_function = length_function | 
					
						
							|  |  |  |         self._keep_separator = keep_separator | 
					
						
							|  |  |  |         self._add_start_index = add_start_index | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def split_text(self, text: str) -> list[str]: | 
					
						
							|  |  |  |         """Split text into multiple components.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         """Create documents from a list of texts.""" | 
					
						
							|  |  |  |         _metadatas = metadatas or [{}] * len(texts) | 
					
						
							|  |  |  |         documents = [] | 
					
						
							|  |  |  |         for i, text in enumerate(texts): | 
					
						
							|  |  |  |             index = -1 | 
					
						
							|  |  |  |             for chunk in self.split_text(text): | 
					
						
							|  |  |  |                 metadata = copy.deepcopy(_metadatas[i]) | 
					
						
							|  |  |  |                 if self._add_start_index: | 
					
						
							|  |  |  |                     index = text.find(chunk, index + 1) | 
					
						
							|  |  |  |                     metadata["start_index"] = index | 
					
						
							|  |  |  |                 new_doc = Document(page_content=chunk, metadata=metadata) | 
					
						
							|  |  |  |                 documents.append(new_doc) | 
					
						
							|  |  |  |         return documents | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-23 18:05:23 +08:00
										 |  |  |     def split_documents(self, documents: Iterable[Document]) -> list[Document]: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         """Split documents.""" | 
					
						
							|  |  |  |         texts, metadatas = [], [] | 
					
						
							|  |  |  |         for doc in documents: | 
					
						
							|  |  |  |             texts.append(doc.page_content) | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |             metadatas.append(doc.metadata or {}) | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         return self.create_documents(texts, metadatas=metadatas) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: | 
					
						
							|  |  |  |         text = separator.join(docs) | 
					
						
							|  |  |  |         text = text.strip() | 
					
						
							|  |  |  |         if text == "": | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return text | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  |     def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         # We now want to combine these smaller pieces into medium size | 
					
						
							|  |  |  |         # chunks to send to the LLM. | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |         separator_len = self._length_function([separator])[0] | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         docs = [] | 
					
						
							|  |  |  |         current_doc: list[str] = [] | 
					
						
							|  |  |  |         total = 0 | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  |         index = 0 | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         for d in splits: | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  |             _len = lengths[index] | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                 if total > self._chunk_size: | 
					
						
							|  |  |  |                     logger.warning( | 
					
						
							| 
									
										
										
										
											2024-09-13 14:24:49 +08:00
										 |  |  |                         f"Created a chunk of size {total}, which is longer than the specified {self._chunk_size}" | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                     ) | 
					
						
							|  |  |  |                 if len(current_doc) > 0: | 
					
						
							|  |  |  |                     doc = self._join_docs(current_doc, separator) | 
					
						
							|  |  |  |                     if doc is not None: | 
					
						
							|  |  |  |                         docs.append(doc) | 
					
						
							|  |  |  |                     # Keep on popping if: | 
					
						
							|  |  |  |                     # - we have a larger chunk than in the chunk overlap | 
					
						
							|  |  |  |                     # - or if we still have any chunks and the length is long | 
					
						
							|  |  |  |                     while total > self._chunk_overlap or ( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                     ): | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |                         total -= self._length_function([current_doc[0]])[0] + ( | 
					
						
							|  |  |  |                             separator_len if len(current_doc) > 1 else 0 | 
					
						
							|  |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                         current_doc = current_doc[1:] | 
					
						
							|  |  |  |             current_doc.append(d) | 
					
						
							|  |  |  |             total += _len + (separator_len if len(current_doc) > 1 else 0) | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  |             index += 1 | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         doc = self._join_docs(current_doc, separator) | 
					
						
							|  |  |  |         if doc is not None: | 
					
						
							|  |  |  |             docs.append(doc) | 
					
						
							|  |  |  |         return docs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: | 
					
						
							|  |  |  |         """Text splitter that uses HuggingFace tokenizer to count length.""" | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |             from transformers import PreTrainedTokenizerBase  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             if not isinstance(tokenizer, PreTrainedTokenizerBase): | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             def _huggingface_tokenizer_length(text: str) -> int: | 
					
						
							|  |  |  |                 return len(tokenizer.encode(text)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except ImportError: | 
					
						
							|  |  |  |             raise ValueError( | 
					
						
							| 
									
										
										
										
											2024-09-13 14:24:49 +08:00
										 |  |  |                 "Could not import transformers python package. Please install it with `pip install transformers`." | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |         return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs) | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def from_tiktoken_encoder( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         cls: type[TS], | 
					
						
							|  |  |  |         encoding_name: str = "gpt2", | 
					
						
							|  |  |  |         model_name: Optional[str] = None, | 
					
						
							|  |  |  |         allowed_special: Union[Literal["all"], Set[str]] = set(), | 
					
						
							|  |  |  |         disallowed_special: Union[Literal["all"], Collection[str]] = "all", | 
					
						
							|  |  |  |         **kwargs: Any, | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |     ) -> TS: | 
					
						
							|  |  |  |         """Text splitter that uses tiktoken encoder to count length.""" | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             import tiktoken | 
					
						
							|  |  |  |         except ImportError: | 
					
						
							|  |  |  |             raise ImportError( | 
					
						
							|  |  |  |                 "Could not import tiktoken python package. " | 
					
						
							|  |  |  |                 "This is needed in order to calculate max_tokens_for_prompt. " | 
					
						
							|  |  |  |                 "Please install it with `pip install tiktoken`." | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if model_name is not None: | 
					
						
							|  |  |  |             enc = tiktoken.encoding_for_model(model_name) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             enc = tiktoken.get_encoding(encoding_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def _tiktoken_encoder(text: str) -> int: | 
					
						
							|  |  |  |             return len( | 
					
						
							|  |  |  |                 enc.encode( | 
					
						
							|  |  |  |                     text, | 
					
						
							|  |  |  |                     allowed_special=allowed_special, | 
					
						
							|  |  |  |                     disallowed_special=disallowed_special, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if issubclass(cls, TokenTextSplitter): | 
					
						
							|  |  |  |             extra_kwargs = { | 
					
						
							|  |  |  |                 "encoding_name": encoding_name, | 
					
						
							|  |  |  |                 "model_name": model_name, | 
					
						
							|  |  |  |                 "allowed_special": allowed_special, | 
					
						
							|  |  |  |                 "disallowed_special": disallowed_special, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             kwargs = {**kwargs, **extra_kwargs} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |         return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs) | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         """Transform sequence of documents by splitting them.""" | 
					
						
							|  |  |  |         return self.split_documents(list(documents)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         """Asynchronously transform a sequence of documents by splitting them.""" | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CharacterTextSplitter(TextSplitter): | 
					
						
							|  |  |  |     """Splitting text that looks at characters.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: | 
					
						
							|  |  |  |         """Create a new TextSplitter.""" | 
					
						
							|  |  |  |         super().__init__(**kwargs) | 
					
						
							|  |  |  |         self._separator = separator | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def split_text(self, text: str) -> list[str]: | 
					
						
							|  |  |  |         """Split incoming text and return chunks.""" | 
					
						
							|  |  |  |         # First we naively split the large input into a bunch of smaller ones. | 
					
						
							|  |  |  |         splits = _split_text_with_regex(text, self._separator, self._keep_separator) | 
					
						
							|  |  |  |         _separator = "" if self._keep_separator else self._separator | 
					
						
							| 
									
										
										
										
											2024-09-04 21:47:12 +08:00
										 |  |  |         _good_splits_lengths = []  # cache the lengths of the splits | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |         if splits: | 
					
						
							|  |  |  |             _good_splits_lengths.extend(self._length_function(splits)) | 
					
						
							| 
									
										
										
										
											2024-09-04 21:47:12 +08:00
										 |  |  |         return self._merge_splits(splits, _separator, _good_splits_lengths) | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class LineType(TypedDict): | 
					
						
							|  |  |  |     """Line type as typed dict.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     metadata: dict[str, str] | 
					
						
							|  |  |  |     content: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class HeaderType(TypedDict): | 
					
						
							|  |  |  |     """Header type as typed dict.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     level: int | 
					
						
							|  |  |  |     name: str | 
					
						
							|  |  |  |     data: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MarkdownHeaderTextSplitter: | 
					
						
							|  |  |  |     """Splitting markdown files based on specified headers.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         """Create a new MarkdownHeaderTextSplitter.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             headers_to_split_on: Headers we want to track | 
					
						
							|  |  |  |             return_each_line: Return each line w/ associated headers | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Output line-by-line or aggregated into chunks w/ common headers | 
					
						
							|  |  |  |         self.return_each_line = return_each_line | 
					
						
							|  |  |  |         # Given the headers we want to split on, | 
					
						
							|  |  |  |         # (e.g., "#, ##, etc") order by length | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: | 
					
						
							|  |  |  |         """Combine lines with common metadata into chunks
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             lines: Line of text / associated header metadata | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         aggregated_chunks: list[LineType] = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for line in lines: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                 # If the last line in the aggregated list | 
					
						
							|  |  |  |                 # has the same metadata as the current line, | 
					
						
							|  |  |  |                 # append the current content to the last lines's content | 
					
						
							|  |  |  |                 aggregated_chunks[-1]["content"] += "  \n" + line["content"] | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 # Otherwise, append the current line to the aggregated list | 
					
						
							|  |  |  |                 aggregated_chunks.append(line) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def split_text(self, text: str) -> list[Document]: | 
					
						
							|  |  |  |         """Split markdown file
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             text: Markdown file"""
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Split the input text by newline character ("\n"). | 
					
						
							|  |  |  |         lines = text.split("\n") | 
					
						
							|  |  |  |         # Final output | 
					
						
							|  |  |  |         lines_with_metadata: list[LineType] = [] | 
					
						
							|  |  |  |         # Content and metadata of the chunk currently being processed | 
					
						
							|  |  |  |         current_content: list[str] = [] | 
					
						
							|  |  |  |         current_metadata: dict[str, str] = {} | 
					
						
							|  |  |  |         # Keep track of the nested header structure | 
					
						
							|  |  |  |         # header_stack: List[Dict[str, Union[int, str]]] = [] | 
					
						
							|  |  |  |         header_stack: list[HeaderType] = [] | 
					
						
							|  |  |  |         initial_metadata: dict[str, str] = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for line in lines: | 
					
						
							|  |  |  |             stripped_line = line.strip() | 
					
						
							|  |  |  |             # Check each line against each of the header types (e.g., #, ##) | 
					
						
							|  |  |  |             for sep, name in self.headers_to_split_on: | 
					
						
							|  |  |  |                 # Check if line starts with a header that we intend to split on | 
					
						
							|  |  |  |                 if stripped_line.startswith(sep) and ( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     # Header with no text OR header is followed by space | 
					
						
							|  |  |  |                     # Both are valid conditions that sep is being used a header | 
					
						
							|  |  |  |                     len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                 ): | 
					
						
							|  |  |  |                     # Ensure we are tracking the header as metadata | 
					
						
							|  |  |  |                     if name is not None: | 
					
						
							|  |  |  |                         # Get the current header level | 
					
						
							|  |  |  |                         current_header_level = sep.count("#") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         # Pop out headers of lower or same level from the stack | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         while header_stack and header_stack[-1]["level"] >= current_header_level: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                             # We have encountered a new header | 
					
						
							|  |  |  |                             # at the same or higher level | 
					
						
							|  |  |  |                             popped_header = header_stack.pop() | 
					
						
							|  |  |  |                             # Clear the metadata for the | 
					
						
							|  |  |  |                             # popped header in initial_metadata | 
					
						
							|  |  |  |                             if popped_header["name"] in initial_metadata: | 
					
						
							|  |  |  |                                 initial_metadata.pop(popped_header["name"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         # Push the current header to the stack | 
					
						
							|  |  |  |                         header: HeaderType = { | 
					
						
							|  |  |  |                             "level": current_header_level, | 
					
						
							|  |  |  |                             "name": name, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             "data": stripped_line[len(sep) :].strip(), | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                         } | 
					
						
							|  |  |  |                         header_stack.append(header) | 
					
						
							|  |  |  |                         # Update initial_metadata with the current header | 
					
						
							|  |  |  |                         initial_metadata[name] = header["data"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     # Add the previous line to the lines_with_metadata | 
					
						
							|  |  |  |                     # only if current_content is not empty | 
					
						
							|  |  |  |                     if current_content: | 
					
						
							|  |  |  |                         lines_with_metadata.append( | 
					
						
							|  |  |  |                             { | 
					
						
							|  |  |  |                                 "content": "\n".join(current_content), | 
					
						
							|  |  |  |                                 "metadata": current_metadata.copy(), | 
					
						
							|  |  |  |                             } | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  |                         current_content.clear() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 if stripped_line: | 
					
						
							|  |  |  |                     current_content.append(stripped_line) | 
					
						
							|  |  |  |                 elif current_content: | 
					
						
							|  |  |  |                     lines_with_metadata.append( | 
					
						
							|  |  |  |                         { | 
					
						
							|  |  |  |                             "content": "\n".join(current_content), | 
					
						
							|  |  |  |                             "metadata": current_metadata.copy(), | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                     current_content.clear() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             current_metadata = initial_metadata.copy() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if current_content: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # lines_with_metadata has each line with associated header metadata | 
					
						
							|  |  |  |         # aggregate these into chunks based on common metadata | 
					
						
							|  |  |  |         if not self.return_each_line: | 
					
						
							|  |  |  |             return self.aggregate_lines_to_chunks(lines_with_metadata) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return [ | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |             ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # should be in newer Python versions (3.10+) | 
					
						
							|  |  |  | # @dataclass(frozen=True, kw_only=True, slots=True) | 
					
						
							|  |  |  | @dataclass(frozen=True) | 
					
						
							|  |  |  | class Tokenizer: | 
					
						
							|  |  |  |     chunk_overlap: int | 
					
						
							|  |  |  |     tokens_per_chunk: int | 
					
						
							|  |  |  |     decode: Callable[[list[int]], str] | 
					
						
							|  |  |  |     encode: Callable[[str], list[int]] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: | 
					
						
							|  |  |  |     """Split incoming text and return chunks using tokenizer.""" | 
					
						
							|  |  |  |     splits: list[str] = [] | 
					
						
							|  |  |  |     input_ids = tokenizer.encode(text) | 
					
						
							|  |  |  |     start_idx = 0 | 
					
						
							|  |  |  |     cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | 
					
						
							|  |  |  |     chunk_ids = input_ids[start_idx:cur_idx] | 
					
						
							|  |  |  |     while start_idx < len(input_ids): | 
					
						
							|  |  |  |         splits.append(tokenizer.decode(chunk_ids)) | 
					
						
							|  |  |  |         start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap | 
					
						
							|  |  |  |         cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | 
					
						
							|  |  |  |         chunk_ids = input_ids[start_idx:cur_idx] | 
					
						
							|  |  |  |     return splits | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TokenTextSplitter(TextSplitter): | 
					
						
							|  |  |  |     """Splitting text to tokens using model tokenizer.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         self, | 
					
						
							|  |  |  |         encoding_name: str = "gpt2", | 
					
						
							|  |  |  |         model_name: Optional[str] = None, | 
					
						
							|  |  |  |         allowed_special: Union[Literal["all"], Set[str]] = set(), | 
					
						
							|  |  |  |         disallowed_special: Union[Literal["all"], Collection[str]] = "all", | 
					
						
							|  |  |  |         **kwargs: Any, | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |     ) -> None: | 
					
						
							|  |  |  |         """Create a new TextSplitter.""" | 
					
						
							|  |  |  |         super().__init__(**kwargs) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             import tiktoken | 
					
						
							|  |  |  |         except ImportError: | 
					
						
							|  |  |  |             raise ImportError( | 
					
						
							|  |  |  |                 "Could not import tiktoken python package. " | 
					
						
							|  |  |  |                 "This is needed in order to for TokenTextSplitter. " | 
					
						
							|  |  |  |                 "Please install it with `pip install tiktoken`." | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if model_name is not None: | 
					
						
							|  |  |  |             enc = tiktoken.encoding_for_model(model_name) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             enc = tiktoken.get_encoding(encoding_name) | 
					
						
							|  |  |  |         self._tokenizer = enc | 
					
						
							|  |  |  |         self._allowed_special = allowed_special | 
					
						
							|  |  |  |         self._disallowed_special = disallowed_special | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def split_text(self, text: str) -> list[str]: | 
					
						
							|  |  |  |         def _encode(_text: str) -> list[int]: | 
					
						
							|  |  |  |             return self._tokenizer.encode( | 
					
						
							|  |  |  |                 _text, | 
					
						
							|  |  |  |                 allowed_special=self._allowed_special, | 
					
						
							|  |  |  |                 disallowed_special=self._disallowed_special, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         tokenizer = Tokenizer( | 
					
						
							|  |  |  |             chunk_overlap=self._chunk_overlap, | 
					
						
							|  |  |  |             tokens_per_chunk=self._chunk_size, | 
					
						
							|  |  |  |             decode=self._tokenizer.decode, | 
					
						
							|  |  |  |             encode=_encode, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return split_text_on_tokens(text=text, tokenizer=tokenizer) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RecursiveCharacterTextSplitter(TextSplitter): | 
					
						
							|  |  |  |     """Splitting text by recursively look at characters.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Recursively tries to split by different characters to find one | 
					
						
							|  |  |  |     that works. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         self, | 
					
						
							|  |  |  |         separators: Optional[list[str]] = None, | 
					
						
							|  |  |  |         keep_separator: bool = True, | 
					
						
							|  |  |  |         **kwargs: Any, | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |     ) -> None: | 
					
						
							|  |  |  |         """Create a new TextSplitter.""" | 
					
						
							|  |  |  |         super().__init__(keep_separator=keep_separator, **kwargs) | 
					
						
							|  |  |  |         self._separators = separators or ["\n\n", "\n", " ", ""] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _split_text(self, text: str, separators: list[str]) -> list[str]: | 
					
						
							|  |  |  |         final_chunks = [] | 
					
						
							|  |  |  |         separator = separators[-1] | 
					
						
							|  |  |  |         new_separators = [] | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         for i, _s in enumerate(separators): | 
					
						
							|  |  |  |             if _s == "": | 
					
						
							|  |  |  |                 separator = _s | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |             if re.search(_s, text): | 
					
						
							|  |  |  |                 separator = _s | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 new_separators = separators[i + 1 :] | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                 break | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         splits = _split_text_with_regex(text, separator, self._keep_separator) | 
					
						
							|  |  |  |         _good_splits = [] | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  |         _good_splits_lengths = []  # cache the lengths of the splits | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         _separator = "" if self._keep_separator else separator | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |         s_lens = self._length_function(splits) | 
					
						
							|  |  |  |         for s, s_len in zip(splits, s_lens): | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  |             if s_len < self._chunk_size: | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                 _good_splits.append(s) | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  |                 _good_splits_lengths.append(s_len) | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 if _good_splits: | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  |                     merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                     final_chunks.extend(merged_text) | 
					
						
							|  |  |  |                     _good_splits = [] | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  |                     _good_splits_lengths = [] | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |                 if not new_separators: | 
					
						
							|  |  |  |                     final_chunks.append(s) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     other_info = self._split_text(s, new_separators) | 
					
						
							|  |  |  |                     final_chunks.extend(other_info) | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         if _good_splits: | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  |             merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |             final_chunks.extend(merged_text) | 
					
						
							| 
									
										
										
										
											2024-09-04 14:41:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 14:16:44 +08:00
										 |  |  |         return final_chunks | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def split_text(self, text: str) -> list[str]: | 
					
						
							|  |  |  |         return self._split_text(text, self._separators) |