olmocr/pdelfin/train/dataloader.py

100 lines
3.1 KiB
Python
Raw Normal View History

import json
from datasets import load_dataset, Dataset, Features, Value
import boto3
from typing import Dict, Any
import logging
import re
import random
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def list_s3_files(s3_path: str):
"""
Lists files in the specified S3 path that match the glob pattern.
"""
s3 = boto3.client('s3')
match = re.match(r"s3://([^/]+)/(.+)", s3_path)
if not match:
logger.error(f"Invalid S3 path: {s3_path}")
raise ValueError(f"Invalid S3 path: {s3_path}")
bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split('*')[0] # Extract prefix before the wildcard
paginator = s3.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
files = []
pattern = re.compile(prefix_pattern.replace('*', '.*'))
for page in pages:
for obj in page.get('Contents', []):
key = obj['Key']
if pattern.fullmatch(key):
files.append(f"s3://{bucket}/{key}")
return files
def load_jsonl_from_s3(s3_glob_path: str, first_n_files: int=None) -> Dataset:
"""
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
"""
all_s3_files = list_s3_files(s3_glob_path)
if first_n_files:
all_s3_files = all_s3_files[:first_n_files]
# Use datasets library to load JSON files from S3
dataset = load_dataset(
'json',
data_files=all_s3_files,
)
return dataset
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
"""
Extracts necessary fields from a query entry.
"""
custom_id = query.get('custom_id', '')
body = query.get('body', {})
messages = body.get('messages', [])
input_prompt_text = ""
input_prompt_image_base64 = ""
for message in messages:
if message.get('role') != 'user':
continue # We are only interested in user messages
contents = message.get('content', [])
for content_item in contents:
if content_item.get('type') == 'text':
input_prompt_text = content_item.get('text', "")
elif content_item.get('type') == 'image_url':
image_url = content_item.get('image_url', {}).get('url', "")
if image_url.startswith('data:image'):
# Extract base64 part from data URL
try:
base64_data = image_url.split(',', 1)[1]
input_prompt_image_base64 = base64_data
except IndexError:
input_prompt_image_base64 = ""
return {
'custom_id': custom_id,
'input_prompt_text': input_prompt_text,
'input_prompt_image_base64': input_prompt_image_base64
}
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str) -> Dataset:
query_ds = load_jsonl_from_s3(query_glob_path)
response_ds = load_jsonl_from_s3(response_glob_path)
# Now merge them based on the custom_id field
return query_ds