mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
fixes missing OSS code for Issue #36
This commit is contained in:
parent
d4b902cea2
commit
bd08fdb476
69
olmocr/eval/dolma_refine/aligners.py
Normal file
69
olmocr/eval/dolma_refine/aligners.py
Normal file
@ -0,0 +1,69 @@
|
||||
from typing import Type
|
||||
|
||||
from sequence_align.pairwise import hirschberg, needleman_wunsch
|
||||
|
||||
from .registry import BaseRegistry
|
||||
|
||||
|
||||
class AlignerRegistry(BaseRegistry[Type["BaseAligner"]]):
|
||||
"""A registry for aligners."""
|
||||
|
||||
|
||||
class BaseAligner:
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@AlignerRegistry.add("hirschberg")
|
||||
class HirschbergAligner(BaseAligner):
|
||||
def __init__(
|
||||
self,
|
||||
match_score: float = 1.0,
|
||||
mismatch_score: float = -1.0,
|
||||
indel_score: float = -1.0,
|
||||
gap_token: str = "▓",
|
||||
):
|
||||
self.match_score = match_score
|
||||
self.mismatch_score = mismatch_score
|
||||
self.indel_score = indel_score
|
||||
self.gap_token = gap_token
|
||||
super().__init__()
|
||||
|
||||
def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
|
||||
return hirschberg(
|
||||
gold,
|
||||
pred,
|
||||
match_score=self.match_score,
|
||||
mismatch_score=self.mismatch_score,
|
||||
indel_score=self.indel_score,
|
||||
gap=self.gap_token,
|
||||
)
|
||||
|
||||
|
||||
@AlignerRegistry.add("needleman-wunsch")
|
||||
class NeedlemanWunschAligner(BaseAligner):
|
||||
def __init__(
|
||||
self,
|
||||
match_score: float = 1.0,
|
||||
mismatch_score: float = -1.0,
|
||||
indel_score: float = -1.0,
|
||||
gap_token: str = "▓",
|
||||
):
|
||||
self.match_score = match_score
|
||||
self.mismatch_score = mismatch_score
|
||||
self.indel_score = indel_score
|
||||
self.gap_token = gap_token
|
||||
super().__init__()
|
||||
|
||||
def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
|
||||
return needleman_wunsch(
|
||||
gold,
|
||||
pred,
|
||||
match_score=self.match_score,
|
||||
mismatch_score=self.mismatch_score,
|
||||
indel_score=self.indel_score,
|
||||
gap=self.gap_token,
|
||||
)
|
237
olmocr/eval/dolma_refine/metrics.py
Normal file
237
olmocr/eval/dolma_refine/metrics.py
Normal file
@ -0,0 +1,237 @@
|
||||
import bisect
|
||||
from typing import Type
|
||||
|
||||
import regex as re
|
||||
from tqdm import tqdm
|
||||
|
||||
from .aligners import BaseAligner
|
||||
from .segmenters import BaseSegmenter, SegmenterRegistry
|
||||
from .registry import BaseRegistry
|
||||
|
||||
class TextMetricRegistry(BaseRegistry[Type["BaseTextMetric"]]):
|
||||
"""A registry for text metrics."""
|
||||
|
||||
|
||||
class BaseTextMetric:
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def compute(self, gold: str, pred: str) -> float:
|
||||
raise NotImplementedError()
|
||||
|
||||
def batch_compute(self, golds: list[str], preds: list[str]) -> list[float]:
|
||||
it = tqdm(
|
||||
zip(golds, preds),
|
||||
total=min(len(golds), len(preds)),
|
||||
desc=type(self).__name__,
|
||||
unit="samples",
|
||||
unit_scale=True,
|
||||
)
|
||||
return [self.compute(gold, pred) for gold, pred in it]
|
||||
|
||||
|
||||
class BaseTextAlignMetric(BaseTextMetric):
|
||||
def __init__(
|
||||
self,
|
||||
segmenter: str | BaseSegmenter,
|
||||
aligner: str | BaseAligner = "hirschberg",
|
||||
aligner_kwargs: dict = {},
|
||||
segmenter_kwargs: dict = {},
|
||||
gap_token: str = "▓",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(segmenter, str):
|
||||
self.segmenter = SegmenterRegistry.get(segmenter)(segmenter, **segmenter_kwargs)
|
||||
else:
|
||||
self.segmenter = segmenter
|
||||
|
||||
if isinstance(aligner, str):
|
||||
self.aligner = AlignerRegistry.get(aligner)(aligner, **aligner_kwargs)
|
||||
else:
|
||||
self.aligner = aligner
|
||||
|
||||
self.gap_token = gap_token
|
||||
|
||||
def segment(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> list[tuple[list[str], list[str]]]:
|
||||
return [(seq_a_tokens, seq_b_tokens)]
|
||||
|
||||
def align(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> tuple[list[str], list[str]]:
|
||||
return self.aligner.align(seq_a_tokens, seq_b_tokens)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return [w for w in re.split(r"(\p{P}+|\s+)", text) if w]
|
||||
|
||||
def compute(self, gold: str, pred: str) -> float:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@TextMetricRegistry.add("document_edit_similarity")
|
||||
class DocumentEditSimilarity(BaseTextAlignMetric):
|
||||
def _score_aligned(self, aligned_gold_tokens: list[str], aligned_pred_tokens: list[str]) -> float:
|
||||
insertions = deletions = matches = substitutions = 0.0
|
||||
for gold_symbol, pred_symbol in zip(aligned_gold_tokens, aligned_pred_tokens):
|
||||
if gold_symbol == self.gap_token:
|
||||
insertions += 1
|
||||
elif pred_symbol == self.gap_token:
|
||||
deletions += 1
|
||||
elif gold_symbol == pred_symbol:
|
||||
matches += 1
|
||||
else:
|
||||
substitutions += 1
|
||||
|
||||
if total := insertions + deletions + matches + substitutions:
|
||||
return matches / total
|
||||
return 0.0
|
||||
|
||||
def compute(self, gold: str, pred: str) -> float:
|
||||
gold_tokens = self.tokenize(gold)
|
||||
pred_tokens = self.tokenize(pred)
|
||||
aligned_gold_tokens, aligned_pred_tokens = self.align(gold_tokens, pred_tokens)
|
||||
return self._score_aligned(aligned_gold_tokens, aligned_pred_tokens)
|
||||
|
||||
|
||||
def find_align_gaps(aligned_text: list[str], gap_token: str = "▓", gap_threshold: int = 3) -> list[int]:
|
||||
consecutive_gaps_counter = 0
|
||||
above_threshold_locs: list[int] = []
|
||||
|
||||
for aligned_pos, symbol in enumerate(aligned_text):
|
||||
if symbol == gap_token:
|
||||
consecutive_gaps_counter += 1
|
||||
else:
|
||||
consecutive_gaps_counter = 0
|
||||
|
||||
if consecutive_gaps_counter >= gap_threshold:
|
||||
above_threshold_locs.append(aligned_pos)
|
||||
consecutive_gaps_counter = 0
|
||||
|
||||
return above_threshold_locs
|
||||
|
||||
|
||||
def make_unaligned_text(tokens: list[str], gap_token: str = "▓") -> str:
|
||||
return "".join(symbol for symbol in tokens if symbol != gap_token)
|
||||
|
||||
|
||||
def find_sentences(
|
||||
tokens: list[str],
|
||||
sentences: list[str],
|
||||
gap_token: str = "▓",
|
||||
):
|
||||
matches: list[tuple[int, int]] = []
|
||||
|
||||
original_text = ""
|
||||
original: list[int] = []
|
||||
original_to_aligned: list[int] = []
|
||||
|
||||
for i, token in enumerate(tokens):
|
||||
if token != gap_token:
|
||||
original_text += token
|
||||
original.append(len(original_text))
|
||||
original_to_aligned.append(i)
|
||||
|
||||
matches = []
|
||||
for sentence in sentences:
|
||||
start_pos = original_text.find(sentence)
|
||||
if start_pos < 0:
|
||||
continue
|
||||
|
||||
end_pos = start_pos + len(sentence)
|
||||
start_token = original_to_aligned[bisect.bisect_left(original, start_pos)]
|
||||
end_token = original_to_aligned[min(bisect.bisect_right(original, end_pos), len(original) - 1)]
|
||||
matches.append((start_token, end_token))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def merge_spans(spans: list[tuple[int, int]]) -> list[tuple[int, int]]:
|
||||
if not spans:
|
||||
return []
|
||||
|
||||
# Sort spans based on start position
|
||||
sorted_spans = sorted(spans, key=lambda x: x[0])
|
||||
|
||||
merged = [sorted_spans[0]]
|
||||
|
||||
for current in sorted_spans[1:]:
|
||||
last = merged[-1]
|
||||
|
||||
# If current span overlaps with last merged span, update the end of last span
|
||||
if current[0] <= last[1]:
|
||||
merged[-1] = (last[0], max(last[1], current[1]))
|
||||
else:
|
||||
merged.append(current)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def make_sentences_around_gaps(sent_locs: list[tuple[int, int]], gaps_locs: list[int], window: int):
|
||||
sent_start_only = [start for start, _ in sent_locs]
|
||||
|
||||
sentences_with_gaps = []
|
||||
|
||||
# collect all sentences that are around the gaps
|
||||
for gap in gaps_locs:
|
||||
start_idx = bisect.bisect_left(sent_start_only, gap)
|
||||
fwd_window = max(0, start_idx - window)
|
||||
bwd_window = min(len(sent_locs) - 1, start_idx + window)
|
||||
sentences_with_gaps.append((sent_locs[fwd_window][0], sent_locs[bwd_window][-1]))
|
||||
|
||||
# merge overlapping sentences
|
||||
sentences_with_gaps = merge_spans(sentences_with_gaps)
|
||||
|
||||
return sentences_with_gaps
|
||||
|
||||
|
||||
@TextMetricRegistry.add("paragraph_edit_similarity")
|
||||
class ParagraphEditSimilarity(DocumentEditSimilarity):
|
||||
def __init__(
|
||||
self,
|
||||
segmenter: str | BaseSegmenter,
|
||||
aligner: str | BaseAligner = "hirschberg",
|
||||
aligner_kwargs: dict = {},
|
||||
segmenter_kwargs: dict = {},
|
||||
gap_token: str = "▓",
|
||||
gap_threshold: int = 3,
|
||||
sent_window: int = 1,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
segmenter=segmenter,
|
||||
aligner=aligner,
|
||||
aligner_kwargs=aligner_kwargs,
|
||||
segmenter_kwargs=segmenter_kwargs,
|
||||
gap_token=gap_token,
|
||||
)
|
||||
self.gap_threshold = gap_threshold
|
||||
self.sent_window = sent_window
|
||||
|
||||
def segment(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> list[tuple[list[str], list[str]]]:
|
||||
|
||||
all_spans = []
|
||||
|
||||
for seq_tokens in (seq_a_tokens, seq_b_tokens):
|
||||
text = make_unaligned_text(tokens=seq_tokens, gap_token=self.gap_token)
|
||||
sentences = self.segmenter.segment(text)
|
||||
|
||||
sent_locs = find_sentences(tokens=seq_tokens, sentences=sentences, gap_token=self.gap_token)
|
||||
gaps_locs = find_align_gaps(aligned_text=seq_tokens, gap_token=self.gap_token, gap_threshold=3)
|
||||
|
||||
sentences_with_gaps = make_sentences_around_gaps(
|
||||
sent_locs=sent_locs, gaps_locs=gaps_locs, window=self.sent_window
|
||||
)
|
||||
all_spans.extend(sentences_with_gaps)
|
||||
|
||||
return [(seq_a_tokens[start:end], seq_b_tokens[start:end]) for start, end in merge_spans(all_spans)]
|
||||
|
||||
def compute(self, gold: str, pred: str) -> float:
|
||||
gold_tokens = self.tokenize(gold)
|
||||
pred_tokens = self.tokenize(pred)
|
||||
aligned_gold_tokens, aligned_pred_tokens = self.align(gold_tokens, pred_tokens)
|
||||
|
||||
scores = []
|
||||
for gold_segment, pred_segment in self.segment(aligned_gold_tokens, aligned_pred_tokens):
|
||||
score = self._score_aligned(gold_segment, pred_segment)
|
||||
scores.append(score)
|
||||
|
||||
return sum(scores) / len(scores) if scores else 1.0
|
122
olmocr/eval/dolma_refine/registry.py
Normal file
122
olmocr/eval/dolma_refine/registry.py
Normal file
@ -0,0 +1,122 @@
|
||||
import re
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class BaseRegistry(Generic[T]):
|
||||
"""A registry for objects."""
|
||||
|
||||
_registry_of_registries: Dict[str, Type["BaseRegistry"]] = {}
|
||||
_registry_storage: Dict[str, Tuple[T, Optional[str]]]
|
||||
|
||||
@classmethod
|
||||
def _add_to_registry_of_registries(cls) -> None:
|
||||
name = cls.__name__
|
||||
if name not in cls._registry_of_registries:
|
||||
cls._registry_of_registries[name] = cls
|
||||
|
||||
@classmethod
|
||||
def registries(cls) -> Generator[Tuple[str, Type["BaseRegistry"]], None, None]:
|
||||
"""Yield all registries in the registry of registries."""
|
||||
yield from sorted(cls._registry_of_registries.items())
|
||||
|
||||
@classmethod
|
||||
def _get_storage(cls) -> Dict[str, Tuple[T, Optional[str]]]:
|
||||
if not hasattr(cls, "_registry_storage"):
|
||||
cls._registry_storage = {}
|
||||
return cls._registry_storage # pyright: ignore
|
||||
|
||||
@classmethod
|
||||
def items(cls) -> Generator[Tuple[str, T], None, None]:
|
||||
"""Yield all items in the registry."""
|
||||
yield from sorted((n, t) for (n, (t, _)) in cls._get_storage().items())
|
||||
|
||||
@classmethod
|
||||
def items_with_description(cls) -> Generator[Tuple[str, T, Optional[str]], None, None]:
|
||||
"""Yield all items in the registry with their descriptions."""
|
||||
yield from sorted((n, t, d) for (n, (t, d)) in cls._get_storage().items())
|
||||
|
||||
@classmethod
|
||||
def add(cls, name: str, desc: Optional[str] = None) -> Callable[[R], R]:
|
||||
"""Add a class to the registry."""
|
||||
|
||||
# Add the registry to the registry of registries
|
||||
cls._add_to_registry_of_registries()
|
||||
|
||||
def _add(
|
||||
inner_self: T,
|
||||
inner_name: str = name,
|
||||
inner_desc: Optional[str] = desc,
|
||||
inner_cls: Type[BaseRegistry] = cls,
|
||||
) -> T:
|
||||
"""Add a tagger to the registry using tagger_name as the name."""
|
||||
|
||||
existing = inner_cls.get(inner_name, raise_on_missing=False)
|
||||
|
||||
if existing and existing != inner_self:
|
||||
if inner_self.__module__ == "__main__":
|
||||
return inner_self
|
||||
|
||||
raise ValueError(f"Tagger {inner_name} already exists")
|
||||
inner_cls._get_storage()[inner_name] = (inner_self, inner_desc)
|
||||
return inner_self
|
||||
|
||||
return _add # type: ignore
|
||||
|
||||
@classmethod
|
||||
def remove(cls, name: str) -> bool:
|
||||
"""Remove a tagger from the registry."""
|
||||
if name in cls._get_storage():
|
||||
cls._get_storage().pop(name)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def has(cls, name: str) -> bool:
|
||||
"""Check if a tagger exists in the registry."""
|
||||
return name in cls._get_storage()
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def get(cls, name: str) -> T: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def get(cls, name: str, raise_on_missing: Literal[True]) -> T: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def get(cls, name: str, raise_on_missing: Literal[False]) -> Optional[T]: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str, raise_on_missing: bool = True) -> Optional[T]:
|
||||
"""Get a tagger from the registry; raise ValueError if it doesn't exist."""
|
||||
|
||||
matches = [registered for registered in cls._get_storage() if re.match(registered, name)]
|
||||
|
||||
if len(matches) > 1:
|
||||
raise ValueError(f"Multiple taggers match {name}: {', '.join(matches)}")
|
||||
|
||||
elif len(matches) == 0:
|
||||
if raise_on_missing:
|
||||
tagger_names = ", ".join([tn for tn, _ in cls.items()])
|
||||
raise ValueError(f"Unknown tagger {name}; available taggers: {tagger_names}")
|
||||
return None
|
||||
|
||||
else:
|
||||
name = matches[0]
|
||||
t, _ = cls._get_storage()[name]
|
||||
return t
|
32
olmocr/eval/dolma_refine/segmenters.py
Normal file
32
olmocr/eval/dolma_refine/segmenters.py
Normal file
@ -0,0 +1,32 @@
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
|
||||
from spacy.lang.en import English
|
||||
|
||||
|
||||
from .registry import BaseRegistry
|
||||
|
||||
|
||||
class SegmenterRegistry(BaseRegistry[Type["BaseSegmenter"]]):
|
||||
"""A registry for segmenters."""
|
||||
|
||||
|
||||
class BaseSegmenter:
|
||||
def __init__(self, segmenter_name_or_path: str, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def segment(self, text: str) -> list[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@SegmenterRegistry.add("spacy")
|
||||
class SpacySegmenter(BaseSegmenter):
|
||||
def __init__(self, segmenter_name_or_path: str, *args, **kwargs):
|
||||
assert segmenter_name_or_path == "spacy", "Only 'spacy' segmenter is supported"
|
||||
self.nlp = English()
|
||||
self.nlp.add_pipe("sentencizer")
|
||||
|
||||
def segment(self, text: str) -> list[str]:
|
||||
return [sent.text_with_ws for sent in self.nlp(text).sents]
|
@ -1,8 +1,4 @@
|
||||
# This script will build a set of scores for the accuracy of a given pdf conversion tactic against a gold dataset
|
||||
#
|
||||
# You might need to pip install git+https://github.com/allenai/refine.git@soldni/eval-m
|
||||
# in order to use some of the existing aligner scoring that was developed as part
|
||||
# of the refiner pipeline
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
@ -17,9 +13,9 @@ from typing import Dict, List, Optional
|
||||
|
||||
import boto3
|
||||
import zstandard
|
||||
from dolma_refine.evaluate.aligners import HirschbergAligner
|
||||
from dolma_refine.evaluate.metrics import DocumentEditSimilarity
|
||||
from dolma_refine.evaluate.segmenters import SpacySegmenter
|
||||
from .dolma_refine.aligners import HirschbergAligner
|
||||
from .dolma_refine.metrics import DocumentEditSimilarity
|
||||
from .dolma_refine.segmenters import SpacySegmenter
|
||||
from smart_open import register_compressor, smart_open
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -39,6 +39,7 @@ dependencies = [
|
||||
"transformers>=4.46.2",
|
||||
"fuzzysearch",
|
||||
"rapidfuzz",
|
||||
"sequence_align",
|
||||
"beaker-py",
|
||||
]
|
||||
license = {file = "LICENSE"}
|
||||
@ -72,7 +73,8 @@ dev = [
|
||||
"necessary",
|
||||
"peft",
|
||||
"datasets",
|
||||
"omegaconf"
|
||||
"omegaconf",
|
||||
"spacy",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user