Faster eval script

This commit is contained in:
Jake Poznanski 2024-10-10 15:22:33 +00:00
parent 931f48c3d1
commit 3245990216

View File

@ -13,7 +13,7 @@ import sys
import argparse
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple, Dict
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from pathlib import Path
@ -124,34 +124,63 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
# Load every .json file from GOLD_DATA_S3_PATH (and saves it to some temp folder for quick loading next time)
# returns map from "custom_id" ex. "s3://ai2-s2-pdfs/39ce/3db4516cd6e7d7f8e580a494c7a665a6a16a.pdf-4" (where the -4 means page 4)
# to the gold standard text
def load_gold_data(gold_data_path: str) -> dict:
def load_gold_data(gold_data_path: str, max_workers: int = 8) -> dict:
"""
Load gold data from JSONL files in a multithreaded manner.
Args:
gold_data_path (str): Path to the directory containing JSONL files.
max_workers (int, optional): Maximum number of threads to use. Defaults to 8.
Returns:
dict: A dictionary containing gold data entries.
"""
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
gold_data = {}
gold_jsonl_files = list_jsonl_files(gold_data_path)
gold_data: Dict[str, str] = {}
total_errors = 0
total_overruns = 0
gold_errors = 0
gold_overruns = 0
gold_jsonl_files: List[str] = list_jsonl_files(gold_data_path)
def process_file(path: str) -> tuple:
"""Process a single JSONL file and return its data and error counts."""
file_gold_data = {}
file_errors = 0
file_overruns = 0
for path in gold_jsonl_files:
# Load the JSON file
with smart_open(path, 'r') as f:
for line in f:
data = json.loads(line)
data = normalize_json_entry(data)
if data.error is not None:
gold_errors += 1
file_errors += 1
elif data.finish_reason != "stop":
gold_overruns += 1
file_overruns += 1
else:
gold_data[data.goldkey] = data.text
file_gold_data[data.goldkey] = data.text
return file_gold_data, file_errors, file_overruns
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all file processing tasks
futures = [executor.submit(process_file, path) for path in gold_jsonl_files]
# Gather results as they complete
for future in as_completed(futures):
try:
file_gold_data, file_errors, file_overruns = future.result()
gold_data.update(file_gold_data)
total_errors += file_errors
total_overruns += file_overruns
except Exception as e:
print(f"Error processing a file: {e}")
print(f"Loaded {len(gold_data):,} gold data entries for comparison")
print(f"Gold processing errors: {gold_errors}")
print(f"Gold overrun errors: {gold_overruns}")
print(f"Gold processing errors: {total_errors}")
print(f"Gold overrun errors: {total_overruns}")
print("-----------------------------------------------------------")
return gold_data
@ -287,8 +316,8 @@ def do_eval(gold_data_path: str, eval_data_path: str, review_page_name: str, rev
# if len(pd["gold_text"]) < 200 and len(pd["eval_text"]) < 200:
# continue
if "[Error processing this page: overrun" not in pd["eval_text"]:
continue
# if "[Error processing this page: overrun" not in pd["eval_text"]:
# continue
page_eval_data.append(pd)