Loading tests much faster in parallel

This commit is contained in:
Jake Poznanski 2025-03-13 10:20:09 -07:00
parent 7729e5a9d7
commit 980121feea
2 changed files with 62 additions and 35 deletions

2
.gitignore vendored
View File

@ -11,6 +11,8 @@ sample200_vllm/*
sample200_sglang/*
pdelfin_testset/*
localworkspace/*
math_data/*
math_data_big/*
gpt4otestset/*
gpt4otestset_output/*
pdfs/*

View File

@ -11,7 +11,9 @@ from fuzzysearch import find_near_matches
from rapidfuzz import fuzz
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from olmocr.repeatdetect import RepeatDetector
from .katex.render import render_equation, compare_rendered_equations
class TestType(str, Enum):
@ -544,10 +546,9 @@ class MathTest(BasePDFTest):
return False, f"No match found for {self.math} anywhere in content, best match threshold was {best_match_score:.3f}"
def load_tests(jsonl_file: str) -> List[BasePDFTest]:
"""
Load tests from a JSONL file.
Load tests from a JSONL file using parallel processing with a ThreadPoolExecutor.
Args:
jsonl_file: Path to the JSONL file containing test definitions.
@ -555,43 +556,67 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]:
Returns:
A list of test objects.
"""
tests: List[BasePDFTest] = []
unique_ids = set()
with open(jsonl_file, "r") as file:
for line_number, line in tqdm(enumerate(file, start=1), desc="Loading tests"):
line = line.strip()
if not line:
continue
# First, count the total number of lines for an accurate progress bar.
with open(jsonl_file, "r") as f:
total_lines = sum(1 for _ in f)
try:
data = json.loads(line)
test_type = data.get("type")
if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}:
test = TextPresenceTest(**data)
elif test_type == TestType.ORDER.value:
test = TextOrderTest(**data)
elif test_type == TestType.TABLE.value:
test = TableTest(**data)
elif test_type == TestType.MATH.value:
test = MathTest(**data)
else:
raise ValidationError(f"Unknown test type: {test_type}")
def process_line(line_tuple: Tuple[int, str]) -> Optional[Tuple[int, BasePDFTest]]:
"""
Process a single line from the JSONL file and return a tuple of (line_number, test object).
Returns None for empty lines.
"""
line_number, line = line_tuple
line = line.strip()
if not line:
return None
if test.id in unique_ids:
raise ValidationError(f"Test with duplicate id {test.id} found, error loading tests.")
else:
unique_ids.add(test.id)
try:
data = json.loads(line)
test_type = data.get("type")
if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}:
test = TextPresenceTest(**data)
elif test_type == TestType.ORDER.value:
test = TextOrderTest(**data)
elif test_type == TestType.TABLE.value:
test = TableTest(**data)
elif test_type == TestType.MATH.value:
test = MathTest(**data)
else:
raise ValidationError(f"Unknown test type: {test_type}")
return (line_number, test)
except json.JSONDecodeError as e:
print(f"Error parsing JSON on line {line_number}: {e}")
raise
except (ValidationError, KeyError) as e:
print(f"Error on line {line_number}: {e}")
raise
except Exception as e:
print(f"Unexpected error on line {line_number}: {e}")
raise
tests = []
# Read all lines along with their line numbers.
with open(jsonl_file, "r") as f:
lines = list(enumerate(f, start=1))
# Use a ThreadPoolExecutor to process each line in parallel.
with ThreadPoolExecutor() as executor:
# Submit all tasks concurrently.
futures = {executor.submit(process_line, item): item[0] for item in lines}
# Use tqdm to show progress as futures complete.
for future in tqdm(as_completed(futures), total=len(futures), desc="Loading tests"):
result = future.result()
if result is not None:
_, test = result
tests.append(test)
except json.JSONDecodeError as e:
print(f"Error parsing JSON on line {line_number}: {e}")
raise
except (ValidationError, KeyError) as e:
print(f"Error on line {line_number}: {e}")
raise
except Exception as e:
print(f"Unexpected error on line {line_number}: {e}")
raise
# Check for duplicate test IDs after parallel processing.
unique_ids = set()
for test in tests:
if test.id in unique_ids:
raise ValidationError(f"Test with duplicate id {test.id} found, error loading tests.")
unique_ids.add(test.id)
return tests