Test validation

This commit is contained in:
Jake Poznanski 2025-04-02 14:46:07 -07:00
parent 4604b59661
commit 0d94d15341

View File

@ -584,6 +584,8 @@ class TextOrderTest(BasePDFTest):
raise ValidationError("Before field cannot be empty") raise ValidationError("Before field cannot be empty")
if not self.after.strip(): if not self.after.strip():
raise ValidationError("After field cannot be empty") raise ValidationError("After field cannot be empty")
if self.max_diffs > len(self.before) // 2 or self.max_diffs > len(self.after) // 2:
raise ValidationError("Max diffs is too large for this test, greater than 50% of the search string")
def run(self, md_content: str) -> Tuple[bool, str]: def run(self, md_content: str) -> Tuple[bool, str]:
md_content = normalize_text(md_content) md_content = normalize_text(md_content)
@ -856,6 +858,7 @@ class BaselineTest(BasePDFTest):
""" """
max_repeats: int = 30 max_repeats: int = 30
check_disallowed_characters: bool = True
def run(self, content: str) -> Tuple[bool, str]: def run(self, content: str) -> Tuple[bool, str]:
if len("".join(c for c in content if c.isalnum()).strip()) == 0: if len("".join(c for c in content if c.isalnum()).strip()) == 0:
@ -884,8 +887,9 @@ class BaselineTest(BasePDFTest):
r"]", r"]",
flags=re.UNICODE, flags=re.UNICODE,
) )
matches = pattern.findall(content) matches = pattern.findall(content)
if matches: if self.check_disallowed_characters and matches:
return False, f"Text contains disallowed characters {matches}" return False, f"Text contains disallowed characters {matches}"
return True, "" return True, ""
@ -983,6 +987,8 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]:
test = TableTest(**data) test = TableTest(**data)
elif test_type == TestType.MATH.value: elif test_type == TestType.MATH.value:
test = MathTest(**data) test = MathTest(**data)
elif test_type == TestType.BASELINE.value:
test = BaselineTest(**data)
else: else:
raise ValidationError(f"Unknown test type: {test_type}") raise ValidationError(f"Unknown test type: {test_type}")
return (line_number, test) return (line_number, test)