mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
cleaning up
This commit is contained in:
parent
20ea23409d
commit
34552f2e5d
@ -2,42 +2,39 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from haystack.components.preprocessors import DocumentSplitter
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from typing import Any, Dict, List, Literal, Tuple
|
||||
|
||||
from more_itertools import windowed
|
||||
from haystack import Document, component, logging
|
||||
from haystack.components.preprocessors.sentence_tokenizer import Language, SentenceSplitter, nltk_imports
|
||||
from haystack.core.serialization import default_from_dict, default_to_dict
|
||||
from haystack.utils import deserialize_callable, serialize_callable
|
||||
with LazyImport("Run 'pip install hanlp'") as hanlp:
|
||||
import hanlp
|
||||
|
||||
from haystack import Document, component, logging
|
||||
from haystack.components.preprocessors import DocumentSplitter
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport("Run 'pip install hanlp'") as hanlp_import:
|
||||
import hanlp
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# mapping of split by character, 'function' and 'sentence' don't split by character
|
||||
_CHARACTER_SPLIT_BY_MAPPING = {
|
||||
"page": "\f", "passage": "\n\n", "period": ".", "word": " ", "line": "\n"}
|
||||
chinese_tokenizer_coarse = hanlp.load(
|
||||
hanlp.pretrained.tok.COARSE_ELECTRA_SMALL_ZH)
|
||||
_CHARACTER_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "period": ".", "word": " ", "line": "\n"}
|
||||
|
||||
hanlp_import.check()
|
||||
|
||||
chinese_tokenizer_coarse = hanlp.load(hanlp.pretrained.tok.COARSE_ELECTRA_SMALL_ZH)
|
||||
chinese_tokenizer_fine = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH)
|
||||
# 加载中文的句子切分器
|
||||
split_sent = hanlp.load(hanlp.pretrained.eos.UD_CTB_EOS_MUL)
|
||||
split_sent = hanlp.load(hanlp.pretrained.eos.UD_CTB_EOS_MUL) # 加载中文的句子切分器
|
||||
|
||||
|
||||
@component
|
||||
class chinese_DocumentSpliter(DocumentSplitter):
|
||||
|
||||
class ChineseDocumentspliter(DocumentSplitter):
|
||||
def __init__(self, *args, particle_size: Literal["coarse", "fine"] = "coarse", **kwargs):
|
||||
super(chinese_DocumentSpliter, self).__init__(*args, **kwargs)
|
||||
super(ChineseDocumentspliter, self).__init__(*args, **kwargs)
|
||||
|
||||
# coarse代表粗颗粒度中文分词,fine代表细颗粒度分词,默认为粗颗粒度分词
|
||||
# 'coarse' represents coarse granularity Chinese word segmentation, 'fine' represents fine granularity word segmentation, default is coarse granularity word segmentation
|
||||
# 'coarse' represents coarse granularity Chinese word segmentation, 'fine' represents fine granularity word
|
||||
# segmentation, default is coarse granularity word segmentation
|
||||
self.particle_size = particle_size
|
||||
# self.chinese_tokenizer_coarse = hanlp.load(hanlp.pretrained.tok.COARSE_ELECTRA_SMALL_ZH)
|
||||
# self.chinese_tokenizer_fine = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH)
|
||||
@ -47,12 +44,12 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
|
||||
def _split_by_character(self, doc) -> List[Document]:
|
||||
split_at = _CHARACTER_SPLIT_BY_MAPPING[self.split_by]
|
||||
if self.language == 'zh' and self.particle_size == "coarse":
|
||||
if self.language == "zh" and self.particle_size == "coarse":
|
||||
units = chinese_tokenizer_coarse(doc.content)
|
||||
|
||||
if self.language == 'zh' and self.particle_size == "fine":
|
||||
if self.language == "zh" and self.particle_size == "fine":
|
||||
units = chinese_tokenizer_fine(doc.content)
|
||||
if self.language == 'en':
|
||||
if self.language == "en":
|
||||
units = doc.content.split(split_at)
|
||||
# Add the delimiter back to all units except the last one
|
||||
for i in range(len(units) - 1):
|
||||
@ -67,7 +64,14 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
)
|
||||
|
||||
# 定义一个函数用于处理中文分句
|
||||
def chinese_sentence_split(self, text: str) -> list:
|
||||
@staticmethod
|
||||
def chinese_sentence_split(text: str) -> list:
|
||||
"""
|
||||
Segmentation of Chinese text.
|
||||
|
||||
:param text: The Chinese text to be segmented.
|
||||
:returns: A list of dictionaries, each containing a sentence and its start and end indices.
|
||||
"""
|
||||
# 分句
|
||||
sentences = split_sent(text)
|
||||
|
||||
@ -77,11 +81,7 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
for sentence in sentences:
|
||||
start = text.find(sentence, start)
|
||||
end = start + len(sentence)
|
||||
results.append({
|
||||
'sentence': sentence + '\n',
|
||||
'start': start,
|
||||
'end': end
|
||||
})
|
||||
results.append({"sentence": sentence + "\n", "start": start, "end": end})
|
||||
start = end
|
||||
|
||||
return results
|
||||
@ -123,17 +123,17 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
# chinese_tokenizer_fine = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH)
|
||||
for sentence_idx, sentence in enumerate(sentences):
|
||||
current_chunk.append(sentence)
|
||||
if language == 'zh' and particle_size == "coarse":
|
||||
if language == "zh" and particle_size == "coarse":
|
||||
chunk_word_count += len(chinese_tokenizer_coarse(sentence))
|
||||
next_sentence_word_count = (
|
||||
len(chinese_tokenizer_coarse(
|
||||
sentences[sentence_idx + 1])) if sentence_idx < len(sentences) - 1 else 0
|
||||
len(chinese_tokenizer_coarse(sentences[sentence_idx + 1]))
|
||||
if sentence_idx < len(sentences) - 1
|
||||
else 0
|
||||
)
|
||||
if language == 'zh' and particle_size == "fine":
|
||||
if language == "zh" and particle_size == "fine":
|
||||
chunk_word_count += len(chinese_tokenizer_fine(sentence))
|
||||
next_sentence_word_count = (
|
||||
len(chinese_tokenizer_fine(
|
||||
sentences[sentence_idx + 1])) if sentence_idx < len(sentences) - 1 else 0
|
||||
len(chinese_tokenizer_fine(sentences[sentence_idx + 1])) if sentence_idx < len(sentences) - 1 else 0
|
||||
)
|
||||
|
||||
# Number of words in the current chunk plus the next sentence is larger than the split_length,
|
||||
@ -145,25 +145,25 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
split_start_indices.append(chunk_start_idx)
|
||||
|
||||
# Get the number of sentences that overlap with the next chunk
|
||||
num_sentences_to_keep = chinese_DocumentSpliter._number_of_sentences_to_keep(
|
||||
sentences=current_chunk, split_length=split_length, split_overlap=split_overlap, language=language, particle_size=particle_size
|
||||
num_sentences_to_keep = ChineseDocumentspliter._number_of_sentences_to_keep(
|
||||
sentences=current_chunk,
|
||||
split_length=split_length,
|
||||
split_overlap=split_overlap,
|
||||
language=language,
|
||||
particle_size=particle_size,
|
||||
)
|
||||
# Set up information for the new chunk
|
||||
if num_sentences_to_keep > 0:
|
||||
# Processed sentences are the ones that are not overlapping with the next chunk
|
||||
processed_sentences = current_chunk[:-
|
||||
num_sentences_to_keep]
|
||||
chunk_starting_page_number += sum(sent.count("\f")
|
||||
for sent in processed_sentences)
|
||||
processed_sentences = current_chunk[:-num_sentences_to_keep]
|
||||
chunk_starting_page_number += sum(sent.count("\f") for sent in processed_sentences)
|
||||
chunk_start_idx += len("".join(processed_sentences))
|
||||
# Next chunk starts with the sentences that were overlapping with the previous chunk
|
||||
current_chunk = current_chunk[-num_sentences_to_keep:]
|
||||
chunk_word_count = sum(len(s.split())
|
||||
for s in current_chunk)
|
||||
chunk_word_count = sum(len(s.split()) for s in current_chunk)
|
||||
else:
|
||||
# Here processed_sentences is the same as current_chunk since there is no overlap
|
||||
chunk_starting_page_number += sum(sent.count("\f")
|
||||
for sent in current_chunk)
|
||||
chunk_starting_page_number += sum(sent.count("\f") for sent in current_chunk)
|
||||
chunk_start_idx += len("".join(current_chunk))
|
||||
current_chunk = []
|
||||
chunk_word_count = 0
|
||||
@ -181,18 +181,21 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
def _split_by_nltk_sentence(self, doc: Document) -> List[Document]:
|
||||
split_docs = []
|
||||
|
||||
if self.language == 'zh':
|
||||
result = self.chinese_sentence_split(doc.content)
|
||||
if self.language == 'en':
|
||||
result = self.sentence_splitter.split_sentences(
|
||||
doc.content) # type: ignore # None check is done in run()
|
||||
if self.language == "zh":
|
||||
result = ChineseDocumentspliter.chinese_sentence_split(doc.content)
|
||||
if self.language == "en":
|
||||
result = self.sentence_splitter.split_sentences(doc.content) # type: ignore # None check is done in run()
|
||||
|
||||
units = [sentence["sentence"] for sentence in result]
|
||||
|
||||
if self.respect_sentence_boundary:
|
||||
text_splits, splits_pages, splits_start_idxs = self._concatenate_sentences_based_on_word_amount(
|
||||
sentences=units, split_length=self.split_length, split_overlap=self.split_overlap, language=self.language,
|
||||
particle_size=self.particle_size)
|
||||
sentences=units,
|
||||
split_length=self.split_length,
|
||||
split_overlap=self.split_overlap,
|
||||
language=self.language,
|
||||
particle_size=self.particle_size,
|
||||
)
|
||||
else:
|
||||
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
|
||||
elements=units,
|
||||
@ -224,8 +227,7 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
splits_start_idxs: List[int] = []
|
||||
cur_start_idx = 0
|
||||
cur_page = 1
|
||||
segments = windowed(elements, n=split_length,
|
||||
step=split_length - split_overlap)
|
||||
segments = windowed(elements, n=split_length, step=split_length - split_overlap)
|
||||
|
||||
for seg in segments:
|
||||
current_units = [unit for unit in seg if unit is not None]
|
||||
@ -248,8 +250,7 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
if self.split_by == "page":
|
||||
num_page_breaks = len(processed_units)
|
||||
else:
|
||||
num_page_breaks = sum(processed_unit.count("\f")
|
||||
for processed_unit in processed_units)
|
||||
num_page_breaks = sum(processed_unit.count("\f") for processed_unit in processed_units)
|
||||
|
||||
cur_page += num_page_breaks
|
||||
|
||||
@ -282,11 +283,10 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
doc_start_idx = splits_start_idxs[i]
|
||||
previous_doc = documents[i - 1]
|
||||
previous_doc_start_idx = splits_start_idxs[i - 1]
|
||||
self._add_split_overlap_information(
|
||||
doc, doc_start_idx, previous_doc, previous_doc_start_idx)
|
||||
self._add_split_overlap_information(doc, doc_start_idx, previous_doc, previous_doc_start_idx)
|
||||
|
||||
for d in documents:
|
||||
d.content=d.content.replace(" ","")
|
||||
d.content = d.content.replace(" ", "")
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
@ -301,26 +301,24 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
:param previous_doc: The Document that was split before the current Document.
|
||||
:param previous_doc_start_idx: The starting index of the previous Document.
|
||||
"""
|
||||
overlapping_range = (current_doc_start_idx - previous_doc_start_idx,
|
||||
len(previous_doc.content)) # type: ignore
|
||||
overlapping_range = (current_doc_start_idx - previous_doc_start_idx, len(previous_doc.content)) # type: ignore
|
||||
|
||||
if overlapping_range[0] < overlapping_range[1]:
|
||||
# type: ignore
|
||||
overlapping_str = previous_doc.content[overlapping_range[0]: overlapping_range[1]]
|
||||
overlapping_str = previous_doc.content[overlapping_range[0] : overlapping_range[1]]
|
||||
|
||||
if current_doc.content.startswith(overlapping_str): # type: ignore
|
||||
# add split overlap information to this Document regarding the previous Document
|
||||
current_doc.meta["_split_overlap"].append(
|
||||
{"doc_id": previous_doc.id, "range": overlapping_range})
|
||||
current_doc.meta["_split_overlap"].append({"doc_id": previous_doc.id, "range": overlapping_range})
|
||||
|
||||
# add split overlap information to previous Document regarding this Document
|
||||
overlapping_range = (
|
||||
0, overlapping_range[1] - overlapping_range[0])
|
||||
previous_doc.meta["_split_overlap"].append(
|
||||
{"doc_id": current_doc.id, "range": overlapping_range})
|
||||
overlapping_range = (0, overlapping_range[1] - overlapping_range[0])
|
||||
previous_doc.meta["_split_overlap"].append({"doc_id": current_doc.id, "range": overlapping_range})
|
||||
|
||||
@staticmethod
|
||||
def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int, language: str, particle_size: str) -> int:
|
||||
def _number_of_sentences_to_keep(
|
||||
sentences: List[str], split_length: int, split_overlap: int, language: str, particle_size: str
|
||||
) -> int:
|
||||
"""
|
||||
Returns the number of sentences to keep in the next chunk based on the `split_overlap` and `split_length`.
|
||||
|
||||
@ -339,10 +337,10 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
# chinese_tokenizer_fine = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH)
|
||||
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
|
||||
for sent in reversed(sentences[1:]):
|
||||
if language == 'zh' and particle_size == "coarse":
|
||||
if language == "zh" and particle_size == "coarse":
|
||||
num_words += len(chinese_tokenizer_coarse(sent))
|
||||
# num_words += len(sent.split())
|
||||
if language == 'zh' and particle_size == "fine":
|
||||
if language == "zh" and particle_size == "fine":
|
||||
num_words += len(chinese_tokenizer_fine(sent))
|
||||
# If the number of words is larger than the split_length then don't add any more sentences
|
||||
if num_words > split_length:
|
||||
@ -350,5 +348,5 @@ class chinese_DocumentSpliter(DocumentSplitter):
|
||||
num_sentences_to_keep += 1
|
||||
if num_words > split_overlap:
|
||||
break
|
||||
return num_sentences_to_keep
|
||||
|
||||
return num_sentences_to_keep
|
||||
|
Loading…
x
Reference in New Issue
Block a user