diff --git a/olmocr/bench/tests.py b/olmocr/bench/tests.py index cdcaaba..e06f7d1 100644 --- a/olmocr/bench/tests.py +++ b/olmocr/bench/tests.py @@ -584,6 +584,8 @@ class TextOrderTest(BasePDFTest): raise ValidationError("Before field cannot be empty") if not self.after.strip(): raise ValidationError("After field cannot be empty") + if self.max_diffs > len(self.before) // 2 or self.max_diffs > len(self.after) // 2: + raise ValidationError("Max diffs is too large for this test, greater than 50% of the search string") def run(self, md_content: str) -> Tuple[bool, str]: md_content = normalize_text(md_content) @@ -856,6 +858,7 @@ class BaselineTest(BasePDFTest): """ max_repeats: int = 30 + check_disallowed_characters: bool = True def run(self, content: str) -> Tuple[bool, str]: if len("".join(c for c in content if c.isalnum()).strip()) == 0: @@ -884,8 +887,9 @@ class BaselineTest(BasePDFTest): r"]", flags=re.UNICODE, ) + matches = pattern.findall(content) - if matches: + if self.check_disallowed_characters and matches: return False, f"Text contains disallowed characters {matches}" return True, "" @@ -983,6 +987,8 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]: test = TableTest(**data) elif test_type == TestType.MATH.value: test = MathTest(**data) + elif test_type == TestType.BASELINE.value: + test = BaselineTest(**data) else: raise ValidationError(f"Unknown test type: {test_type}") return (line_number, test)