diff --git a/.gitignore b/.gitignore index 19a020e..1439958 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,8 @@ sample200_vllm/* sample200_sglang/* pdelfin_testset/* localworkspace/* +math_data/* +math_data_big/* gpt4otestset/* gpt4otestset_output/* pdfs/* diff --git a/olmocr/bench/tests.py b/olmocr/bench/tests.py index 161e34e..06a383e 100644 --- a/olmocr/bench/tests.py +++ b/olmocr/bench/tests.py @@ -11,7 +11,9 @@ from fuzzysearch import find_near_matches from rapidfuzz import fuzz from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed from olmocr.repeatdetect import RepeatDetector + from .katex.render import render_equation, compare_rendered_equations class TestType(str, Enum): @@ -544,10 +546,9 @@ class MathTest(BasePDFTest): return False, f"No match found for {self.math} anywhere in content, best match threshold was {best_match_score:.3f}" - def load_tests(jsonl_file: str) -> List[BasePDFTest]: """ - Load tests from a JSONL file. + Load tests from a JSONL file using parallel processing with a ThreadPoolExecutor. Args: jsonl_file: Path to the JSONL file containing test definitions. @@ -555,43 +556,67 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]: Returns: A list of test objects. """ - tests: List[BasePDFTest] = [] - unique_ids = set() - with open(jsonl_file, "r") as file: - for line_number, line in tqdm(enumerate(file, start=1), desc="Loading tests"): - line = line.strip() - if not line: - continue + # First, count the total number of lines for an accurate progress bar. + with open(jsonl_file, "r") as f: + total_lines = sum(1 for _ in f) - try: - data = json.loads(line) - test_type = data.get("type") - if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}: - test = TextPresenceTest(**data) - elif test_type == TestType.ORDER.value: - test = TextOrderTest(**data) - elif test_type == TestType.TABLE.value: - test = TableTest(**data) - elif test_type == TestType.MATH.value: - test = MathTest(**data) - else: - raise ValidationError(f"Unknown test type: {test_type}") + def process_line(line_tuple: Tuple[int, str]) -> Optional[Tuple[int, BasePDFTest]]: + """ + Process a single line from the JSONL file and return a tuple of (line_number, test object). + Returns None for empty lines. + """ + line_number, line = line_tuple + line = line.strip() + if not line: + return None - if test.id in unique_ids: - raise ValidationError(f"Test with duplicate id {test.id} found, error loading tests.") - else: - unique_ids.add(test.id) + try: + data = json.loads(line) + test_type = data.get("type") + if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}: + test = TextPresenceTest(**data) + elif test_type == TestType.ORDER.value: + test = TextOrderTest(**data) + elif test_type == TestType.TABLE.value: + test = TableTest(**data) + elif test_type == TestType.MATH.value: + test = MathTest(**data) + else: + raise ValidationError(f"Unknown test type: {test_type}") + return (line_number, test) + except json.JSONDecodeError as e: + print(f"Error parsing JSON on line {line_number}: {e}") + raise + except (ValidationError, KeyError) as e: + print(f"Error on line {line_number}: {e}") + raise + except Exception as e: + print(f"Unexpected error on line {line_number}: {e}") + raise + tests = [] + + # Read all lines along with their line numbers. + with open(jsonl_file, "r") as f: + lines = list(enumerate(f, start=1)) + + # Use a ThreadPoolExecutor to process each line in parallel. + with ThreadPoolExecutor() as executor: + # Submit all tasks concurrently. + futures = {executor.submit(process_line, item): item[0] for item in lines} + # Use tqdm to show progress as futures complete. + for future in tqdm(as_completed(futures), total=len(futures), desc="Loading tests"): + result = future.result() + if result is not None: + _, test = result tests.append(test) - except json.JSONDecodeError as e: - print(f"Error parsing JSON on line {line_number}: {e}") - raise - except (ValidationError, KeyError) as e: - print(f"Error on line {line_number}: {e}") - raise - except Exception as e: - print(f"Unexpected error on line {line_number}: {e}") - raise + + # Check for duplicate test IDs after parallel processing. + unique_ids = set() + for test in tests: + if test.id in unique_ids: + raise ValidationError(f"Test with duplicate id {test.id} found, error loading tests.") + unique_ids.add(test.id) return tests