mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-27 23:24:59 +00:00
Faster eval script
This commit is contained in:
parent
931f48c3d1
commit
3245990216
@ -13,7 +13,7 @@ import sys
|
||||
import argparse
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Dict
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
@ -124,34 +124,63 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
|
||||
# Load every .json file from GOLD_DATA_S3_PATH (and saves it to some temp folder for quick loading next time)
|
||||
# returns map from "custom_id" ex. "s3://ai2-s2-pdfs/39ce/3db4516cd6e7d7f8e580a494c7a665a6a16a.pdf-4" (where the -4 means page 4)
|
||||
# to the gold standard text
|
||||
def load_gold_data(gold_data_path: str) -> dict:
|
||||
def load_gold_data(gold_data_path: str, max_workers: int = 8) -> dict:
|
||||
"""
|
||||
Load gold data from JSONL files in a multithreaded manner.
|
||||
|
||||
Args:
|
||||
gold_data_path (str): Path to the directory containing JSONL files.
|
||||
max_workers (int, optional): Maximum number of threads to use. Defaults to 8.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing gold data entries.
|
||||
"""
|
||||
if not os.path.exists(CACHE_DIR):
|
||||
os.makedirs(CACHE_DIR)
|
||||
|
||||
gold_data = {}
|
||||
|
||||
gold_jsonl_files = list_jsonl_files(gold_data_path)
|
||||
gold_data: Dict[str, str] = {}
|
||||
total_errors = 0
|
||||
total_overruns = 0
|
||||
|
||||
gold_errors = 0
|
||||
gold_overruns = 0
|
||||
gold_jsonl_files: List[str] = list_jsonl_files(gold_data_path)
|
||||
|
||||
def process_file(path: str) -> tuple:
|
||||
"""Process a single JSONL file and return its data and error counts."""
|
||||
file_gold_data = {}
|
||||
file_errors = 0
|
||||
file_overruns = 0
|
||||
|
||||
for path in gold_jsonl_files:
|
||||
# Load the JSON file
|
||||
with smart_open(path, 'r') as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
data = normalize_json_entry(data)
|
||||
|
||||
if data.error is not None:
|
||||
gold_errors += 1
|
||||
file_errors += 1
|
||||
elif data.finish_reason != "stop":
|
||||
gold_overruns += 1
|
||||
file_overruns += 1
|
||||
else:
|
||||
gold_data[data.goldkey] = data.text
|
||||
file_gold_data[data.goldkey] = data.text
|
||||
|
||||
return file_gold_data, file_errors, file_overruns
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all file processing tasks
|
||||
futures = [executor.submit(process_file, path) for path in gold_jsonl_files]
|
||||
|
||||
# Gather results as they complete
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
file_gold_data, file_errors, file_overruns = future.result()
|
||||
gold_data.update(file_gold_data)
|
||||
total_errors += file_errors
|
||||
total_overruns += file_overruns
|
||||
except Exception as e:
|
||||
print(f"Error processing a file: {e}")
|
||||
|
||||
print(f"Loaded {len(gold_data):,} gold data entries for comparison")
|
||||
print(f"Gold processing errors: {gold_errors}")
|
||||
print(f"Gold overrun errors: {gold_overruns}")
|
||||
print(f"Gold processing errors: {total_errors}")
|
||||
print(f"Gold overrun errors: {total_overruns}")
|
||||
print("-----------------------------------------------------------")
|
||||
|
||||
return gold_data
|
||||
@ -287,8 +316,8 @@ def do_eval(gold_data_path: str, eval_data_path: str, review_page_name: str, rev
|
||||
|
||||
# if len(pd["gold_text"]) < 200 and len(pd["eval_text"]) < 200:
|
||||
# continue
|
||||
if "[Error processing this page: overrun" not in pd["eval_text"]:
|
||||
continue
|
||||
# if "[Error processing this page: overrun" not in pd["eval_text"]:
|
||||
# continue
|
||||
|
||||
page_eval_data.append(pd)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user