diff --git a/olmocr/bench/tests.py b/olmocr/bench/tests.py index 320d31a..758fbed 100644 --- a/olmocr/bench/tests.py +++ b/olmocr/bench/tests.py @@ -5,7 +5,7 @@ import unicodedata from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict, dataclass, field from enum import Enum -from typing import List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Union import numpy as np from bs4 import BeautifulSoup @@ -965,6 +965,45 @@ class MathTest(BasePDFTest): return False, f"No match found for {self.math} anywhere in content" +def load_single_test(data: Union[str, Dict]) -> BasePDFTest: + """ + Load a single test from a JSON line string or JSON object. + + Args: + data: Either a JSON string to parse or a dictionary containing test data. + + Returns: + A test object of the appropriate type. + + Raises: + ValidationError: If the test type is unknown or data is invalid. + json.JSONDecodeError: If the string cannot be parsed as JSON. + """ + # Handle JSON string input + if isinstance(data, str): + data = data.strip() + if not data: + raise ValueError("Empty string provided") + data = json.loads(data) + + # Process the test data + test_type = data.get("type") + if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}: + test = TextPresenceTest(**data) + elif test_type == TestType.ORDER.value: + test = TextOrderTest(**data) + elif test_type == TestType.TABLE.value: + test = TableTest(**data) + elif test_type == TestType.MATH.value: + test = MathTest(**data) + elif test_type == TestType.BASELINE.value: + test = BaselineTest(**data) + else: + raise ValidationError(f"Unknown test type: {test_type}") + + return test + + def load_tests(jsonl_file: str) -> List[BasePDFTest]: """ Load tests from a JSONL file using parallel processing with a ThreadPoolExecutor. @@ -976,7 +1015,7 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]: A list of test objects. """ - def process_line(line_tuple: Tuple[int, str]) -> Optional[Tuple[int, BasePDFTest]]: + def process_line_with_number(line_tuple: Tuple[int, str]) -> Optional[Tuple[int, BasePDFTest]]: """ Process a single line from the JSONL file and return a tuple of (line_number, test object). Returns None for empty lines. @@ -987,20 +1026,7 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]: return None try: - data = json.loads(line) - test_type = data.get("type") - if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}: - test = TextPresenceTest(**data) - elif test_type == TestType.ORDER.value: - test = TextOrderTest(**data) - elif test_type == TestType.TABLE.value: - test = TableTest(**data) - elif test_type == TestType.MATH.value: - test = MathTest(**data) - elif test_type == TestType.BASELINE.value: - test = BaselineTest(**data) - else: - raise ValidationError(f"Unknown test type: {test_type}") + test = load_single_test(line) return (line_number, test) except json.JSONDecodeError as e: print(f"Error parsing JSON on line {line_number}: {e}") @@ -1021,7 +1047,7 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]: # Use a ThreadPoolExecutor to process each line in parallel. with ThreadPoolExecutor(max_workers=min(os.cpu_count() or 1, 64)) as executor: # Submit all tasks concurrently. - futures = {executor.submit(process_line, item): item[0] for item in lines} + futures = {executor.submit(process_line_with_number, item): item[0] for item in lines} # Use tqdm to show progress as futures complete. for future in tqdm(as_completed(futures), total=len(futures), desc="Loading tests"): result = future.result()