Refactoring of loading tests

This commit is contained in:
Jake Poznanski 2025-08-22 18:31:37 +00:00
parent afac12b839
commit d9789947d5

View File

@ -5,7 +5,7 @@ import unicodedata
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from typing import List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@ -965,29 +965,28 @@ class MathTest(BasePDFTest):
return False, f"No match found for {self.math} anywhere in content" return False, f"No match found for {self.math} anywhere in content"
def load_tests(jsonl_file: str) -> List[BasePDFTest]: def load_single_test(data: Union[str, Dict]) -> BasePDFTest:
""" """
Load tests from a JSONL file using parallel processing with a ThreadPoolExecutor. Load a single test from a JSON line string or JSON object.
Args: Args:
jsonl_file: Path to the JSONL file containing test definitions. data: Either a JSON string to parse or a dictionary containing test data.
Returns: Returns:
A list of test objects. A test object of the appropriate type.
"""
def process_line(line_tuple: Tuple[int, str]) -> Optional[Tuple[int, BasePDFTest]]: Raises:
ValidationError: If the test type is unknown or data is invalid.
json.JSONDecodeError: If the string cannot be parsed as JSON.
""" """
Process a single line from the JSONL file and return a tuple of (line_number, test object). # Handle JSON string input
Returns None for empty lines. if isinstance(data, str):
""" data = data.strip()
line_number, line = line_tuple if not data:
line = line.strip() raise ValueError("Empty string provided")
if not line: data = json.loads(data)
return None
try: # Process the test data
data = json.loads(line)
test_type = data.get("type") test_type = data.get("type")
if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}: if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}:
test = TextPresenceTest(**data) test = TextPresenceTest(**data)
@ -1001,6 +1000,33 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]:
test = BaselineTest(**data) test = BaselineTest(**data)
else: else:
raise ValidationError(f"Unknown test type: {test_type}") raise ValidationError(f"Unknown test type: {test_type}")
return test
def load_tests(jsonl_file: str) -> List[BasePDFTest]:
"""
Load tests from a JSONL file using parallel processing with a ThreadPoolExecutor.
Args:
jsonl_file: Path to the JSONL file containing test definitions.
Returns:
A list of test objects.
"""
def process_line_with_number(line_tuple: Tuple[int, str]) -> Optional[Tuple[int, BasePDFTest]]:
"""
Process a single line from the JSONL file and return a tuple of (line_number, test object).
Returns None for empty lines.
"""
line_number, line = line_tuple
line = line.strip()
if not line:
return None
try:
test = load_single_test(line)
return (line_number, test) return (line_number, test)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f"Error parsing JSON on line {line_number}: {e}") print(f"Error parsing JSON on line {line_number}: {e}")
@ -1021,7 +1047,7 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]:
# Use a ThreadPoolExecutor to process each line in parallel. # Use a ThreadPoolExecutor to process each line in parallel.
with ThreadPoolExecutor(max_workers=min(os.cpu_count() or 1, 64)) as executor: with ThreadPoolExecutor(max_workers=min(os.cpu_count() or 1, 64)) as executor:
# Submit all tasks concurrently. # Submit all tasks concurrently.
futures = {executor.submit(process_line, item): item[0] for item in lines} futures = {executor.submit(process_line_with_number, item): item[0] for item in lines}
# Use tqdm to show progress as futures complete. # Use tqdm to show progress as futures complete.
for future in tqdm(as_completed(futures), total=len(futures), desc="Loading tests"): for future in tqdm(as_completed(futures), total=len(futures), desc="Loading tests"):
result = future.result() result = future.result()