olmocr/pdelfin/train/dataloader.py

244 lines
8.7 KiB
Python
Raw Normal View History

import json
import logging
import multiprocessing
2024-09-18 22:52:42 +00:00
import re
2024-09-23 09:40:24 -07:00
import random
import base64
2024-10-07 07:49:16 -07:00
import glob
2024-09-23 09:40:24 -07:00
from functools import partial
2024-09-23 09:40:24 -07:00
from typing import Any, Dict, Optional
from logging import Logger
2024-09-18 22:52:42 +00:00
import boto3
2024-09-23 09:43:36 -07:00
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
2024-09-23 09:40:24 -07:00
from .core.config import DataConfig, SourceConfig
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def list_dataset_files(s3_glob_path: str):
"""
Lists files in the specified S3 path that match the glob pattern.
"""
if s3_glob_path.startswith("s3://"):
2024-10-07 07:49:16 -07:00
s3 = boto3.client("s3")
match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path)
2024-10-07 07:49:16 -07:00
if not match:
logger.error(f"Invalid S3 path: {s3_glob_path}")
raise ValueError(f"Invalid S3 path: {s3_glob_path}")
2024-10-07 07:49:16 -07:00
bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
files = []
pattern = re.compile(prefix_pattern.replace("*", ".*"))
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if pattern.fullmatch(key):
files.append(f"s3://{bucket}/{key}")
return files
else:
return glob.glob(s3_glob_path)
2024-10-07 07:49:16 -07:00
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset:
"""
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
"""
all_json_files = s3_glob_path(s3_glob_path)
if first_n_files:
2024-10-07 07:49:16 -07:00
all_json_files = all_json_files[:first_n_files]
2024-09-18 22:52:42 +00:00
# Use datasets library to load JSON files from S3
dataset = load_dataset(
"json",
2024-10-07 07:49:16 -07:00
data_files=all_json_files,
)
return dataset
2024-09-18 22:52:42 +00:00
def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]:
"""
Returns the (width, height) of a PNG image given its base64-encoded data,
without base64-decoding the entire data or loading the PNG itself
Should be really fast to support filtering
Parameters:
- base64_data (str): Base64-encoded PNG image data.
Returns:
- tuple: (width, height) of the image.
Raises:
- ValueError: If the data is not a valid PNG image or the required bytes are not found.
"""
# PNG signature is 8 bytes
png_signature_base64 = base64.b64encode(b'\x89PNG\r\n\x1a\n').decode('ascii')
if not base64_data.startswith(png_signature_base64[:8]):
raise ValueError('Not a valid PNG file')
# Positions in the binary data where width and height are stored
width_start = 16 # Byte position where width starts (0-based indexing)
width_end = 20 # Byte position where width ends (exclusive)
height_start = 20
height_end = 24
# Compute the byte range needed (from width_start to height_end)
start_byte = width_start
end_byte = height_end
# Calculate base64 character positions
# Each group of 3 bytes corresponds to 4 base64 characters
base64_start = (start_byte // 3) * 4
base64_end = ((end_byte + 2) // 3) * 4 # Add 2 to ensure we cover partial groups
# Extract the necessary base64 substring
base64_substring = base64_data[base64_start:base64_end]
# Decode only the necessary bytes
decoded_bytes = base64.b64decode(base64_substring)
# Compute the offset within the decoded bytes
offset = start_byte % 3
# Extract width and height bytes
width_bytes = decoded_bytes[offset:offset+4]
height_bytes = decoded_bytes[offset+4:offset+8]
if len(width_bytes) < 4 or len(height_bytes) < 4:
raise ValueError('Insufficient data to extract dimensions')
# Convert bytes to integers
width = int.from_bytes(width_bytes, 'big')
height = int.from_bytes(height_bytes, 'big')
return width, height
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
"""
Extracts necessary fields from a query entry passed to openai's batch API for vision LMs
"""
custom_id = query.get("custom_id", "")
body = query.get("body", {})
messages = body.get("messages", [])
2024-09-18 22:52:42 +00:00
input_prompt_text = ""
input_prompt_image_base64 = ""
2024-09-18 22:52:42 +00:00
for message in messages:
if message.get("role") != "user":
continue # We are only interested in user messages
2024-09-18 22:52:42 +00:00
contents = message.get("content", [])
for content_item in contents:
if content_item.get("type") == "text":
input_prompt_text = content_item.get("text", "")
elif content_item.get("type") == "image_url":
image_url = content_item.get("image_url", {}).get("url", "")
if image_url.startswith("data:image"):
# Extract base64 part from data URL
try:
base64_data = image_url.split(",", 1)[1]
input_prompt_image_base64 = base64_data
except IndexError:
input_prompt_image_base64 = ""
2024-09-18 22:52:42 +00:00
# At this point, the input_prompt_text is the raw text that was passed to the OpenAI model
# to generate our silver data. But, we want to have a simplfied prompt for this here fine tune,
# so we're going to extract out just the raw extracted prompt text
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
# Use re.DOTALL to ensure that the dot matches newline characters
match = re.search(pattern, input_prompt_text, re.DOTALL)
if match:
raw_page_text = match.group(1).strip()
else:
raw_page_text = ""
return {
2024-09-18 22:52:42 +00:00
"custom_id": custom_id,
"input_prompt_text": input_prompt_text,
"input_prompt_image_base64": input_prompt_image_base64,
"raw_page_text": raw_page_text,
}
def extract_openai_batch_response(example):
2024-09-18 22:52:42 +00:00
custom_id = example.get("custom_id", None)
response_body = example.get("response", {}).get("body", {})
choices = response_body.get("choices", [])
response = ""
finish_reason = ""
if choices:
first_choice = choices[0]
2024-09-18 22:52:42 +00:00
message = first_choice.get("message", {})
response = message.get("content", "")
finish_reason = first_choice.get("finish_reason", "")
2024-09-18 22:52:42 +00:00
return {"custom_id": custom_id, "response": response, "finish_reason": finish_reason}
def merge_query_response(query_example, response_data: Dataset, response_map: dict[str, int]):
custom_id = query_example["custom_id"]
if custom_id not in response_map:
return {
"response": None,
"finish_reason": None,
}
response_row = response_data[response_map[custom_id]]
2024-09-18 22:52:42 +00:00
return {"response": response_row["response"], "finish_reason": response_row["finish_reason"]}
2024-09-23 09:40:24 -07:00
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str, num_proc: int=32) -> Dataset:
logger.info("Loading query and response datasets")
2024-10-07 07:49:16 -07:00
query_data = load_jsonl_into_ds(query_glob_path)
response_data = load_jsonl_into_ds(response_glob_path)
# Map the datasets down to the core fields that we're going to need to make them easier to process
logger.info("Mapping query data")
query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc)
logger.info("Mapping response data")
response_data = response_data["train"]
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc)
# What we're going to do, is build an in-memory map for the response data from custom_id to row
# This will let us do quick lookups when we do a merge step, but it will not scale past a certain point
logger.info("Building custom_id to row map")
custom_id_to_response_row = {}
for row_id, entry in enumerate(response_data):
custom_id_to_response_row[entry["custom_id"]] = row_id
logger.info("Running merge map")
2024-09-18 22:52:42 +00:00
final_dataset = query_data.map(
partial(merge_query_response, response_data=response_data, response_map=custom_id_to_response_row),
2024-09-23 09:40:24 -07:00
num_proc=num_proc
2024-09-18 22:52:42 +00:00
)
# Don't include data where the model cut off due to a length issue, or moderation issue
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
# Pick things that have a reasonable image size only
def pick_image_sizes(x):
width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"])
return 1800 <= max(width, height) <= 2200
final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc)
return final_dataset