From d22b311340dfe206a5f521998017c93e91eae8ef Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 18 Sep 2024 21:42:09 +0000 Subject: [PATCH] Starting to write dataloader for visual lm data --- pdelfin/train/__init__.py | 0 pdelfin/train/dataloader.py | 99 +++++++++++++++++++++++++++++++++++++ pdelfin/train/train.py | 11 +++++ tests/test_coherency.py | 2 +- tests/test_dataloader.py | 26 ++++++++++ 5 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 pdelfin/train/__init__.py create mode 100644 pdelfin/train/dataloader.py create mode 100644 pdelfin/train/train.py create mode 100644 tests/test_dataloader.py diff --git a/pdelfin/train/__init__.py b/pdelfin/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pdelfin/train/dataloader.py b/pdelfin/train/dataloader.py new file mode 100644 index 0000000..fa575d3 --- /dev/null +++ b/pdelfin/train/dataloader.py @@ -0,0 +1,99 @@ +import json +from datasets import load_dataset, Dataset, Features, Value +import boto3 +from typing import Dict, Any +import logging +import re +import random + + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def list_s3_files(s3_path: str): + """ + Lists files in the specified S3 path that match the glob pattern. + """ + s3 = boto3.client('s3') + match = re.match(r"s3://([^/]+)/(.+)", s3_path) + if not match: + logger.error(f"Invalid S3 path: {s3_path}") + raise ValueError(f"Invalid S3 path: {s3_path}") + + bucket, prefix_pattern = match.groups() + prefix = prefix_pattern.split('*')[0] # Extract prefix before the wildcard + paginator = s3.get_paginator('list_objects_v2') + pages = paginator.paginate(Bucket=bucket, Prefix=prefix) + + files = [] + pattern = re.compile(prefix_pattern.replace('*', '.*')) + for page in pages: + for obj in page.get('Contents', []): + key = obj['Key'] + if pattern.fullmatch(key): + files.append(f"s3://{bucket}/{key}") + return files + + +def load_jsonl_from_s3(s3_glob_path: str, first_n_files: int=None) -> Dataset: + """ + Loads JSONL files from the specified S3 path into a Hugging Face Dataset. + """ + all_s3_files = list_s3_files(s3_glob_path) + + if first_n_files: + all_s3_files = all_s3_files[:first_n_files] + + # Use datasets library to load JSON files from S3 + dataset = load_dataset( + 'json', + data_files=all_s3_files, + ) + + return dataset + +def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]: + """ + Extracts necessary fields from a query entry. + """ + custom_id = query.get('custom_id', '') + body = query.get('body', {}) + messages = body.get('messages', []) + + input_prompt_text = "" + input_prompt_image_base64 = "" + + for message in messages: + if message.get('role') != 'user': + continue # We are only interested in user messages + + contents = message.get('content', []) + for content_item in contents: + if content_item.get('type') == 'text': + input_prompt_text = content_item.get('text', "") + elif content_item.get('type') == 'image_url': + image_url = content_item.get('image_url', {}).get('url', "") + if image_url.startswith('data:image'): + # Extract base64 part from data URL + try: + base64_data = image_url.split(',', 1)[1] + input_prompt_image_base64 = base64_data + except IndexError: + input_prompt_image_base64 = "" + + return { + 'custom_id': custom_id, + 'input_prompt_text': input_prompt_text, + 'input_prompt_image_base64': input_prompt_image_base64 + } + +def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str) -> Dataset: + query_ds = load_jsonl_from_s3(query_glob_path) + response_ds = load_jsonl_from_s3(response_glob_path) + + # Now merge them based on the custom_id field + + + return query_ds diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py new file mode 100644 index 0000000..fb64ddc --- /dev/null +++ b/pdelfin/train/train.py @@ -0,0 +1,11 @@ +# Step 1, load the data +# Probably, we want to see just a folder with openai batch input jsonls, plus the batch output jsonls +# TODO: Figure out hyperparameters for image sizing + +# Step 2. Load those prompts through and do a forward pass to calculate the loss + +# Step 3. Add hugging face accelerate for training + +# Step 4. Checkpointing code, both saving and reloading to restart + +# Step 5. Move over from interactive session to gantry launch script \ No newline at end of file diff --git a/tests/test_coherency.py b/tests/test_coherency.py index 8f7b33b..b388207 100644 --- a/tests/test_coherency.py +++ b/tests/test_coherency.py @@ -1,6 +1,6 @@ import os import time - +import html import unittest import multiprocessing diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py new file mode 100644 index 0000000..7f78cf7 --- /dev/null +++ b/tests/test_dataloader.py @@ -0,0 +1,26 @@ +import unittest + +from pdelfin.train.dataloader import load_jsonl_from_s3, build_batch_query_response_vision_dataset +from pdelfin.train.dataloader import extract_openai_batch_query + +class TestBatchQueryResponseDataset(unittest.TestCase): + def testLoadS3(self): + ds = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3) + + print(f"Loaded {len(ds)} entries") + print(ds) + print(ds["train"]) + + def testCombinedQueryResponse(self): + ds = build_batch_query_response_vision_dataset(query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", + response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json") + + print(ds) + + def testExtractBatch(self): + query_data = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3) + query_data = query_data["train"] + query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names) + + print(query_data) + print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"]) \ No newline at end of file