| 
									
										
										
										
											2023-05-16 12:57:25 +08:00
										 |  |  | """Functionality for splitting text.""" | 
					
						
							|  |  |  | from __future__ import annotations | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-12 18:45:34 +08:00
										 |  |  | from typing import Any, List, Optional, cast | 
					
						
							| 
									
										
										
										
											2023-05-16 12:57:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-12 18:45:34 +08:00
										 |  |  | from core.model_manager import ModelInstance | 
					
						
							|  |  |  | from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | 
					
						
							| 
									
										
										
										
											2024-01-03 13:02:56 +08:00
										 |  |  | from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter, | 
					
						
							|  |  |  |                                      TokenTextSplitter, Type, Union) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-16 12:57:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-03 13:02:56 +08:00
										 |  |  | class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |         This class is used to implement from_gpt2_encoder, to prevent using of tiktoken | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2024-01-12 18:45:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-03 13:02:56 +08:00
										 |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-01-12 18:45:34 +08:00
										 |  |  |     def from_encoder( | 
					
						
							|  |  |  |             cls: Type[TS], | 
					
						
							|  |  |  |             embedding_model_instance: Optional[ModelInstance], | 
					
						
							|  |  |  |             allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | 
					
						
							|  |  |  |             disallowed_special: Union[Literal["all"], Collection[str]] = "all", | 
					
						
							|  |  |  |             **kwargs: Any, | 
					
						
							| 
									
										
										
										
											2024-01-03 13:02:56 +08:00
										 |  |  |     ): | 
					
						
							|  |  |  |         def _token_encoder(text: str) -> int: | 
					
						
							| 
									
										
										
										
											2024-01-12 18:45:34 +08:00
										 |  |  |             if embedding_model_instance: | 
					
						
							|  |  |  |                 embedding_model_type_instance = embedding_model_instance.model_type_instance | 
					
						
							|  |  |  |                 embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | 
					
						
							|  |  |  |                 return embedding_model_type_instance.get_num_tokens( | 
					
						
							|  |  |  |                     model=embedding_model_instance.model, | 
					
						
							|  |  |  |                     credentials=embedding_model_instance.credentials, | 
					
						
							|  |  |  |                     texts=[text] | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 return GPT2Tokenizer.get_num_tokens(text) | 
					
						
							| 
									
										
										
										
											2024-01-03 13:02:56 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if issubclass(cls, TokenTextSplitter): | 
					
						
							|  |  |  |             extra_kwargs = { | 
					
						
							| 
									
										
										
										
											2024-01-12 18:45:34 +08:00
										 |  |  |                 "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2', | 
					
						
							| 
									
										
										
										
											2024-01-03 13:02:56 +08:00
										 |  |  |                 "allowed_special": allowed_special, | 
					
						
							|  |  |  |                 "disallowed_special": disallowed_special, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             kwargs = {**kwargs, **extra_kwargs} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return cls(length_function=_token_encoder, **kwargs) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-12 18:45:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-03 13:02:56 +08:00
										 |  |  | class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): | 
					
						
							| 
									
										
										
										
											2023-05-16 12:57:25 +08:00
										 |  |  |     def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): | 
					
						
							|  |  |  |         """Create a new TextSplitter.""" | 
					
						
							|  |  |  |         super().__init__(**kwargs) | 
					
						
							|  |  |  |         self._fixed_separator = fixed_separator | 
					
						
							|  |  |  |         self._separators = separators or ["\n\n", "\n", " ", ""] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def split_text(self, text: str) -> List[str]: | 
					
						
							|  |  |  |         """Split incoming text and return chunks.""" | 
					
						
							|  |  |  |         if self._fixed_separator: | 
					
						
							|  |  |  |             chunks = text.split(self._fixed_separator) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             chunks = list(text) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         final_chunks = [] | 
					
						
							|  |  |  |         for chunk in chunks: | 
					
						
							|  |  |  |             if self._length_function(chunk) > self._chunk_size: | 
					
						
							|  |  |  |                 final_chunks.extend(self.recursive_split_text(chunk)) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 final_chunks.append(chunk) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return final_chunks | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def recursive_split_text(self, text: str) -> List[str]: | 
					
						
							|  |  |  |         """Split incoming text and return chunks.""" | 
					
						
							|  |  |  |         final_chunks = [] | 
					
						
							|  |  |  |         # Get appropriate separator to use | 
					
						
							|  |  |  |         separator = self._separators[-1] | 
					
						
							|  |  |  |         for _s in self._separators: | 
					
						
							|  |  |  |             if _s == "": | 
					
						
							|  |  |  |                 separator = _s | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |             if _s in text: | 
					
						
							|  |  |  |                 separator = _s | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |         # Now that we have the separator, split the text | 
					
						
							|  |  |  |         if separator: | 
					
						
							|  |  |  |             splits = text.split(separator) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             splits = list(text) | 
					
						
							|  |  |  |         # Now go merging things, recursively splitting longer texts. | 
					
						
							|  |  |  |         _good_splits = [] | 
					
						
							|  |  |  |         for s in splits: | 
					
						
							|  |  |  |             if self._length_function(s) < self._chunk_size: | 
					
						
							|  |  |  |                 _good_splits.append(s) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 if _good_splits: | 
					
						
							|  |  |  |                     merged_text = self._merge_splits(_good_splits, separator) | 
					
						
							|  |  |  |                     final_chunks.extend(merged_text) | 
					
						
							|  |  |  |                     _good_splits = [] | 
					
						
							|  |  |  |                 other_info = self.recursive_split_text(s) | 
					
						
							|  |  |  |                 final_chunks.extend(other_info) | 
					
						
							|  |  |  |         if _good_splits: | 
					
						
							|  |  |  |             merged_text = self._merge_splits(_good_splits, separator) | 
					
						
							|  |  |  |             final_chunks.extend(merged_text) | 
					
						
							| 
									
										
										
										
											2024-01-12 18:45:34 +08:00
										 |  |  |         return final_chunks |