mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 18:15:44 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			236 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			236 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import bisect
 | |
| from typing import Type
 | |
| 
 | |
| import regex as re
 | |
| from tqdm import tqdm
 | |
| 
 | |
| from .aligners import AlignerRegistry, BaseAligner
 | |
| from .registry import BaseRegistry
 | |
| from .segmenters import BaseSegmenter, SegmenterRegistry
 | |
| 
 | |
| 
 | |
| class TextMetricRegistry(BaseRegistry[Type["BaseTextMetric"]]):
 | |
|     """A registry for text metrics."""
 | |
| 
 | |
| 
 | |
| class BaseTextMetric:
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         super().__init__()
 | |
| 
 | |
|     def compute(self, gold: str, pred: str) -> float:
 | |
|         raise NotImplementedError()
 | |
| 
 | |
|     def batch_compute(self, golds: list[str], preds: list[str]) -> list[float]:
 | |
|         it = tqdm(
 | |
|             zip(golds, preds),
 | |
|             total=min(len(golds), len(preds)),
 | |
|             desc=type(self).__name__,
 | |
|             unit="samples",
 | |
|             unit_scale=True,
 | |
|         )
 | |
|         return [self.compute(gold, pred) for gold, pred in it]
 | |
| 
 | |
| 
 | |
| class BaseTextAlignMetric(BaseTextMetric):
 | |
|     def __init__(
 | |
|         self,
 | |
|         segmenter: str | BaseSegmenter,
 | |
|         aligner: str | BaseAligner = "hirschberg",
 | |
|         aligner_kwargs: dict = {},
 | |
|         segmenter_kwargs: dict = {},
 | |
|         gap_token: str = "▓",
 | |
|         *args,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         if isinstance(segmenter, str):
 | |
|             self.segmenter = SegmenterRegistry.get(segmenter)(segmenter, **segmenter_kwargs)
 | |
|         else:
 | |
|             self.segmenter = segmenter
 | |
| 
 | |
|         if isinstance(aligner, str):
 | |
|             self.aligner = AlignerRegistry.get(aligner)(aligner, **aligner_kwargs)
 | |
|         else:
 | |
|             self.aligner = aligner
 | |
| 
 | |
|         self.gap_token = gap_token
 | |
| 
 | |
|     def segment(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> list[tuple[list[str], list[str]]]:
 | |
|         return [(seq_a_tokens, seq_b_tokens)]
 | |
| 
 | |
|     def align(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> tuple[list[str], list[str]]:
 | |
|         return self.aligner.align(seq_a_tokens, seq_b_tokens)
 | |
| 
 | |
|     def tokenize(self, text: str) -> list[str]:
 | |
|         return [w for w in re.split(r"(\p{P}+|\s+)", text) if w]
 | |
| 
 | |
|     def compute(self, gold: str, pred: str) -> float:
 | |
|         raise NotImplementedError()
 | |
| 
 | |
| 
 | |
| @TextMetricRegistry.add("document_edit_similarity")
 | |
| class DocumentEditSimilarity(BaseTextAlignMetric):
 | |
|     def _score_aligned(self, aligned_gold_tokens: list[str], aligned_pred_tokens: list[str]) -> float:
 | |
|         insertions = deletions = matches = substitutions = 0.0
 | |
|         for gold_symbol, pred_symbol in zip(aligned_gold_tokens, aligned_pred_tokens):
 | |
|             if gold_symbol == self.gap_token:
 | |
|                 insertions += 1
 | |
|             elif pred_symbol == self.gap_token:
 | |
|                 deletions += 1
 | |
|             elif gold_symbol == pred_symbol:
 | |
|                 matches += 1
 | |
|             else:
 | |
|                 substitutions += 1
 | |
| 
 | |
|         if total := insertions + deletions + matches + substitutions:
 | |
|             return matches / total
 | |
|         return 0.0
 | |
| 
 | |
|     def compute(self, gold: str, pred: str) -> float:
 | |
|         gold_tokens = self.tokenize(gold)
 | |
|         pred_tokens = self.tokenize(pred)
 | |
|         aligned_gold_tokens, aligned_pred_tokens = self.align(gold_tokens, pred_tokens)
 | |
|         return self._score_aligned(aligned_gold_tokens, aligned_pred_tokens)
 | |
| 
 | |
| 
 | |
| def find_align_gaps(aligned_text: list[str], gap_token: str = "▓", gap_threshold: int = 3) -> list[int]:
 | |
|     consecutive_gaps_counter = 0
 | |
|     above_threshold_locs: list[int] = []
 | |
| 
 | |
|     for aligned_pos, symbol in enumerate(aligned_text):
 | |
|         if symbol == gap_token:
 | |
|             consecutive_gaps_counter += 1
 | |
|         else:
 | |
|             consecutive_gaps_counter = 0
 | |
| 
 | |
|         if consecutive_gaps_counter >= gap_threshold:
 | |
|             above_threshold_locs.append(aligned_pos)
 | |
|             consecutive_gaps_counter = 0
 | |
| 
 | |
|     return above_threshold_locs
 | |
| 
 | |
| 
 | |
| def make_unaligned_text(tokens: list[str], gap_token: str = "▓") -> str:
 | |
|     return "".join(symbol for symbol in tokens if symbol != gap_token)
 | |
| 
 | |
| 
 | |
| def find_sentences(
 | |
|     tokens: list[str],
 | |
|     sentences: list[str],
 | |
|     gap_token: str = "▓",
 | |
| ):
 | |
|     matches: list[tuple[int, int]] = []
 | |
| 
 | |
|     original_text = ""
 | |
|     original: list[int] = []
 | |
|     original_to_aligned: list[int] = []
 | |
| 
 | |
|     for i, token in enumerate(tokens):
 | |
|         if token != gap_token:
 | |
|             original_text += token
 | |
|             original.append(len(original_text))
 | |
|             original_to_aligned.append(i)
 | |
| 
 | |
|     matches = []
 | |
|     for sentence in sentences:
 | |
|         start_pos = original_text.find(sentence)
 | |
|         if start_pos < 0:
 | |
|             continue
 | |
| 
 | |
|         end_pos = start_pos + len(sentence)
 | |
|         start_token = original_to_aligned[bisect.bisect_left(original, start_pos)]
 | |
|         end_token = original_to_aligned[min(bisect.bisect_right(original, end_pos), len(original) - 1)]
 | |
|         matches.append((start_token, end_token))
 | |
| 
 | |
|     return matches
 | |
| 
 | |
| 
 | |
| def merge_spans(spans: list[tuple[int, int]]) -> list[tuple[int, int]]:
 | |
|     if not spans:
 | |
|         return []
 | |
| 
 | |
|     # Sort spans based on start position
 | |
|     sorted_spans = sorted(spans, key=lambda x: x[0])
 | |
| 
 | |
|     merged = [sorted_spans[0]]
 | |
| 
 | |
|     for current in sorted_spans[1:]:
 | |
|         last = merged[-1]
 | |
| 
 | |
|         # If current span overlaps with last merged span, update the end of last span
 | |
|         if current[0] <= last[1]:
 | |
|             merged[-1] = (last[0], max(last[1], current[1]))
 | |
|         else:
 | |
|             merged.append(current)
 | |
| 
 | |
|     return merged
 | |
| 
 | |
| 
 | |
| def make_sentences_around_gaps(sent_locs: list[tuple[int, int]], gaps_locs: list[int], window: int):
 | |
|     sent_start_only = [start for start, _ in sent_locs]
 | |
| 
 | |
|     sentences_with_gaps = []
 | |
| 
 | |
|     # collect all sentences that are around the gaps
 | |
|     for gap in gaps_locs:
 | |
|         start_idx = bisect.bisect_left(sent_start_only, gap)
 | |
|         fwd_window = max(0, start_idx - window)
 | |
|         bwd_window = min(len(sent_locs) - 1, start_idx + window)
 | |
|         sentences_with_gaps.append((sent_locs[fwd_window][0], sent_locs[bwd_window][-1]))
 | |
| 
 | |
|     # merge overlapping sentences
 | |
|     sentences_with_gaps = merge_spans(sentences_with_gaps)
 | |
| 
 | |
|     return sentences_with_gaps
 | |
| 
 | |
| 
 | |
| @TextMetricRegistry.add("paragraph_edit_similarity")
 | |
| class ParagraphEditSimilarity(DocumentEditSimilarity):
 | |
|     def __init__(
 | |
|         self,
 | |
|         segmenter: str | BaseSegmenter,
 | |
|         aligner: str | BaseAligner = "hirschberg",
 | |
|         aligner_kwargs: dict = {},
 | |
|         segmenter_kwargs: dict = {},
 | |
|         gap_token: str = "▓",
 | |
|         gap_threshold: int = 3,
 | |
|         sent_window: int = 1,
 | |
|         *args,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         super().__init__(
 | |
|             segmenter=segmenter,
 | |
|             aligner=aligner,
 | |
|             aligner_kwargs=aligner_kwargs,
 | |
|             segmenter_kwargs=segmenter_kwargs,
 | |
|             gap_token=gap_token,
 | |
|         )
 | |
|         self.gap_threshold = gap_threshold
 | |
|         self.sent_window = sent_window
 | |
| 
 | |
|     def segment(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> list[tuple[list[str], list[str]]]:
 | |
|         all_spans = []
 | |
| 
 | |
|         for seq_tokens in (seq_a_tokens, seq_b_tokens):
 | |
|             text = make_unaligned_text(tokens=seq_tokens, gap_token=self.gap_token)
 | |
|             sentences = self.segmenter.segment(text)
 | |
| 
 | |
|             sent_locs = find_sentences(tokens=seq_tokens, sentences=sentences, gap_token=self.gap_token)
 | |
|             gaps_locs = find_align_gaps(aligned_text=seq_tokens, gap_token=self.gap_token, gap_threshold=3)
 | |
| 
 | |
|             sentences_with_gaps = make_sentences_around_gaps(sent_locs=sent_locs, gaps_locs=gaps_locs, window=self.sent_window)
 | |
|             all_spans.extend(sentences_with_gaps)
 | |
| 
 | |
|         return [(seq_a_tokens[start:end], seq_b_tokens[start:end]) for start, end in merge_spans(all_spans)]
 | |
| 
 | |
|     def compute(self, gold: str, pred: str) -> float:
 | |
|         gold_tokens = self.tokenize(gold)
 | |
|         pred_tokens = self.tokenize(pred)
 | |
|         aligned_gold_tokens, aligned_pred_tokens = self.align(gold_tokens, pred_tokens)
 | |
| 
 | |
|         scores = []
 | |
|         for gold_segment, pred_segment in self.segment(aligned_gold_tokens, aligned_pred_tokens):
 | |
|             score = self._score_aligned(gold_segment, pred_segment)
 | |
|             scores.append(score)
 | |
| 
 | |
|         return sum(scores) / len(scores) if scores else 1.0
 | 
