Refactoring of loading tests

This commit is contained in:
Jake Poznanski 2025-08-22 18:31:37 +00:00
parent afac12b839
commit d9789947d5

View File

@ -5,7 +5,7 @@ import unicodedata
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np
from bs4 import BeautifulSoup
@ -965,6 +965,45 @@ class MathTest(BasePDFTest):
return False, f"No match found for {self.math} anywhere in content"
def load_single_test(data: Union[str, Dict]) -> BasePDFTest:
"""
Load a single test from a JSON line string or JSON object.
Args:
data: Either a JSON string to parse or a dictionary containing test data.
Returns:
A test object of the appropriate type.
Raises:
ValidationError: If the test type is unknown or data is invalid.
json.JSONDecodeError: If the string cannot be parsed as JSON.
"""
# Handle JSON string input
if isinstance(data, str):
data = data.strip()
if not data:
raise ValueError("Empty string provided")
data = json.loads(data)
# Process the test data
test_type = data.get("type")
if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}:
test = TextPresenceTest(**data)
elif test_type == TestType.ORDER.value:
test = TextOrderTest(**data)
elif test_type == TestType.TABLE.value:
test = TableTest(**data)
elif test_type == TestType.MATH.value:
test = MathTest(**data)
elif test_type == TestType.BASELINE.value:
test = BaselineTest(**data)
else:
raise ValidationError(f"Unknown test type: {test_type}")
return test
def load_tests(jsonl_file: str) -> List[BasePDFTest]:
"""
Load tests from a JSONL file using parallel processing with a ThreadPoolExecutor.
@ -976,7 +1015,7 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]:
A list of test objects.
"""
def process_line(line_tuple: Tuple[int, str]) -> Optional[Tuple[int, BasePDFTest]]:
def process_line_with_number(line_tuple: Tuple[int, str]) -> Optional[Tuple[int, BasePDFTest]]:
"""
Process a single line from the JSONL file and return a tuple of (line_number, test object).
Returns None for empty lines.
@ -987,20 +1026,7 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]:
return None
try:
data = json.loads(line)
test_type = data.get("type")
if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}:
test = TextPresenceTest(**data)
elif test_type == TestType.ORDER.value:
test = TextOrderTest(**data)
elif test_type == TestType.TABLE.value:
test = TableTest(**data)
elif test_type == TestType.MATH.value:
test = MathTest(**data)
elif test_type == TestType.BASELINE.value:
test = BaselineTest(**data)
else:
raise ValidationError(f"Unknown test type: {test_type}")
test = load_single_test(line)
return (line_number, test)
except json.JSONDecodeError as e:
print(f"Error parsing JSON on line {line_number}: {e}")
@ -1021,7 +1047,7 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]:
# Use a ThreadPoolExecutor to process each line in parallel.
with ThreadPoolExecutor(max_workers=min(os.cpu_count() or 1, 64)) as executor:
# Submit all tasks concurrently.
futures = {executor.submit(process_line, item): item[0] for item in lines}
futures = {executor.submit(process_line_with_number, item): item[0] for item in lines}
# Use tqdm to show progress as futures complete.
for future in tqdm(as_completed(futures), total=len(futures), desc="Loading tests"):
result = future.result()