cleaning up

This commit is contained in:
David S. Batista 2025-06-04 17:36:55 +02:00
parent 20ea23409d
commit 34552f2e5d

View File

@ -2,42 +2,39 @@
#
# SPDX-License-Identifier: Apache-2.0
from haystack.components.preprocessors import DocumentSplitter
from copy import deepcopy
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
from haystack.lazy_imports import LazyImport
from typing import Any, Dict, List, Literal, Tuple
from more_itertools import windowed
from haystack import Document, component, logging
from haystack.components.preprocessors.sentence_tokenizer import Language, SentenceSplitter, nltk_imports
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.utils import deserialize_callable, serialize_callable
with LazyImport("Run 'pip install hanlp'") as hanlp:
import hanlp
from haystack import Document, component, logging
from haystack.components.preprocessors import DocumentSplitter
from haystack.lazy_imports import LazyImport
with LazyImport("Run 'pip install hanlp'") as hanlp_import:
import hanlp
logger = logging.getLogger(__name__)
# mapping of split by character, 'function' and 'sentence' don't split by character
_CHARACTER_SPLIT_BY_MAPPING = {
"page": "\f", "passage": "\n\n", "period": ".", "word": " ", "line": "\n"}
chinese_tokenizer_coarse = hanlp.load(
hanlp.pretrained.tok.COARSE_ELECTRA_SMALL_ZH)
_CHARACTER_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "period": ".", "word": " ", "line": "\n"}
hanlp_import.check()
chinese_tokenizer_coarse = hanlp.load(hanlp.pretrained.tok.COARSE_ELECTRA_SMALL_ZH)
chinese_tokenizer_fine = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH)
# 加载中文的句子切分器
split_sent = hanlp.load(hanlp.pretrained.eos.UD_CTB_EOS_MUL)
split_sent = hanlp.load(hanlp.pretrained.eos.UD_CTB_EOS_MUL) # 加载中文的句子切分器
@component
class chinese_DocumentSpliter(DocumentSplitter):
class ChineseDocumentspliter(DocumentSplitter):
def __init__(self, *args, particle_size: Literal["coarse", "fine"] = "coarse", **kwargs):
super(chinese_DocumentSpliter, self).__init__(*args, **kwargs)
super(ChineseDocumentspliter, self).__init__(*args, **kwargs)
# coarse代表粗颗粒度中文分词fine代表细颗粒度分词默认为粗颗粒度分词
# 'coarse' represents coarse granularity Chinese word segmentation, 'fine' represents fine granularity word segmentation, default is coarse granularity word segmentation
# 'coarse' represents coarse granularity Chinese word segmentation, 'fine' represents fine granularity word
# segmentation, default is coarse granularity word segmentation
self.particle_size = particle_size
# self.chinese_tokenizer_coarse = hanlp.load(hanlp.pretrained.tok.COARSE_ELECTRA_SMALL_ZH)
# self.chinese_tokenizer_fine = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH)
@ -47,12 +44,12 @@ class chinese_DocumentSpliter(DocumentSplitter):
def _split_by_character(self, doc) -> List[Document]:
split_at = _CHARACTER_SPLIT_BY_MAPPING[self.split_by]
if self.language == 'zh' and self.particle_size == "coarse":
if self.language == "zh" and self.particle_size == "coarse":
units = chinese_tokenizer_coarse(doc.content)
if self.language == 'zh' and self.particle_size == "fine":
if self.language == "zh" and self.particle_size == "fine":
units = chinese_tokenizer_fine(doc.content)
if self.language == 'en':
if self.language == "en":
units = doc.content.split(split_at)
# Add the delimiter back to all units except the last one
for i in range(len(units) - 1):
@ -67,7 +64,14 @@ class chinese_DocumentSpliter(DocumentSplitter):
)
# 定义一个函数用于处理中文分句
def chinese_sentence_split(self, text: str) -> list:
@staticmethod
def chinese_sentence_split(text: str) -> list:
"""
Segmentation of Chinese text.
:param text: The Chinese text to be segmented.
:returns: A list of dictionaries, each containing a sentence and its start and end indices.
"""
# 分句
sentences = split_sent(text)
@ -77,11 +81,7 @@ class chinese_DocumentSpliter(DocumentSplitter):
for sentence in sentences:
start = text.find(sentence, start)
end = start + len(sentence)
results.append({
'sentence': sentence + '\n',
'start': start,
'end': end
})
results.append({"sentence": sentence + "\n", "start": start, "end": end})
start = end
return results
@ -123,17 +123,17 @@ class chinese_DocumentSpliter(DocumentSplitter):
# chinese_tokenizer_fine = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH)
for sentence_idx, sentence in enumerate(sentences):
current_chunk.append(sentence)
if language == 'zh' and particle_size == "coarse":
if language == "zh" and particle_size == "coarse":
chunk_word_count += len(chinese_tokenizer_coarse(sentence))
next_sentence_word_count = (
len(chinese_tokenizer_coarse(
sentences[sentence_idx + 1])) if sentence_idx < len(sentences) - 1 else 0
len(chinese_tokenizer_coarse(sentences[sentence_idx + 1]))
if sentence_idx < len(sentences) - 1
else 0
)
if language == 'zh' and particle_size == "fine":
if language == "zh" and particle_size == "fine":
chunk_word_count += len(chinese_tokenizer_fine(sentence))
next_sentence_word_count = (
len(chinese_tokenizer_fine(
sentences[sentence_idx + 1])) if sentence_idx < len(sentences) - 1 else 0
len(chinese_tokenizer_fine(sentences[sentence_idx + 1])) if sentence_idx < len(sentences) - 1 else 0
)
# Number of words in the current chunk plus the next sentence is larger than the split_length,
@ -145,25 +145,25 @@ class chinese_DocumentSpliter(DocumentSplitter):
split_start_indices.append(chunk_start_idx)
# Get the number of sentences that overlap with the next chunk
num_sentences_to_keep = chinese_DocumentSpliter._number_of_sentences_to_keep(
sentences=current_chunk, split_length=split_length, split_overlap=split_overlap, language=language, particle_size=particle_size
num_sentences_to_keep = ChineseDocumentspliter._number_of_sentences_to_keep(
sentences=current_chunk,
split_length=split_length,
split_overlap=split_overlap,
language=language,
particle_size=particle_size,
)
# Set up information for the new chunk
if num_sentences_to_keep > 0:
# Processed sentences are the ones that are not overlapping with the next chunk
processed_sentences = current_chunk[:-
num_sentences_to_keep]
chunk_starting_page_number += sum(sent.count("\f")
for sent in processed_sentences)
processed_sentences = current_chunk[:-num_sentences_to_keep]
chunk_starting_page_number += sum(sent.count("\f") for sent in processed_sentences)
chunk_start_idx += len("".join(processed_sentences))
# Next chunk starts with the sentences that were overlapping with the previous chunk
current_chunk = current_chunk[-num_sentences_to_keep:]
chunk_word_count = sum(len(s.split())
for s in current_chunk)
chunk_word_count = sum(len(s.split()) for s in current_chunk)
else:
# Here processed_sentences is the same as current_chunk since there is no overlap
chunk_starting_page_number += sum(sent.count("\f")
for sent in current_chunk)
chunk_starting_page_number += sum(sent.count("\f") for sent in current_chunk)
chunk_start_idx += len("".join(current_chunk))
current_chunk = []
chunk_word_count = 0
@ -181,18 +181,21 @@ class chinese_DocumentSpliter(DocumentSplitter):
def _split_by_nltk_sentence(self, doc: Document) -> List[Document]:
split_docs = []
if self.language == 'zh':
result = self.chinese_sentence_split(doc.content)
if self.language == 'en':
result = self.sentence_splitter.split_sentences(
doc.content) # type: ignore # None check is done in run()
if self.language == "zh":
result = ChineseDocumentspliter.chinese_sentence_split(doc.content)
if self.language == "en":
result = self.sentence_splitter.split_sentences(doc.content) # type: ignore # None check is done in run()
units = [sentence["sentence"] for sentence in result]
if self.respect_sentence_boundary:
text_splits, splits_pages, splits_start_idxs = self._concatenate_sentences_based_on_word_amount(
sentences=units, split_length=self.split_length, split_overlap=self.split_overlap, language=self.language,
particle_size=self.particle_size)
sentences=units,
split_length=self.split_length,
split_overlap=self.split_overlap,
language=self.language,
particle_size=self.particle_size,
)
else:
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
elements=units,
@ -224,8 +227,7 @@ class chinese_DocumentSpliter(DocumentSplitter):
splits_start_idxs: List[int] = []
cur_start_idx = 0
cur_page = 1
segments = windowed(elements, n=split_length,
step=split_length - split_overlap)
segments = windowed(elements, n=split_length, step=split_length - split_overlap)
for seg in segments:
current_units = [unit for unit in seg if unit is not None]
@ -248,8 +250,7 @@ class chinese_DocumentSpliter(DocumentSplitter):
if self.split_by == "page":
num_page_breaks = len(processed_units)
else:
num_page_breaks = sum(processed_unit.count("\f")
for processed_unit in processed_units)
num_page_breaks = sum(processed_unit.count("\f") for processed_unit in processed_units)
cur_page += num_page_breaks
@ -282,11 +283,10 @@ class chinese_DocumentSpliter(DocumentSplitter):
doc_start_idx = splits_start_idxs[i]
previous_doc = documents[i - 1]
previous_doc_start_idx = splits_start_idxs[i - 1]
self._add_split_overlap_information(
doc, doc_start_idx, previous_doc, previous_doc_start_idx)
self._add_split_overlap_information(doc, doc_start_idx, previous_doc, previous_doc_start_idx)
for d in documents:
d.content=d.content.replace(" ","")
d.content = d.content.replace(" ", "")
return documents
@staticmethod
@ -301,26 +301,24 @@ class chinese_DocumentSpliter(DocumentSplitter):
:param previous_doc: The Document that was split before the current Document.
:param previous_doc_start_idx: The starting index of the previous Document.
"""
overlapping_range = (current_doc_start_idx - previous_doc_start_idx,
len(previous_doc.content)) # type: ignore
overlapping_range = (current_doc_start_idx - previous_doc_start_idx, len(previous_doc.content)) # type: ignore
if overlapping_range[0] < overlapping_range[1]:
# type: ignore
overlapping_str = previous_doc.content[overlapping_range[0]: overlapping_range[1]]
overlapping_str = previous_doc.content[overlapping_range[0] : overlapping_range[1]]
if current_doc.content.startswith(overlapping_str): # type: ignore
# add split overlap information to this Document regarding the previous Document
current_doc.meta["_split_overlap"].append(
{"doc_id": previous_doc.id, "range": overlapping_range})
current_doc.meta["_split_overlap"].append({"doc_id": previous_doc.id, "range": overlapping_range})
# add split overlap information to previous Document regarding this Document
overlapping_range = (
0, overlapping_range[1] - overlapping_range[0])
previous_doc.meta["_split_overlap"].append(
{"doc_id": current_doc.id, "range": overlapping_range})
overlapping_range = (0, overlapping_range[1] - overlapping_range[0])
previous_doc.meta["_split_overlap"].append({"doc_id": current_doc.id, "range": overlapping_range})
@staticmethod
def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int, language: str, particle_size: str) -> int:
def _number_of_sentences_to_keep(
sentences: List[str], split_length: int, split_overlap: int, language: str, particle_size: str
) -> int:
"""
Returns the number of sentences to keep in the next chunk based on the `split_overlap` and `split_length`.
@ -339,10 +337,10 @@ class chinese_DocumentSpliter(DocumentSplitter):
# chinese_tokenizer_fine = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH)
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
for sent in reversed(sentences[1:]):
if language == 'zh' and particle_size == "coarse":
if language == "zh" and particle_size == "coarse":
num_words += len(chinese_tokenizer_coarse(sent))
# num_words += len(sent.split())
if language == 'zh' and particle_size == "fine":
if language == "zh" and particle_size == "fine":
num_words += len(chinese_tokenizer_fine(sent))
# If the number of words is larger than the split_length then don't add any more sentences
if num_words > split_length:
@ -350,5 +348,5 @@ class chinese_DocumentSpliter(DocumentSplitter):
num_sentences_to_keep += 1
if num_words > split_overlap:
break
return num_sentences_to_keep
return num_sentences_to_keep