olmocr/pdelfin/train/dataloader.py

160 lines
5.7 KiB
Python
Raw Normal View History

import json
import logging
import tempfile
2024-09-18 22:52:42 +00:00
import re
import os
import base64
2024-10-07 07:49:16 -07:00
import glob
2024-10-15 15:13:25 +00:00
import pypdf, pypdf.errors
2024-09-23 09:40:24 -07:00
from functools import partial
2024-09-23 09:40:24 -07:00
from typing import Any, Dict, Optional
from logging import Logger
from filelock import FileLock
2024-09-18 22:52:42 +00:00
import boto3
2024-09-23 09:43:36 -07:00
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
2024-09-23 09:40:24 -07:00
from .core.config import DataConfig, SourceConfig
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
2024-10-16 18:26:25 +00:00
# Quiet logs from pypdf and smart open
logging.getLogger("pypdf").setLevel(logging.ERROR)
logging.getLogger("smart_open").setLevel(logging.ERROR)
def list_dataset_files(s3_glob_path: str):
"""
Lists files in the specified S3 path that match the glob pattern.
"""
if s3_glob_path.startswith("s3://"):
2024-10-07 07:49:16 -07:00
s3 = boto3.client("s3")
match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path)
2024-10-07 07:49:16 -07:00
if not match:
logger.error(f"Invalid S3 path: {s3_glob_path}")
raise ValueError(f"Invalid S3 path: {s3_glob_path}")
2024-10-07 07:49:16 -07:00
bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
files = []
pattern = re.compile(prefix_pattern.replace("*", ".*"))
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if pattern.fullmatch(key):
files.append(f"s3://{bucket}/{key}")
return files
else:
return glob.glob(s3_glob_path)
2024-10-07 07:49:16 -07:00
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset:
"""
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
"""
2024-10-07 15:25:48 -07:00
all_json_files = list_dataset_files(s3_glob_path)
if first_n_files:
2024-10-07 07:49:16 -07:00
all_json_files = all_json_files[:first_n_files]
2024-09-18 22:52:42 +00:00
# Use datasets library to load JSON files from S3
dataset = load_dataset(
"json",
2024-10-07 07:49:16 -07:00
data_files=all_json_files,
)
return dataset
2024-09-18 22:52:42 +00:00
def extract_openai_batch_response(example):
2024-09-18 22:52:42 +00:00
custom_id = example.get("custom_id", None)
# Parse the custom id into an s3 document path and page number (1indexed)
s3_path, page_num = parse_custom_id(custom_id)
2024-09-18 22:52:42 +00:00
response_body = example.get("response", {}).get("body", {})
choices = response_body.get("choices", [])
response = ""
finish_reason = ""
if choices:
first_choice = choices[0]
2024-09-18 22:52:42 +00:00
message = first_choice.get("message", {})
response = message.get("content", "")
finish_reason = first_choice.get("finish_reason", "")
# TODO Maybe in the future we can parse the response (which is a structured JSON document itself)
# into its own columns
return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason}
def _cache_s3_file(s3_path: str, local_cache_dir: str):
"""
Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file.
"""
bucket, key = parse_s3_path(s3_path)
# Define the local file path
local_file_path = os.path.join(local_cache_dir, key.replace("/", "_"))
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
lock_file = f"{local_file_path}.lock"
# Use a file lock to prevent concurrent writes
with FileLock(lock_file):
if not os.path.exists(local_file_path):
logger.info(f"Downloading {s3_path} to {local_file_path}")
s3_client = boto3.client(
's3',
aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY')
)
s3_client.download_file(bucket, key, local_file_path)
else:
logger.info(f"File {local_file_path} already exists, skipping download.")
return local_file_path
def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset:
"""
Caches all S3 paths in the dataset to the local cache directory.
"""
# Define the download function to use in parallel processing
def cache_file(example):
s3_path = example["s3_path"]
if s3_path:
# Download the file and cache it locally
local_path = _cache_s3_file(s3_path, pdf_cache_location)
return {"local_pdf_path": local_path}
return {"local_pdf_path": None}
# Map the caching function to the dataset (with parallelism if needed)
dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False)
return dataset
2024-09-18 22:52:42 +00:00
def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str]=None, num_proc: int=32) -> Dataset:
if pdf_cache_location is None:
pdf_cache_location = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin_pdfs')
logger.info("Loading fine tuning dataset from OpenAI style batch responses")
2024-10-07 07:49:16 -07:00
response_data = load_jsonl_into_ds(response_glob_path)
response_data = response_data["train"]
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc)
# Don't include data where the model cut off due to a length issue, or moderation issue
2024-10-15 15:13:25 +00:00
logger.info("Filtering on finish_reason == stop")
final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
# Cache all the s3_paths that were accessed to a local storage location,
final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc)
return final_dataset