olmocr/pdelfin/train/dataloader.py
2024-10-17 02:28:43 +00:00

173 lines
6.4 KiB
Python

import json
import logging
import tempfile
import re
import os
import base64
import glob
import pypdf, pypdf.errors
from functools import partial
from typing import Any, Dict, Optional
from logging import Logger
from filelock import FileLock
import boto3
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
from .core.config import DataConfig, SourceConfig
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path
from pdelfin.data.renderpdf import get_pdf_media_box_width_height
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Quiet logs from pypdf and smart open
logging.getLogger("pypdf").setLevel(logging.ERROR)
logging.getLogger("smart_open").setLevel(logging.ERROR)
def list_dataset_files(s3_glob_path: str):
"""
Lists files in the specified S3 path that match the glob pattern.
"""
if s3_glob_path.startswith("s3://"):
s3 = boto3.client("s3")
match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path)
if not match:
logger.error(f"Invalid S3 path: {s3_glob_path}")
raise ValueError(f"Invalid S3 path: {s3_glob_path}")
bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
files = []
pattern = re.compile(prefix_pattern.replace("*", ".*"))
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if pattern.fullmatch(key):
files.append(f"s3://{bucket}/{key}")
return files
else:
return glob.glob(s3_glob_path)
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset:
"""
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
"""
all_json_files = list_dataset_files(s3_glob_path)
if first_n_files:
all_json_files = all_json_files[:first_n_files]
# Use datasets library to load JSON files from S3
dataset = load_dataset(
"json",
data_files=all_json_files,
)
return dataset
def extract_openai_batch_response(example):
custom_id = example.get("custom_id", None)
# Parse the custom id into an s3 document path and page number (1indexed)
s3_path, page_num = parse_custom_id(custom_id)
response_body = example.get("response", {}).get("body", {})
choices = response_body.get("choices", [])
response = ""
finish_reason = ""
if choices:
first_choice = choices[0]
message = first_choice.get("message", {})
response = message.get("content", "")
finish_reason = first_choice.get("finish_reason", "")
# TODO Maybe in the future we can parse the response (which is a structured JSON document itself)
# into its own columns
return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason}
def _cache_s3_file(s3_path: str, local_cache_dir: str):
"""
Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file.
"""
bucket, key = parse_s3_path(s3_path)
# Define the local file path
local_file_path = os.path.join(local_cache_dir, bucket + "__" + key.replace("/", "_"))
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
lock_file = f"{local_file_path}.lock"
# Use a file lock to prevent concurrent writes
with FileLock(lock_file):
if not os.path.exists(local_file_path):
logger.info(f"Downloading {s3_path} to {local_file_path}")
s3_client = boto3.client(
's3',
aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY')
)
s3_client.download_file(bucket, key, local_file_path)
else:
pass
#logger.info(f"File {local_file_path} already exists, skipping download.")
return local_file_path
def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset:
"""
Caches all S3 paths in the dataset to the local cache directory.
"""
# Define the download function to use in parallel processing
def cache_file(example):
s3_path = example["s3_path"]
if s3_path:
# Download the file and cache it locally
local_path = _cache_s3_file(s3_path, pdf_cache_location)
return {"local_pdf_path": local_path}
return {"local_pdf_path": None}
# Map the caching function to the dataset (with parallelism if needed)
dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False)
return dataset
def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str]=None, num_proc: int=32) -> Dataset:
if pdf_cache_location is None:
pdf_cache_location = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin_pdfs')
logger.info("Loading fine tuning dataset from OpenAI style batch responses")
response_data = load_jsonl_into_ds(response_glob_path)
response_data = response_data["train"]
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc)
# Don't include data where the model cut off due to a length issue, or moderation issue
logger.info("Filtering on finish_reason == stop")
final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
# Cache all the s3_paths that were accessed to a local storage location,
final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc)
# Filter out pages where you cannot get an anchor text generated, to prevent errors during actual training
def _can_create_anchor_text(example):
try:
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=4000)
_ = get_pdf_media_box_width_height(example["local_pdf_path"], example["page_num"])
return anchor_text is not None
except:
return False
final_dataset = final_dataset.filter(_can_create_anchor_text, num_proc=num_proc)
return final_dataset