olmocr/pdelfin/train/dataloader.py

157 lines
5.4 KiB
Python
Raw Normal View History

import json
from datasets import load_dataset, Dataset, Features, Value
import boto3
from typing import Dict, Any
import logging
import re
import multiprocessing
from functools import partial
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def list_s3_files(s3_path: str):
"""
Lists files in the specified S3 path that match the glob pattern.
"""
s3 = boto3.client("s3")
match = re.match(r"s3://([^/]+)/(.+)", s3_path)
if not match:
logger.error(f"Invalid S3 path: {s3_path}")
raise ValueError(f"Invalid S3 path: {s3_path}")
bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
files = []
pattern = re.compile(prefix_pattern.replace("*", ".*"))
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if pattern.fullmatch(key):
files.append(f"s3://{bucket}/{key}")
return files
def load_jsonl_from_s3(s3_glob_path: str, first_n_files: int=None) -> Dataset:
"""
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
"""
all_s3_files = list_s3_files(s3_glob_path)
if first_n_files:
all_s3_files = all_s3_files[:first_n_files]
# Use datasets library to load JSON files from S3
dataset = load_dataset(
"json",
data_files=all_s3_files,
)
return dataset
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
"""
Extracts necessary fields from a query entry passed to openai's batch API for vision LMs
"""
custom_id = query.get("custom_id", "")
body = query.get("body", {})
messages = body.get("messages", [])
input_prompt_text = ""
input_prompt_image_base64 = ""
for message in messages:
if message.get("role") != "user":
continue # We are only interested in user messages
contents = message.get("content", [])
for content_item in contents:
if content_item.get("type") == "text":
input_prompt_text = content_item.get("text", "")
elif content_item.get("type") == "image_url":
image_url = content_item.get("image_url", {}).get("url", "")
if image_url.startswith("data:image"):
# Extract base64 part from data URL
try:
base64_data = image_url.split(",", 1)[1]
input_prompt_image_base64 = base64_data
except IndexError:
input_prompt_image_base64 = ""
return {
'custom_id': custom_id,
'input_prompt_text': input_prompt_text,
'input_prompt_image_base64': input_prompt_image_base64
}
def extract_openai_batch_response(example):
custom_id = example.get('custom_id', None)
response_body = example.get('response', {}).get('body', {})
choices = response_body.get('choices', [])
response = ''
finish_reason = ''
if choices:
first_choice = choices[0]
message = first_choice.get('message', {})
response = message.get('content', '')
finish_reason = first_choice.get('finish_reason', '')
return {
'custom_id': custom_id,
'response': response,
"finish_reason": finish_reason
}
def merge_query_response(query_example, response_data: Dataset, response_map: dict[str, int]):
custom_id = query_example["custom_id"]
if custom_id not in response_map:
return {
"response": None,
"finish_reason": None,
}
response_row = response_data[response_map[custom_id]]
return {
"response": response_row["response"],
"finish_reason": response_row["finish_reason"]
}
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str) -> Dataset:
logger.info("Loading query and response datasets")
query_data = load_jsonl_from_s3(query_glob_path)
response_data = load_jsonl_from_s3(response_glob_path)
# Map the datasets down to the core fields that we're going to need to make them easier to process
logger.info("Mapping query data")
query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
logger.info("Mapping response data")
response_data = response_data["train"]
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names)
# What we're going to do, is build an in-memory map for the response data from custom_id to row
# This will let us do quick lookups when we do a merge step, but it will not scale past a certain point
logger.info("Building custom_id to row map")
custom_id_to_response_row = {}
for row_id, entry in enumerate(response_data):
custom_id_to_response_row[entry["custom_id"]] = row_id
logger.info("Running merge map")
final_dataset = query_data.map(partial(merge_query_response, response_data=response_data, response_map=custom_id_to_response_row),
num_proc=multiprocessing.cpu_count())
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop")
return final_dataset