| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  | import random | 
					
						
							| 
									
										
										
										
											2024-03-19 12:26:04 +08:00
										 |  |  |  | from collections import Counter | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  | from rag.utils import num_tokens_from_string | 
					
						
							|  |  |  |  | from . import huqie | 
					
						
							|  |  |  |  | import re | 
					
						
							| 
									
										
										
										
											2024-03-01 19:48:01 +08:00
										 |  |  |  | import copy | 
					
						
							| 
									
										
										
										
											2024-01-30 18:28:09 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-19 18:02:53 +08:00
										 |  |  |  | all_codecs = [ | 
					
						
							|  |  |  |  |     'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', | 
					
						
							|  |  |  |  |     'cp037', 'cp273', 'cp424', 'cp437', | 
					
						
							|  |  |  |  |     'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857', | 
					
						
							|  |  |  |  |     'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869', | 
					
						
							|  |  |  |  |     'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125', | 
					
						
							|  |  |  |  |     'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256', | 
					
						
							|  |  |  |  |     'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr', | 
					
						
							|  |  |  |  |     'gb2312', 'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2', | 
					
						
							|  |  |  |  |     'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1', | 
					
						
							|  |  |  |  |     'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7', | 
					
						
							|  |  |  |  |     'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13', | 
					
						
							|  |  |  |  |     'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u', | 
					
						
							|  |  |  |  |     'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman', | 
					
						
							|  |  |  |  |     'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213', | 
					
						
							|  |  |  |  |     'utf_32', 'utf_32_be', 'utf_32_le''utf_16_be', 'utf_16_le', 'utf_7' | 
					
						
							|  |  |  |  | ] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def find_codec(blob): | 
					
						
							|  |  |  |  |     global all_codecs | 
					
						
							|  |  |  |  |     for c in all_codecs: | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             blob.decode(c) | 
					
						
							|  |  |  |  |             return c | 
					
						
							|  |  |  |  |         except Exception as e: | 
					
						
							|  |  |  |  |             pass | 
					
						
							|  |  |  |  |     return "utf-8" | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | BULLET_PATTERN = [[ | 
					
						
							|  |  |  |  |     r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", | 
					
						
							|  |  |  |  |     r"第[零一二三四五六七八九十百0-9]+章", | 
					
						
							|  |  |  |  |     r"第[零一二三四五六七八九十百0-9]+节", | 
					
						
							|  |  |  |  |     r"第[零一二三四五六七八九十百0-9]+条", | 
					
						
							|  |  |  |  |     r"[\((][零一二三四五六七八九十百]+[\))]", | 
					
						
							|  |  |  |  | ], [ | 
					
						
							|  |  |  |  |     r"第[0-9]+章", | 
					
						
							|  |  |  |  |     r"第[0-9]+节", | 
					
						
							| 
									
										
										
										
											2024-03-25 13:11:57 +08:00
										 |  |  |  |     r"[0-9]{,2}[\. 、]", | 
					
						
							|  |  |  |  |     r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]", | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |     r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", | 
					
						
							|  |  |  |  |     r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", | 
					
						
							|  |  |  |  | ], [ | 
					
						
							|  |  |  |  |     r"第[零一二三四五六七八九十百0-9]+章", | 
					
						
							|  |  |  |  |     r"第[零一二三四五六七八九十百0-9]+节", | 
					
						
							|  |  |  |  |     r"[零一二三四五六七八九十百]+[ 、]", | 
					
						
							|  |  |  |  |     r"[\((][零一二三四五六七八九十百]+[\))]", | 
					
						
							|  |  |  |  |     r"[\((][0-9]{,2}[\))]", | 
					
						
							|  |  |  |  | ], [ | 
					
						
							|  |  |  |  |     r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)", | 
					
						
							|  |  |  |  |     r"Chapter (I+V?|VI*|XI|IX|X)", | 
					
						
							|  |  |  |  |     r"Section [0-9]+", | 
					
						
							|  |  |  |  |     r"Article [0-9]+" | 
					
						
							|  |  |  |  | ] | 
					
						
							|  |  |  |  | ] | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-27 14:57:34 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  | def random_choices(arr, k): | 
					
						
							|  |  |  |  |     k = min(len(arr), k) | 
					
						
							|  |  |  |  |     return random.choices(arr, k=k) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-27 14:57:34 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-25 13:11:57 +08:00
										 |  |  |  | def not_bullet(line): | 
					
						
							|  |  |  |  |     patt = [ | 
					
						
							|  |  |  |  |         r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}" | 
					
						
							|  |  |  |  |     ] | 
					
						
							|  |  |  |  |     return any([re.match(r, line) for r in patt]) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  | def bullets_category(sections): | 
					
						
							|  |  |  |  |     global BULLET_PATTERN | 
					
						
							|  |  |  |  |     hits = [0] * len(BULLET_PATTERN) | 
					
						
							|  |  |  |  |     for i, pro in enumerate(BULLET_PATTERN): | 
					
						
							|  |  |  |  |         for sec in sections: | 
					
						
							|  |  |  |  |             for p in pro: | 
					
						
							| 
									
										
										
										
											2024-03-25 13:11:57 +08:00
										 |  |  |  |                 if re.match(p, sec) and not not_bullet(sec): | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |                     hits[i] += 1 | 
					
						
							|  |  |  |  |                     break | 
					
						
							|  |  |  |  |     maxium = 0 | 
					
						
							|  |  |  |  |     res = -1 | 
					
						
							|  |  |  |  |     for i, h in enumerate(hits): | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |         if h <= maxium: | 
					
						
							|  |  |  |  |             continue | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |         res = i | 
					
						
							|  |  |  |  |         maxium = h | 
					
						
							|  |  |  |  |     return res | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def is_english(texts): | 
					
						
							|  |  |  |  |     eng = 0 | 
					
						
							| 
									
										
										
										
											2024-04-07 09:04:32 +08:00
										 |  |  |  |     if not texts: return False | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |     for t in texts: | 
					
						
							|  |  |  |  |         if re.match(r"[a-zA-Z]{2,}", t.strip()): | 
					
						
							|  |  |  |  |             eng += 1 | 
					
						
							|  |  |  |  |     if eng / len(texts) > 0.8: | 
					
						
							|  |  |  |  |         return True | 
					
						
							|  |  |  |  |     return False | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def tokenize(d, t, eng): | 
					
						
							|  |  |  |  |     d["content_with_weight"] = t | 
					
						
							| 
									
										
										
										
											2024-03-19 15:31:47 +08:00
										 |  |  |  |     t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t) | 
					
						
							| 
									
										
										
										
											2024-03-20 16:56:16 +08:00
										 |  |  |  |     d["content_ltks"] = huqie.qie(t) | 
					
						
							|  |  |  |  |     d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 19:21:09 +08:00
										 |  |  |  | def tokenize_chunks(chunks, doc, eng, pdf_parser): | 
					
						
							|  |  |  |  |     res = [] | 
					
						
							|  |  |  |  |     # wrap up as es documents | 
					
						
							|  |  |  |  |     for ck in chunks: | 
					
						
							|  |  |  |  |         if len(ck.strip()) == 0:continue | 
					
						
							|  |  |  |  |         print("--", ck) | 
					
						
							|  |  |  |  |         d = copy.deepcopy(doc) | 
					
						
							|  |  |  |  |         if pdf_parser: | 
					
						
							|  |  |  |  |             try: | 
					
						
							|  |  |  |  |                 d["image"], poss = pdf_parser.crop(ck, need_position=True) | 
					
						
							|  |  |  |  |                 add_positions(d, poss) | 
					
						
							|  |  |  |  |                 ck = pdf_parser.remove_tag(ck) | 
					
						
							|  |  |  |  |             except NotImplementedError as e: | 
					
						
							|  |  |  |  |                 pass | 
					
						
							|  |  |  |  |         tokenize(d, ck, eng) | 
					
						
							|  |  |  |  |         res.append(d) | 
					
						
							|  |  |  |  |     return res | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-01 19:48:01 +08:00
										 |  |  |  | def tokenize_table(tbls, doc, eng, batch_size=10): | 
					
						
							|  |  |  |  |     res = [] | 
					
						
							|  |  |  |  |     # add tables | 
					
						
							| 
									
										
										
										
											2024-03-04 14:42:26 +08:00
										 |  |  |  |     for (img, rows), poss in tbls: | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |         if not rows: | 
					
						
							|  |  |  |  |             continue | 
					
						
							| 
									
										
										
										
											2024-03-04 14:42:26 +08:00
										 |  |  |  |         if isinstance(rows, str): | 
					
						
							|  |  |  |  |             d = copy.deepcopy(doc) | 
					
						
							| 
									
										
										
										
											2024-03-20 16:56:16 +08:00
										 |  |  |  |             tokenize(d, rows, eng) | 
					
						
							| 
									
										
										
										
											2024-03-04 14:42:26 +08:00
										 |  |  |  |             d["content_with_weight"] = rows | 
					
						
							| 
									
										
										
										
											2024-04-07 09:04:32 +08:00
										 |  |  |  |             if img: d["image"] = img | 
					
						
							|  |  |  |  |             if poss: add_positions(d, poss) | 
					
						
							| 
									
										
										
										
											2024-03-04 14:42:26 +08:00
										 |  |  |  |             res.append(d) | 
					
						
							|  |  |  |  |             continue | 
					
						
							| 
									
										
										
										
											2024-03-01 19:48:01 +08:00
										 |  |  |  |         de = "; " if eng else "; " | 
					
						
							|  |  |  |  |         for i in range(0, len(rows), batch_size): | 
					
						
							|  |  |  |  |             d = copy.deepcopy(doc) | 
					
						
							|  |  |  |  |             r = de.join(rows[i:i + batch_size]) | 
					
						
							|  |  |  |  |             tokenize(d, r, eng) | 
					
						
							|  |  |  |  |             d["image"] = img | 
					
						
							| 
									
										
										
										
											2024-03-04 14:42:26 +08:00
										 |  |  |  |             add_positions(d, poss) | 
					
						
							| 
									
										
										
										
											2024-03-01 19:48:01 +08:00
										 |  |  |  |             res.append(d) | 
					
						
							|  |  |  |  |     return res | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-04 14:42:26 +08:00
										 |  |  |  | def add_positions(d, poss): | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |     if not poss: | 
					
						
							|  |  |  |  |         return | 
					
						
							| 
									
										
										
										
											2024-03-04 14:42:26 +08:00
										 |  |  |  |     d["page_num_int"] = [] | 
					
						
							|  |  |  |  |     d["position_int"] = [] | 
					
						
							|  |  |  |  |     d["top_int"] = [] | 
					
						
							|  |  |  |  |     for pn, left, right, top, bottom in poss: | 
					
						
							| 
									
										
										
										
											2024-04-10 16:00:48 +08:00
										 |  |  |  |         d["page_num_int"].append(int(pn + 1)) | 
					
						
							|  |  |  |  |         d["top_int"].append(int(top)) | 
					
						
							|  |  |  |  |         d["position_int"].append((int(pn + 1), int(left), int(right), int(top), int(bottom))) | 
					
						
							| 
									
										
										
										
											2024-03-04 14:42:26 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  | def remove_contents_table(sections, eng=False): | 
					
						
							|  |  |  |  |     i = 0 | 
					
						
							|  |  |  |  |     while i < len(sections): | 
					
						
							|  |  |  |  |         def get(i): | 
					
						
							|  |  |  |  |             nonlocal sections | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |             return (sections[i] if isinstance(sections[i], | 
					
						
							|  |  |  |  |                     type("")) else sections[i][0]).strip() | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", | 
					
						
							|  |  |  |  |                         re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], re.IGNORECASE)): | 
					
						
							|  |  |  |  |             i += 1 | 
					
						
							|  |  |  |  |             continue | 
					
						
							|  |  |  |  |         sections.pop(i) | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |         if i >= len(sections): | 
					
						
							|  |  |  |  |             break | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |         prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2]) | 
					
						
							|  |  |  |  |         while not prefix: | 
					
						
							|  |  |  |  |             sections.pop(i) | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |             if i >= len(sections): | 
					
						
							|  |  |  |  |                 break | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |             prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2]) | 
					
						
							|  |  |  |  |         sections.pop(i) | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |         if i >= len(sections) or not prefix: | 
					
						
							|  |  |  |  |             break | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |         for j in range(i, min(i + 128, len(sections))): | 
					
						
							|  |  |  |  |             if not re.match(prefix, get(j)): | 
					
						
							|  |  |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |             for _ in range(i, j): | 
					
						
							|  |  |  |  |                 sections.pop(i) | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |             break | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def make_colon_as_title(sections): | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |     if not sections: | 
					
						
							|  |  |  |  |         return [] | 
					
						
							|  |  |  |  |     if isinstance(sections[0], type("")): | 
					
						
							|  |  |  |  |         return sections | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |     i = 0 | 
					
						
							|  |  |  |  |     while i < len(sections): | 
					
						
							|  |  |  |  |         txt, layout = sections[i] | 
					
						
							|  |  |  |  |         i += 1 | 
					
						
							|  |  |  |  |         txt = txt.split("@")[0].strip() | 
					
						
							|  |  |  |  |         if not txt: | 
					
						
							|  |  |  |  |             continue | 
					
						
							|  |  |  |  |         if txt[-1] not in "::": | 
					
						
							|  |  |  |  |             continue | 
					
						
							|  |  |  |  |         txt = txt[::-1] | 
					
						
							|  |  |  |  |         arr = re.split(r"([。?!!?;;]| .)", txt) | 
					
						
							|  |  |  |  |         if len(arr) < 2 or len(arr[1]) < 32: | 
					
						
							|  |  |  |  |             continue | 
					
						
							|  |  |  |  |         sections.insert(i - 1, (arr[0][::-1], "title")) | 
					
						
							|  |  |  |  |         i += 1 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-19 12:26:04 +08:00
										 |  |  |  | def title_frequency(bull, sections): | 
					
						
							|  |  |  |  |     bullets_size = len(BULLET_PATTERN[bull]) | 
					
						
							|  |  |  |  |     levels = [bullets_size+1 for _ in range(len(sections))] | 
					
						
							|  |  |  |  |     if not sections or bull < 0: | 
					
						
							|  |  |  |  |         return bullets_size+1, levels | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     for i, (txt, layout) in enumerate(sections): | 
					
						
							|  |  |  |  |         for j, p in enumerate(BULLET_PATTERN[bull]): | 
					
						
							| 
									
										
										
										
											2024-03-25 13:11:57 +08:00
										 |  |  |  |             if re.match(p, txt.strip()) and not not_bullet(txt): | 
					
						
							| 
									
										
										
										
											2024-03-19 12:26:04 +08:00
										 |  |  |  |                 levels[i] = j | 
					
						
							|  |  |  |  |                 break | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]): | 
					
						
							|  |  |  |  |                 levels[i] = bullets_size | 
					
						
							|  |  |  |  |     most_level = bullets_size+1 | 
					
						
							|  |  |  |  |     for l, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1): | 
					
						
							|  |  |  |  |         if l <= bullets_size: | 
					
						
							|  |  |  |  |             most_level = l | 
					
						
							|  |  |  |  |             break | 
					
						
							|  |  |  |  |     return most_level, levels | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def not_title(txt): | 
					
						
							|  |  |  |  |     if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt): | 
					
						
							|  |  |  |  |         return False | 
					
						
							|  |  |  |  |     if len(txt.split(" ")) > 12 or (txt.find(" ") < 0 and len(txt) >= 32): | 
					
						
							|  |  |  |  |         return True | 
					
						
							|  |  |  |  |     return re.search(r"[,;,。;!!]", txt) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  | def hierarchical_merge(bull, sections, depth): | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |     if not sections or bull < 0: | 
					
						
							|  |  |  |  |         return [] | 
					
						
							|  |  |  |  |     if isinstance(sections[0], type("")): | 
					
						
							|  |  |  |  |         sections = [(s, "") for s in sections] | 
					
						
							|  |  |  |  |     sections = [(t, o) for t, o in sections if | 
					
						
							|  |  |  |  |                 t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |     bullets_size = len(BULLET_PATTERN[bull]) | 
					
						
							|  |  |  |  |     levels = [[] for _ in range(bullets_size + 2)] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     for i, (txt, layout) in enumerate(sections): | 
					
						
							|  |  |  |  |         for j, p in enumerate(BULLET_PATTERN[bull]): | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |             if re.match(p, txt.strip()): | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |                 levels[j].append(i) | 
					
						
							|  |  |  |  |                 break | 
					
						
							|  |  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-03-05 12:08:41 +08:00
										 |  |  |  |             if re.search(r"(title|head)", layout) and not not_title(txt): | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |                 levels[bullets_size].append(i) | 
					
						
							|  |  |  |  |             else: | 
					
						
							|  |  |  |  |                 levels[bullets_size + 1].append(i) | 
					
						
							|  |  |  |  |     sections = [t for t, _ in sections] | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     # for s in sections: print("--", s) | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     def binary_search(arr, target): | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |         if not arr: | 
					
						
							|  |  |  |  |             return -1 | 
					
						
							|  |  |  |  |         if target > arr[-1]: | 
					
						
							|  |  |  |  |             return len(arr) - 1 | 
					
						
							|  |  |  |  |         if target < arr[0]: | 
					
						
							|  |  |  |  |             return -1 | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |         s, e = 0, len(arr) | 
					
						
							|  |  |  |  |         while e - s > 1: | 
					
						
							|  |  |  |  |             i = (e + s) // 2 | 
					
						
							|  |  |  |  |             if target > arr[i]: | 
					
						
							|  |  |  |  |                 s = i | 
					
						
							|  |  |  |  |                 continue | 
					
						
							|  |  |  |  |             elif target < arr[i]: | 
					
						
							|  |  |  |  |                 e = i | 
					
						
							|  |  |  |  |                 continue | 
					
						
							|  |  |  |  |             else: | 
					
						
							|  |  |  |  |                 assert False | 
					
						
							|  |  |  |  |         return s | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     cks = [] | 
					
						
							|  |  |  |  |     readed = [False] * len(sections) | 
					
						
							|  |  |  |  |     levels = levels[::-1] | 
					
						
							|  |  |  |  |     for i, arr in enumerate(levels[:depth]): | 
					
						
							|  |  |  |  |         for j in arr: | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |             if readed[j]: | 
					
						
							|  |  |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |             readed[j] = True | 
					
						
							|  |  |  |  |             cks.append([j]) | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |             if i + 1 == len(levels) - 1: | 
					
						
							|  |  |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |             for ii in range(i + 1, len(levels)): | 
					
						
							|  |  |  |  |                 jj = binary_search(levels[ii], j) | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |                 if jj < 0: | 
					
						
							|  |  |  |  |                     continue | 
					
						
							|  |  |  |  |                 if jj > cks[-1][-1]: | 
					
						
							|  |  |  |  |                     cks[-1].pop(-1) | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |                 cks[-1].append(levels[ii][jj]) | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |             for ii in cks[-1]: | 
					
						
							|  |  |  |  |                 readed[ii] = True | 
					
						
							| 
									
										
										
										
											2024-03-05 12:08:41 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |     if not cks: | 
					
						
							|  |  |  |  |         return cks | 
					
						
							| 
									
										
										
										
											2024-03-05 12:08:41 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |     for i in range(len(cks)): | 
					
						
							|  |  |  |  |         cks[i] = [sections[j] for j in cks[i][::-1]] | 
					
						
							|  |  |  |  |         print("--------------\n", "\n* ".join(cks[i])) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 12:08:41 +08:00
										 |  |  |  |     res = [[]] | 
					
						
							|  |  |  |  |     num = [0] | 
					
						
							|  |  |  |  |     for ck in cks: | 
					
						
							|  |  |  |  |         if len(ck) == 1: | 
					
						
							|  |  |  |  |             n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0])) | 
					
						
							|  |  |  |  |             if n + num[-1] < 218: | 
					
						
							|  |  |  |  |                 res[-1].append(ck[0]) | 
					
						
							|  |  |  |  |                 num[-1] += n | 
					
						
							|  |  |  |  |                 continue | 
					
						
							|  |  |  |  |             res.append(ck) | 
					
						
							|  |  |  |  |             num.append(n) | 
					
						
							|  |  |  |  |             continue | 
					
						
							|  |  |  |  |         res.append(ck) | 
					
						
							|  |  |  |  |         num.append(218) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     return res | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"): | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |     if not sections: | 
					
						
							|  |  |  |  |         return [] | 
					
						
							|  |  |  |  |     if isinstance(sections[0], type("")): | 
					
						
							|  |  |  |  |         sections = [(s, "") for s in sections] | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |     cks = [""] | 
					
						
							|  |  |  |  |     tk_nums = [0] | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |     def add_chunk(t, pos): | 
					
						
							|  |  |  |  |         nonlocal cks, tk_nums, delimiter | 
					
						
							|  |  |  |  |         tnum = num_tokens_from_string(t) | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |         if tnum < 8: | 
					
						
							|  |  |  |  |             pos = "" | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |         if tk_nums[-1] > chunk_token_num: | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |             if t.find(pos) < 0: | 
					
						
							|  |  |  |  |                 t += pos | 
					
						
							| 
									
										
										
										
											2024-03-01 19:48:01 +08:00
										 |  |  |  |             cks.append(t) | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |             tk_nums.append(tnum) | 
					
						
							|  |  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |             if cks[-1].find(pos) < 0: | 
					
						
							|  |  |  |  |                 t += pos | 
					
						
							| 
									
										
										
										
											2024-03-01 19:48:01 +08:00
										 |  |  |  |             cks[-1] += t | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |             tk_nums[-1] += tnum | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     for sec, pos in sections: | 
					
						
							| 
									
										
										
										
											2024-03-04 17:08:35 +08:00
										 |  |  |  |         add_chunk(sec, pos) | 
					
						
							|  |  |  |  |         continue | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |         s, e = 0, 1 | 
					
						
							|  |  |  |  |         while e < len(sec): | 
					
						
							|  |  |  |  |             if sec[e] in delimiter: | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |                 add_chunk(sec[s: e + 1], pos) | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  |                 s = e + 1 | 
					
						
							|  |  |  |  |                 e = s + 1 | 
					
						
							|  |  |  |  |             else: | 
					
						
							|  |  |  |  |                 e += 1 | 
					
						
							| 
									
										
										
										
											2024-03-06 09:09:16 +08:00
										 |  |  |  |         if s < e: | 
					
						
							|  |  |  |  |             add_chunk(sec[s: e], pos) | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     return cks |