mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-02 02:01:09 +00:00
Starting to write dataloader for visual lm data
This commit is contained in:
parent
fb4fc4229e
commit
d22b311340
0
pdelfin/train/__init__.py
Normal file
0
pdelfin/train/__init__.py
Normal file
99
pdelfin/train/dataloader.py
Normal file
99
pdelfin/train/dataloader.py
Normal file
@ -0,0 +1,99 @@
|
||||
import json
|
||||
from datasets import load_dataset, Dataset, Features, Value
|
||||
import boto3
|
||||
from typing import Dict, Any
|
||||
import logging
|
||||
import re
|
||||
import random
|
||||
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def list_s3_files(s3_path: str):
|
||||
"""
|
||||
Lists files in the specified S3 path that match the glob pattern.
|
||||
"""
|
||||
s3 = boto3.client('s3')
|
||||
match = re.match(r"s3://([^/]+)/(.+)", s3_path)
|
||||
if not match:
|
||||
logger.error(f"Invalid S3 path: {s3_path}")
|
||||
raise ValueError(f"Invalid S3 path: {s3_path}")
|
||||
|
||||
bucket, prefix_pattern = match.groups()
|
||||
prefix = prefix_pattern.split('*')[0] # Extract prefix before the wildcard
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
||||
|
||||
files = []
|
||||
pattern = re.compile(prefix_pattern.replace('*', '.*'))
|
||||
for page in pages:
|
||||
for obj in page.get('Contents', []):
|
||||
key = obj['Key']
|
||||
if pattern.fullmatch(key):
|
||||
files.append(f"s3://{bucket}/{key}")
|
||||
return files
|
||||
|
||||
|
||||
def load_jsonl_from_s3(s3_glob_path: str, first_n_files: int=None) -> Dataset:
|
||||
"""
|
||||
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
|
||||
"""
|
||||
all_s3_files = list_s3_files(s3_glob_path)
|
||||
|
||||
if first_n_files:
|
||||
all_s3_files = all_s3_files[:first_n_files]
|
||||
|
||||
# Use datasets library to load JSON files from S3
|
||||
dataset = load_dataset(
|
||||
'json',
|
||||
data_files=all_s3_files,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts necessary fields from a query entry.
|
||||
"""
|
||||
custom_id = query.get('custom_id', '')
|
||||
body = query.get('body', {})
|
||||
messages = body.get('messages', [])
|
||||
|
||||
input_prompt_text = ""
|
||||
input_prompt_image_base64 = ""
|
||||
|
||||
for message in messages:
|
||||
if message.get('role') != 'user':
|
||||
continue # We are only interested in user messages
|
||||
|
||||
contents = message.get('content', [])
|
||||
for content_item in contents:
|
||||
if content_item.get('type') == 'text':
|
||||
input_prompt_text = content_item.get('text', "")
|
||||
elif content_item.get('type') == 'image_url':
|
||||
image_url = content_item.get('image_url', {}).get('url', "")
|
||||
if image_url.startswith('data:image'):
|
||||
# Extract base64 part from data URL
|
||||
try:
|
||||
base64_data = image_url.split(',', 1)[1]
|
||||
input_prompt_image_base64 = base64_data
|
||||
except IndexError:
|
||||
input_prompt_image_base64 = ""
|
||||
|
||||
return {
|
||||
'custom_id': custom_id,
|
||||
'input_prompt_text': input_prompt_text,
|
||||
'input_prompt_image_base64': input_prompt_image_base64
|
||||
}
|
||||
|
||||
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str) -> Dataset:
|
||||
query_ds = load_jsonl_from_s3(query_glob_path)
|
||||
response_ds = load_jsonl_from_s3(response_glob_path)
|
||||
|
||||
# Now merge them based on the custom_id field
|
||||
|
||||
|
||||
return query_ds
|
||||
11
pdelfin/train/train.py
Normal file
11
pdelfin/train/train.py
Normal file
@ -0,0 +1,11 @@
|
||||
# Step 1, load the data
|
||||
# Probably, we want to see just a folder with openai batch input jsonls, plus the batch output jsonls
|
||||
# TODO: Figure out hyperparameters for image sizing
|
||||
|
||||
# Step 2. Load those prompts through and do a forward pass to calculate the loss
|
||||
|
||||
# Step 3. Add hugging face accelerate for training
|
||||
|
||||
# Step 4. Checkpointing code, both saving and reloading to restart
|
||||
|
||||
# Step 5. Move over from interactive session to gantry launch script
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import html
|
||||
import unittest
|
||||
import multiprocessing
|
||||
|
||||
|
||||
26
tests/test_dataloader.py
Normal file
26
tests/test_dataloader.py
Normal file
@ -0,0 +1,26 @@
|
||||
import unittest
|
||||
|
||||
from pdelfin.train.dataloader import load_jsonl_from_s3, build_batch_query_response_vision_dataset
|
||||
from pdelfin.train.dataloader import extract_openai_batch_query
|
||||
|
||||
class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
def testLoadS3(self):
|
||||
ds = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3)
|
||||
|
||||
print(f"Loaded {len(ds)} entries")
|
||||
print(ds)
|
||||
print(ds["train"])
|
||||
|
||||
def testCombinedQueryResponse(self):
|
||||
ds = build_batch_query_response_vision_dataset(query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
|
||||
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json")
|
||||
|
||||
print(ds)
|
||||
|
||||
def testExtractBatch(self):
|
||||
query_data = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3)
|
||||
query_data = query_data["train"]
|
||||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
|
||||
|
||||
print(query_data)
|
||||
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"])
|
||||
Loading…
x
Reference in New Issue
Block a user