mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-31 19:03:09 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			904 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			904 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| import copy
 | |
| import logging
 | |
| import re
 | |
| from abc import ABC, abstractmethod
 | |
| from collections.abc import Callable, Collection, Iterable, Sequence, Set
 | |
| from dataclasses import dataclass
 | |
| from enum import Enum
 | |
| from typing import (
 | |
|     Any,
 | |
|     Literal,
 | |
|     Optional,
 | |
|     TypedDict,
 | |
|     TypeVar,
 | |
|     Union,
 | |
| )
 | |
| 
 | |
| from core.rag.models.document import BaseDocumentTransformer, Document
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| TS = TypeVar("TS", bound="TextSplitter")
 | |
| 
 | |
| 
 | |
| def _split_text_with_regex(
 | |
|         text: str, separator: str, keep_separator: bool
 | |
| ) -> list[str]:
 | |
|     # Now that we have the separator, split the text
 | |
|     if separator:
 | |
|         if keep_separator:
 | |
|             # The parentheses in the pattern keep the delimiters in the result.
 | |
|             _splits = re.split(f"({separator})", text)
 | |
|             splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
 | |
|             if len(_splits) % 2 == 0:
 | |
|                 splits += _splits[-1:]
 | |
|             splits = [_splits[0]] + splits
 | |
|         else:
 | |
|             splits = re.split(separator, text)
 | |
|     else:
 | |
|         splits = list(text)
 | |
|     return [s for s in splits if s != ""]
 | |
| 
 | |
| 
 | |
| class TextSplitter(BaseDocumentTransformer, ABC):
 | |
|     """Interface for splitting text into chunks."""
 | |
| 
 | |
|     def __init__(
 | |
|             self,
 | |
|             chunk_size: int = 4000,
 | |
|             chunk_overlap: int = 200,
 | |
|             length_function: Callable[[str], int] = len,
 | |
|             keep_separator: bool = False,
 | |
|             add_start_index: bool = False,
 | |
|     ) -> None:
 | |
|         """Create a new TextSplitter.
 | |
| 
 | |
|         Args:
 | |
|             chunk_size: Maximum size of chunks to return
 | |
|             chunk_overlap: Overlap in characters between chunks
 | |
|             length_function: Function that measures the length of given chunks
 | |
|             keep_separator: Whether to keep the separator in the chunks
 | |
|             add_start_index: If `True`, includes chunk's start index in metadata
 | |
|         """
 | |
|         if chunk_overlap > chunk_size:
 | |
|             raise ValueError(
 | |
|                 f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
 | |
|                 f"({chunk_size}), should be smaller."
 | |
|             )
 | |
|         self._chunk_size = chunk_size
 | |
|         self._chunk_overlap = chunk_overlap
 | |
|         self._length_function = length_function
 | |
|         self._keep_separator = keep_separator
 | |
|         self._add_start_index = add_start_index
 | |
| 
 | |
|     @abstractmethod
 | |
|     def split_text(self, text: str) -> list[str]:
 | |
|         """Split text into multiple components."""
 | |
| 
 | |
|     def create_documents(
 | |
|             self, texts: list[str], metadatas: Optional[list[dict]] = None
 | |
|     ) -> list[Document]:
 | |
|         """Create documents from a list of texts."""
 | |
|         _metadatas = metadatas or [{}] * len(texts)
 | |
|         documents = []
 | |
|         for i, text in enumerate(texts):
 | |
|             index = -1
 | |
|             for chunk in self.split_text(text):
 | |
|                 metadata = copy.deepcopy(_metadatas[i])
 | |
|                 if self._add_start_index:
 | |
|                     index = text.find(chunk, index + 1)
 | |
|                     metadata["start_index"] = index
 | |
|                 new_doc = Document(page_content=chunk, metadata=metadata)
 | |
|                 documents.append(new_doc)
 | |
|         return documents
 | |
| 
 | |
|     def split_documents(self, documents: Iterable[Document]) -> list[Document]:
 | |
|         """Split documents."""
 | |
|         texts, metadatas = [], []
 | |
|         for doc in documents:
 | |
|             texts.append(doc.page_content)
 | |
|             metadatas.append(doc.metadata)
 | |
|         return self.create_documents(texts, metadatas=metadatas)
 | |
| 
 | |
|     def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
 | |
|         text = separator.join(docs)
 | |
|         text = text.strip()
 | |
|         if text == "":
 | |
|             return None
 | |
|         else:
 | |
|             return text
 | |
| 
 | |
|     def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
 | |
|         # We now want to combine these smaller pieces into medium size
 | |
|         # chunks to send to the LLM.
 | |
|         separator_len = self._length_function(separator)
 | |
| 
 | |
|         docs = []
 | |
|         current_doc: list[str] = []
 | |
|         total = 0
 | |
|         for d in splits:
 | |
|             _len = self._length_function(d)
 | |
|             if (
 | |
|                     total + _len + (separator_len if len(current_doc) > 0 else 0)
 | |
|                     > self._chunk_size
 | |
|             ):
 | |
|                 if total > self._chunk_size:
 | |
|                     logger.warning(
 | |
|                         f"Created a chunk of size {total}, "
 | |
|                         f"which is longer than the specified {self._chunk_size}"
 | |
|                     )
 | |
|                 if len(current_doc) > 0:
 | |
|                     doc = self._join_docs(current_doc, separator)
 | |
|                     if doc is not None:
 | |
|                         docs.append(doc)
 | |
|                     # Keep on popping if:
 | |
|                     # - we have a larger chunk than in the chunk overlap
 | |
|                     # - or if we still have any chunks and the length is long
 | |
|                     while total > self._chunk_overlap or (
 | |
|                             total + _len + (separator_len if len(current_doc) > 0 else 0)
 | |
|                             > self._chunk_size
 | |
|                             and total > 0
 | |
|                     ):
 | |
|                         total -= self._length_function(current_doc[0]) + (
 | |
|                             separator_len if len(current_doc) > 1 else 0
 | |
|                         )
 | |
|                         current_doc = current_doc[1:]
 | |
|             current_doc.append(d)
 | |
|             total += _len + (separator_len if len(current_doc) > 1 else 0)
 | |
|         doc = self._join_docs(current_doc, separator)
 | |
|         if doc is not None:
 | |
|             docs.append(doc)
 | |
|         return docs
 | |
| 
 | |
|     @classmethod
 | |
|     def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
 | |
|         """Text splitter that uses HuggingFace tokenizer to count length."""
 | |
|         try:
 | |
|             from transformers import PreTrainedTokenizerBase
 | |
| 
 | |
|             if not isinstance(tokenizer, PreTrainedTokenizerBase):
 | |
|                 raise ValueError(
 | |
|                     "Tokenizer received was not an instance of PreTrainedTokenizerBase"
 | |
|                 )
 | |
| 
 | |
|             def _huggingface_tokenizer_length(text: str) -> int:
 | |
|                 return len(tokenizer.encode(text))
 | |
| 
 | |
|         except ImportError:
 | |
|             raise ValueError(
 | |
|                 "Could not import transformers python package. "
 | |
|                 "Please install it with `pip install transformers`."
 | |
|             )
 | |
|         return cls(length_function=_huggingface_tokenizer_length, **kwargs)
 | |
| 
 | |
|     @classmethod
 | |
|     def from_tiktoken_encoder(
 | |
|             cls: type[TS],
 | |
|             encoding_name: str = "gpt2",
 | |
|             model_name: Optional[str] = None,
 | |
|             allowed_special: Union[Literal["all"], Set[str]] = set(),
 | |
|             disallowed_special: Union[Literal["all"], Collection[str]] = "all",
 | |
|             **kwargs: Any,
 | |
|     ) -> TS:
 | |
|         """Text splitter that uses tiktoken encoder to count length."""
 | |
|         try:
 | |
|             import tiktoken
 | |
|         except ImportError:
 | |
|             raise ImportError(
 | |
|                 "Could not import tiktoken python package. "
 | |
|                 "This is needed in order to calculate max_tokens_for_prompt. "
 | |
|                 "Please install it with `pip install tiktoken`."
 | |
|             )
 | |
| 
 | |
|         if model_name is not None:
 | |
|             enc = tiktoken.encoding_for_model(model_name)
 | |
|         else:
 | |
|             enc = tiktoken.get_encoding(encoding_name)
 | |
| 
 | |
|         def _tiktoken_encoder(text: str) -> int:
 | |
|             return len(
 | |
|                 enc.encode(
 | |
|                     text,
 | |
|                     allowed_special=allowed_special,
 | |
|                     disallowed_special=disallowed_special,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         if issubclass(cls, TokenTextSplitter):
 | |
|             extra_kwargs = {
 | |
|                 "encoding_name": encoding_name,
 | |
|                 "model_name": model_name,
 | |
|                 "allowed_special": allowed_special,
 | |
|                 "disallowed_special": disallowed_special,
 | |
|             }
 | |
|             kwargs = {**kwargs, **extra_kwargs}
 | |
| 
 | |
|         return cls(length_function=_tiktoken_encoder, **kwargs)
 | |
| 
 | |
|     def transform_documents(
 | |
|             self, documents: Sequence[Document], **kwargs: Any
 | |
|     ) -> Sequence[Document]:
 | |
|         """Transform sequence of documents by splitting them."""
 | |
|         return self.split_documents(list(documents))
 | |
| 
 | |
|     async def atransform_documents(
 | |
|             self, documents: Sequence[Document], **kwargs: Any
 | |
|     ) -> Sequence[Document]:
 | |
|         """Asynchronously transform a sequence of documents by splitting them."""
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| class CharacterTextSplitter(TextSplitter):
 | |
|     """Splitting text that looks at characters."""
 | |
| 
 | |
|     def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
 | |
|         """Create a new TextSplitter."""
 | |
|         super().__init__(**kwargs)
 | |
|         self._separator = separator
 | |
| 
 | |
|     def split_text(self, text: str) -> list[str]:
 | |
|         """Split incoming text and return chunks."""
 | |
|         # First we naively split the large input into a bunch of smaller ones.
 | |
|         splits = _split_text_with_regex(text, self._separator, self._keep_separator)
 | |
|         _separator = "" if self._keep_separator else self._separator
 | |
|         return self._merge_splits(splits, _separator)
 | |
| 
 | |
| 
 | |
| class LineType(TypedDict):
 | |
|     """Line type as typed dict."""
 | |
| 
 | |
|     metadata: dict[str, str]
 | |
|     content: str
 | |
| 
 | |
| 
 | |
| class HeaderType(TypedDict):
 | |
|     """Header type as typed dict."""
 | |
| 
 | |
|     level: int
 | |
|     name: str
 | |
|     data: str
 | |
| 
 | |
| 
 | |
| class MarkdownHeaderTextSplitter:
 | |
|     """Splitting markdown files based on specified headers."""
 | |
| 
 | |
|     def __init__(
 | |
|             self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
 | |
|     ):
 | |
|         """Create a new MarkdownHeaderTextSplitter.
 | |
| 
 | |
|         Args:
 | |
|             headers_to_split_on: Headers we want to track
 | |
|             return_each_line: Return each line w/ associated headers
 | |
|         """
 | |
|         # Output line-by-line or aggregated into chunks w/ common headers
 | |
|         self.return_each_line = return_each_line
 | |
|         # Given the headers we want to split on,
 | |
|         # (e.g., "#, ##, etc") order by length
 | |
|         self.headers_to_split_on = sorted(
 | |
|             headers_to_split_on, key=lambda split: len(split[0]), reverse=True
 | |
|         )
 | |
| 
 | |
|     def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
 | |
|         """Combine lines with common metadata into chunks
 | |
|         Args:
 | |
|             lines: Line of text / associated header metadata
 | |
|         """
 | |
|         aggregated_chunks: list[LineType] = []
 | |
| 
 | |
|         for line in lines:
 | |
|             if (
 | |
|                     aggregated_chunks
 | |
|                     and aggregated_chunks[-1]["metadata"] == line["metadata"]
 | |
|             ):
 | |
|                 # If the last line in the aggregated list
 | |
|                 # has the same metadata as the current line,
 | |
|                 # append the current content to the last lines's content
 | |
|                 aggregated_chunks[-1]["content"] += "  \n" + line["content"]
 | |
|             else:
 | |
|                 # Otherwise, append the current line to the aggregated list
 | |
|                 aggregated_chunks.append(line)
 | |
| 
 | |
|         return [
 | |
|             Document(page_content=chunk["content"], metadata=chunk["metadata"])
 | |
|             for chunk in aggregated_chunks
 | |
|         ]
 | |
| 
 | |
|     def split_text(self, text: str) -> list[Document]:
 | |
|         """Split markdown file
 | |
|         Args:
 | |
|             text: Markdown file"""
 | |
| 
 | |
|         # Split the input text by newline character ("\n").
 | |
|         lines = text.split("\n")
 | |
|         # Final output
 | |
|         lines_with_metadata: list[LineType] = []
 | |
|         # Content and metadata of the chunk currently being processed
 | |
|         current_content: list[str] = []
 | |
|         current_metadata: dict[str, str] = {}
 | |
|         # Keep track of the nested header structure
 | |
|         # header_stack: List[Dict[str, Union[int, str]]] = []
 | |
|         header_stack: list[HeaderType] = []
 | |
|         initial_metadata: dict[str, str] = {}
 | |
| 
 | |
|         for line in lines:
 | |
|             stripped_line = line.strip()
 | |
|             # Check each line against each of the header types (e.g., #, ##)
 | |
|             for sep, name in self.headers_to_split_on:
 | |
|                 # Check if line starts with a header that we intend to split on
 | |
|                 if stripped_line.startswith(sep) and (
 | |
|                         # Header with no text OR header is followed by space
 | |
|                         # Both are valid conditions that sep is being used a header
 | |
|                         len(stripped_line) == len(sep)
 | |
|                         or stripped_line[len(sep)] == " "
 | |
|                 ):
 | |
|                     # Ensure we are tracking the header as metadata
 | |
|                     if name is not None:
 | |
|                         # Get the current header level
 | |
|                         current_header_level = sep.count("#")
 | |
| 
 | |
|                         # Pop out headers of lower or same level from the stack
 | |
|                         while (
 | |
|                                 header_stack
 | |
|                                 and header_stack[-1]["level"] >= current_header_level
 | |
|                         ):
 | |
|                             # We have encountered a new header
 | |
|                             # at the same or higher level
 | |
|                             popped_header = header_stack.pop()
 | |
|                             # Clear the metadata for the
 | |
|                             # popped header in initial_metadata
 | |
|                             if popped_header["name"] in initial_metadata:
 | |
|                                 initial_metadata.pop(popped_header["name"])
 | |
| 
 | |
|                         # Push the current header to the stack
 | |
|                         header: HeaderType = {
 | |
|                             "level": current_header_level,
 | |
|                             "name": name,
 | |
|                             "data": stripped_line[len(sep):].strip(),
 | |
|                         }
 | |
|                         header_stack.append(header)
 | |
|                         # Update initial_metadata with the current header
 | |
|                         initial_metadata[name] = header["data"]
 | |
| 
 | |
|                     # Add the previous line to the lines_with_metadata
 | |
|                     # only if current_content is not empty
 | |
|                     if current_content:
 | |
|                         lines_with_metadata.append(
 | |
|                             {
 | |
|                                 "content": "\n".join(current_content),
 | |
|                                 "metadata": current_metadata.copy(),
 | |
|                             }
 | |
|                         )
 | |
|                         current_content.clear()
 | |
| 
 | |
|                     break
 | |
|             else:
 | |
|                 if stripped_line:
 | |
|                     current_content.append(stripped_line)
 | |
|                 elif current_content:
 | |
|                     lines_with_metadata.append(
 | |
|                         {
 | |
|                             "content": "\n".join(current_content),
 | |
|                             "metadata": current_metadata.copy(),
 | |
|                         }
 | |
|                     )
 | |
|                     current_content.clear()
 | |
| 
 | |
|             current_metadata = initial_metadata.copy()
 | |
| 
 | |
|         if current_content:
 | |
|             lines_with_metadata.append(
 | |
|                 {"content": "\n".join(current_content), "metadata": current_metadata}
 | |
|             )
 | |
| 
 | |
|         # lines_with_metadata has each line with associated header metadata
 | |
|         # aggregate these into chunks based on common metadata
 | |
|         if not self.return_each_line:
 | |
|             return self.aggregate_lines_to_chunks(lines_with_metadata)
 | |
|         else:
 | |
|             return [
 | |
|                 Document(page_content=chunk["content"], metadata=chunk["metadata"])
 | |
|                 for chunk in lines_with_metadata
 | |
|             ]
 | |
| 
 | |
| 
 | |
| # should be in newer Python versions (3.10+)
 | |
| # @dataclass(frozen=True, kw_only=True, slots=True)
 | |
| @dataclass(frozen=True)
 | |
| class Tokenizer:
 | |
|     chunk_overlap: int
 | |
|     tokens_per_chunk: int
 | |
|     decode: Callable[[list[int]], str]
 | |
|     encode: Callable[[str], list[int]]
 | |
| 
 | |
| 
 | |
| def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
 | |
|     """Split incoming text and return chunks using tokenizer."""
 | |
|     splits: list[str] = []
 | |
|     input_ids = tokenizer.encode(text)
 | |
|     start_idx = 0
 | |
|     cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
 | |
|     chunk_ids = input_ids[start_idx:cur_idx]
 | |
|     while start_idx < len(input_ids):
 | |
|         splits.append(tokenizer.decode(chunk_ids))
 | |
|         start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
 | |
|         cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
 | |
|         chunk_ids = input_ids[start_idx:cur_idx]
 | |
|     return splits
 | |
| 
 | |
| 
 | |
| class TokenTextSplitter(TextSplitter):
 | |
|     """Splitting text to tokens using model tokenizer."""
 | |
| 
 | |
|     def __init__(
 | |
|             self,
 | |
|             encoding_name: str = "gpt2",
 | |
|             model_name: Optional[str] = None,
 | |
|             allowed_special: Union[Literal["all"], Set[str]] = set(),
 | |
|             disallowed_special: Union[Literal["all"], Collection[str]] = "all",
 | |
|             **kwargs: Any,
 | |
|     ) -> None:
 | |
|         """Create a new TextSplitter."""
 | |
|         super().__init__(**kwargs)
 | |
|         try:
 | |
|             import tiktoken
 | |
|         except ImportError:
 | |
|             raise ImportError(
 | |
|                 "Could not import tiktoken python package. "
 | |
|                 "This is needed in order to for TokenTextSplitter. "
 | |
|                 "Please install it with `pip install tiktoken`."
 | |
|             )
 | |
| 
 | |
|         if model_name is not None:
 | |
|             enc = tiktoken.encoding_for_model(model_name)
 | |
|         else:
 | |
|             enc = tiktoken.get_encoding(encoding_name)
 | |
|         self._tokenizer = enc
 | |
|         self._allowed_special = allowed_special
 | |
|         self._disallowed_special = disallowed_special
 | |
| 
 | |
|     def split_text(self, text: str) -> list[str]:
 | |
|         def _encode(_text: str) -> list[int]:
 | |
|             return self._tokenizer.encode(
 | |
|                 _text,
 | |
|                 allowed_special=self._allowed_special,
 | |
|                 disallowed_special=self._disallowed_special,
 | |
|             )
 | |
| 
 | |
|         tokenizer = Tokenizer(
 | |
|             chunk_overlap=self._chunk_overlap,
 | |
|             tokens_per_chunk=self._chunk_size,
 | |
|             decode=self._tokenizer.decode,
 | |
|             encode=_encode,
 | |
|         )
 | |
| 
 | |
|         return split_text_on_tokens(text=text, tokenizer=tokenizer)
 | |
| 
 | |
| 
 | |
| class Language(str, Enum):
 | |
|     """Enum of the programming languages."""
 | |
| 
 | |
|     CPP = "cpp"
 | |
|     GO = "go"
 | |
|     JAVA = "java"
 | |
|     JS = "js"
 | |
|     PHP = "php"
 | |
|     PROTO = "proto"
 | |
|     PYTHON = "python"
 | |
|     RST = "rst"
 | |
|     RUBY = "ruby"
 | |
|     RUST = "rust"
 | |
|     SCALA = "scala"
 | |
|     SWIFT = "swift"
 | |
|     MARKDOWN = "markdown"
 | |
|     LATEX = "latex"
 | |
|     HTML = "html"
 | |
|     SOL = "sol"
 | |
| 
 | |
| 
 | |
| class RecursiveCharacterTextSplitter(TextSplitter):
 | |
|     """Splitting text by recursively look at characters.
 | |
| 
 | |
|     Recursively tries to split by different characters to find one
 | |
|     that works.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|             self,
 | |
|             separators: Optional[list[str]] = None,
 | |
|             keep_separator: bool = True,
 | |
|             **kwargs: Any,
 | |
|     ) -> None:
 | |
|         """Create a new TextSplitter."""
 | |
|         super().__init__(keep_separator=keep_separator, **kwargs)
 | |
|         self._separators = separators or ["\n\n", "\n", " ", ""]
 | |
| 
 | |
|     def _split_text(self, text: str, separators: list[str]) -> list[str]:
 | |
|         """Split incoming text and return chunks."""
 | |
|         final_chunks = []
 | |
|         # Get appropriate separator to use
 | |
|         separator = separators[-1]
 | |
|         new_separators = []
 | |
|         for i, _s in enumerate(separators):
 | |
|             if _s == "":
 | |
|                 separator = _s
 | |
|                 break
 | |
|             if re.search(_s, text):
 | |
|                 separator = _s
 | |
|                 new_separators = separators[i + 1:]
 | |
|                 break
 | |
| 
 | |
|         splits = _split_text_with_regex(text, separator, self._keep_separator)
 | |
|         # Now go merging things, recursively splitting longer texts.
 | |
|         _good_splits = []
 | |
|         _separator = "" if self._keep_separator else separator
 | |
|         for s in splits:
 | |
|             if self._length_function(s) < self._chunk_size:
 | |
|                 _good_splits.append(s)
 | |
|             else:
 | |
|                 if _good_splits:
 | |
|                     merged_text = self._merge_splits(_good_splits, _separator)
 | |
|                     final_chunks.extend(merged_text)
 | |
|                     _good_splits = []
 | |
|                 if not new_separators:
 | |
|                     final_chunks.append(s)
 | |
|                 else:
 | |
|                     other_info = self._split_text(s, new_separators)
 | |
|                     final_chunks.extend(other_info)
 | |
|         if _good_splits:
 | |
|             merged_text = self._merge_splits(_good_splits, _separator)
 | |
|             final_chunks.extend(merged_text)
 | |
|         return final_chunks
 | |
| 
 | |
|     def split_text(self, text: str) -> list[str]:
 | |
|         return self._split_text(text, self._separators)
 | |
| 
 | |
|     @classmethod
 | |
|     def from_language(
 | |
|             cls, language: Language, **kwargs: Any
 | |
|     ) -> RecursiveCharacterTextSplitter:
 | |
|         separators = cls.get_separators_for_language(language)
 | |
|         return cls(separators=separators, **kwargs)
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_separators_for_language(language: Language) -> list[str]:
 | |
|         if language == Language.CPP:
 | |
|             return [
 | |
|                 # Split along class definitions
 | |
|                 "\nclass ",
 | |
|                 # Split along function definitions
 | |
|                 "\nvoid ",
 | |
|                 "\nint ",
 | |
|                 "\nfloat ",
 | |
|                 "\ndouble ",
 | |
|                 # Split along control flow statements
 | |
|                 "\nif ",
 | |
|                 "\nfor ",
 | |
|                 "\nwhile ",
 | |
|                 "\nswitch ",
 | |
|                 "\ncase ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.GO:
 | |
|             return [
 | |
|                 # Split along function definitions
 | |
|                 "\nfunc ",
 | |
|                 "\nvar ",
 | |
|                 "\nconst ",
 | |
|                 "\ntype ",
 | |
|                 # Split along control flow statements
 | |
|                 "\nif ",
 | |
|                 "\nfor ",
 | |
|                 "\nswitch ",
 | |
|                 "\ncase ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.JAVA:
 | |
|             return [
 | |
|                 # Split along class definitions
 | |
|                 "\nclass ",
 | |
|                 # Split along method definitions
 | |
|                 "\npublic ",
 | |
|                 "\nprotected ",
 | |
|                 "\nprivate ",
 | |
|                 "\nstatic ",
 | |
|                 # Split along control flow statements
 | |
|                 "\nif ",
 | |
|                 "\nfor ",
 | |
|                 "\nwhile ",
 | |
|                 "\nswitch ",
 | |
|                 "\ncase ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.JS:
 | |
|             return [
 | |
|                 # Split along function definitions
 | |
|                 "\nfunction ",
 | |
|                 "\nconst ",
 | |
|                 "\nlet ",
 | |
|                 "\nvar ",
 | |
|                 "\nclass ",
 | |
|                 # Split along control flow statements
 | |
|                 "\nif ",
 | |
|                 "\nfor ",
 | |
|                 "\nwhile ",
 | |
|                 "\nswitch ",
 | |
|                 "\ncase ",
 | |
|                 "\ndefault ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.PHP:
 | |
|             return [
 | |
|                 # Split along function definitions
 | |
|                 "\nfunction ",
 | |
|                 # Split along class definitions
 | |
|                 "\nclass ",
 | |
|                 # Split along control flow statements
 | |
|                 "\nif ",
 | |
|                 "\nforeach ",
 | |
|                 "\nwhile ",
 | |
|                 "\ndo ",
 | |
|                 "\nswitch ",
 | |
|                 "\ncase ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.PROTO:
 | |
|             return [
 | |
|                 # Split along message definitions
 | |
|                 "\nmessage ",
 | |
|                 # Split along service definitions
 | |
|                 "\nservice ",
 | |
|                 # Split along enum definitions
 | |
|                 "\nenum ",
 | |
|                 # Split along option definitions
 | |
|                 "\noption ",
 | |
|                 # Split along import statements
 | |
|                 "\nimport ",
 | |
|                 # Split along syntax declarations
 | |
|                 "\nsyntax ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.PYTHON:
 | |
|             return [
 | |
|                 # First, try to split along class definitions
 | |
|                 "\nclass ",
 | |
|                 "\ndef ",
 | |
|                 "\n\tdef ",
 | |
|                 # Now split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.RST:
 | |
|             return [
 | |
|                 # Split along section titles
 | |
|                 "\n=+\n",
 | |
|                 "\n-+\n",
 | |
|                 "\n\*+\n",
 | |
|                 # Split along directive markers
 | |
|                 "\n\n.. *\n\n",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.RUBY:
 | |
|             return [
 | |
|                 # Split along method definitions
 | |
|                 "\ndef ",
 | |
|                 "\nclass ",
 | |
|                 # Split along control flow statements
 | |
|                 "\nif ",
 | |
|                 "\nunless ",
 | |
|                 "\nwhile ",
 | |
|                 "\nfor ",
 | |
|                 "\ndo ",
 | |
|                 "\nbegin ",
 | |
|                 "\nrescue ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.RUST:
 | |
|             return [
 | |
|                 # Split along function definitions
 | |
|                 "\nfn ",
 | |
|                 "\nconst ",
 | |
|                 "\nlet ",
 | |
|                 # Split along control flow statements
 | |
|                 "\nif ",
 | |
|                 "\nwhile ",
 | |
|                 "\nfor ",
 | |
|                 "\nloop ",
 | |
|                 "\nmatch ",
 | |
|                 "\nconst ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.SCALA:
 | |
|             return [
 | |
|                 # Split along class definitions
 | |
|                 "\nclass ",
 | |
|                 "\nobject ",
 | |
|                 # Split along method definitions
 | |
|                 "\ndef ",
 | |
|                 "\nval ",
 | |
|                 "\nvar ",
 | |
|                 # Split along control flow statements
 | |
|                 "\nif ",
 | |
|                 "\nfor ",
 | |
|                 "\nwhile ",
 | |
|                 "\nmatch ",
 | |
|                 "\ncase ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.SWIFT:
 | |
|             return [
 | |
|                 # Split along function definitions
 | |
|                 "\nfunc ",
 | |
|                 # Split along class definitions
 | |
|                 "\nclass ",
 | |
|                 "\nstruct ",
 | |
|                 "\nenum ",
 | |
|                 # Split along control flow statements
 | |
|                 "\nif ",
 | |
|                 "\nfor ",
 | |
|                 "\nwhile ",
 | |
|                 "\ndo ",
 | |
|                 "\nswitch ",
 | |
|                 "\ncase ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.MARKDOWN:
 | |
|             return [
 | |
|                 # First, try to split along Markdown headings (starting with level 2)
 | |
|                 "\n#{1,6} ",
 | |
|                 # Note the alternative syntax for headings (below) is not handled here
 | |
|                 # Heading level 2
 | |
|                 # ---------------
 | |
|                 # End of code block
 | |
|                 "```\n",
 | |
|                 # Horizontal lines
 | |
|                 "\n\*\*\*+\n",
 | |
|                 "\n---+\n",
 | |
|                 "\n___+\n",
 | |
|                 # Note that this splitter doesn't handle horizontal lines defined
 | |
|                 # by *three or more* of ***, ---, or ___, but this is not handled
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.LATEX:
 | |
|             return [
 | |
|                 # First, try to split along Latex sections
 | |
|                 "\n\\\chapter{",
 | |
|                 "\n\\\section{",
 | |
|                 "\n\\\subsection{",
 | |
|                 "\n\\\subsubsection{",
 | |
|                 # Now split by environments
 | |
|                 "\n\\\begin{enumerate}",
 | |
|                 "\n\\\begin{itemize}",
 | |
|                 "\n\\\begin{description}",
 | |
|                 "\n\\\begin{list}",
 | |
|                 "\n\\\begin{quote}",
 | |
|                 "\n\\\begin{quotation}",
 | |
|                 "\n\\\begin{verse}",
 | |
|                 "\n\\\begin{verbatim}",
 | |
|                 # Now split by math environments
 | |
|                 "\n\\\begin{align}",
 | |
|                 "$$",
 | |
|                 "$",
 | |
|                 # Now split by the normal type of lines
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.HTML:
 | |
|             return [
 | |
|                 # First, try to split along HTML tags
 | |
|                 "<body",
 | |
|                 "<div",
 | |
|                 "<p",
 | |
|                 "<br",
 | |
|                 "<li",
 | |
|                 "<h1",
 | |
|                 "<h2",
 | |
|                 "<h3",
 | |
|                 "<h4",
 | |
|                 "<h5",
 | |
|                 "<h6",
 | |
|                 "<span",
 | |
|                 "<table",
 | |
|                 "<tr",
 | |
|                 "<td",
 | |
|                 "<th",
 | |
|                 "<ul",
 | |
|                 "<ol",
 | |
|                 "<header",
 | |
|                 "<footer",
 | |
|                 "<nav",
 | |
|                 # Head
 | |
|                 "<head",
 | |
|                 "<style",
 | |
|                 "<script",
 | |
|                 "<meta",
 | |
|                 "<title",
 | |
|                 "",
 | |
|             ]
 | |
|         elif language == Language.SOL:
 | |
|             return [
 | |
|                 # Split along compiler information definitions
 | |
|                 "\npragma ",
 | |
|                 "\nusing ",
 | |
|                 # Split along contract definitions
 | |
|                 "\ncontract ",
 | |
|                 "\ninterface ",
 | |
|                 "\nlibrary ",
 | |
|                 # Split along method definitions
 | |
|                 "\nconstructor ",
 | |
|                 "\ntype ",
 | |
|                 "\nfunction ",
 | |
|                 "\nevent ",
 | |
|                 "\nmodifier ",
 | |
|                 "\nerror ",
 | |
|                 "\nstruct ",
 | |
|                 "\nenum ",
 | |
|                 # Split along control flow statements
 | |
|                 "\nif ",
 | |
|                 "\nfor ",
 | |
|                 "\nwhile ",
 | |
|                 "\ndo while ",
 | |
|                 "\nassembly ",
 | |
|                 # Split by the normal type of lines
 | |
|                 "\n\n",
 | |
|                 "\n",
 | |
|                 " ",
 | |
|                 "",
 | |
|             ]
 | |
|         else:
 | |
|             raise ValueError(
 | |
|                 f"Language {language} is not supported! "
 | |
|                 f"Please choose from {list(Language)}"
 | |
|             )
 | 
