mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-12 15:51:26 +00:00
Dataloader capabable of loading 38k rows reasonably fast
This commit is contained in:
parent
d22b311340
commit
f4d18cb287
@ -4,7 +4,8 @@ import boto3
|
||||
from typing import Dict, Any
|
||||
import logging
|
||||
import re
|
||||
import random
|
||||
import multiprocessing
|
||||
from functools import partial
|
||||
|
||||
|
||||
# Configure logging
|
||||
@ -16,22 +17,22 @@ def list_s3_files(s3_path: str):
|
||||
"""
|
||||
Lists files in the specified S3 path that match the glob pattern.
|
||||
"""
|
||||
s3 = boto3.client('s3')
|
||||
s3 = boto3.client("s3")
|
||||
match = re.match(r"s3://([^/]+)/(.+)", s3_path)
|
||||
if not match:
|
||||
logger.error(f"Invalid S3 path: {s3_path}")
|
||||
raise ValueError(f"Invalid S3 path: {s3_path}")
|
||||
|
||||
bucket, prefix_pattern = match.groups()
|
||||
prefix = prefix_pattern.split('*')[0] # Extract prefix before the wildcard
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
|
||||
paginator = s3.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
||||
|
||||
files = []
|
||||
pattern = re.compile(prefix_pattern.replace('*', '.*'))
|
||||
pattern = re.compile(prefix_pattern.replace("*", ".*"))
|
||||
for page in pages:
|
||||
for obj in page.get('Contents', []):
|
||||
key = obj['Key']
|
||||
for obj in page.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
if pattern.fullmatch(key):
|
||||
files.append(f"s3://{bucket}/{key}")
|
||||
return files
|
||||
@ -48,7 +49,7 @@ def load_jsonl_from_s3(s3_glob_path: str, first_n_files: int=None) -> Dataset:
|
||||
|
||||
# Use datasets library to load JSON files from S3
|
||||
dataset = load_dataset(
|
||||
'json',
|
||||
"json",
|
||||
data_files=all_s3_files,
|
||||
)
|
||||
|
||||
@ -56,29 +57,29 @@ def load_jsonl_from_s3(s3_glob_path: str, first_n_files: int=None) -> Dataset:
|
||||
|
||||
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts necessary fields from a query entry.
|
||||
Extracts necessary fields from a query entry passed to openai's batch API for vision LMs
|
||||
"""
|
||||
custom_id = query.get('custom_id', '')
|
||||
body = query.get('body', {})
|
||||
messages = body.get('messages', [])
|
||||
custom_id = query.get("custom_id", "")
|
||||
body = query.get("body", {})
|
||||
messages = body.get("messages", [])
|
||||
|
||||
input_prompt_text = ""
|
||||
input_prompt_image_base64 = ""
|
||||
|
||||
for message in messages:
|
||||
if message.get('role') != 'user':
|
||||
if message.get("role") != "user":
|
||||
continue # We are only interested in user messages
|
||||
|
||||
contents = message.get('content', [])
|
||||
contents = message.get("content", [])
|
||||
for content_item in contents:
|
||||
if content_item.get('type') == 'text':
|
||||
input_prompt_text = content_item.get('text', "")
|
||||
elif content_item.get('type') == 'image_url':
|
||||
image_url = content_item.get('image_url', {}).get('url', "")
|
||||
if image_url.startswith('data:image'):
|
||||
if content_item.get("type") == "text":
|
||||
input_prompt_text = content_item.get("text", "")
|
||||
elif content_item.get("type") == "image_url":
|
||||
image_url = content_item.get("image_url", {}).get("url", "")
|
||||
if image_url.startswith("data:image"):
|
||||
# Extract base64 part from data URL
|
||||
try:
|
||||
base64_data = image_url.split(',', 1)[1]
|
||||
base64_data = image_url.split(",", 1)[1]
|
||||
input_prompt_image_base64 = base64_data
|
||||
except IndexError:
|
||||
input_prompt_image_base64 = ""
|
||||
@ -89,11 +90,67 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
||||
'input_prompt_image_base64': input_prompt_image_base64
|
||||
}
|
||||
|
||||
|
||||
def extract_openai_batch_response(example):
|
||||
custom_id = example.get('custom_id', None)
|
||||
response_body = example.get('response', {}).get('body', {})
|
||||
choices = response_body.get('choices', [])
|
||||
response = ''
|
||||
finish_reason = ''
|
||||
if choices:
|
||||
first_choice = choices[0]
|
||||
message = first_choice.get('message', {})
|
||||
response = message.get('content', '')
|
||||
finish_reason = first_choice.get('finish_reason', '')
|
||||
|
||||
return {
|
||||
'custom_id': custom_id,
|
||||
'response': response,
|
||||
"finish_reason": finish_reason
|
||||
}
|
||||
|
||||
|
||||
def merge_query_response(query_example, response_data: Dataset, response_map: dict[str, int]):
|
||||
custom_id = query_example["custom_id"]
|
||||
|
||||
if custom_id not in response_map:
|
||||
return {
|
||||
"response": None,
|
||||
"finish_reason": None,
|
||||
}
|
||||
|
||||
response_row = response_data[response_map[custom_id]]
|
||||
|
||||
return {
|
||||
"response": response_row["response"],
|
||||
"finish_reason": response_row["finish_reason"]
|
||||
}
|
||||
|
||||
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str) -> Dataset:
|
||||
query_ds = load_jsonl_from_s3(query_glob_path)
|
||||
response_ds = load_jsonl_from_s3(response_glob_path)
|
||||
logger.info("Loading query and response datasets")
|
||||
query_data = load_jsonl_from_s3(query_glob_path)
|
||||
response_data = load_jsonl_from_s3(response_glob_path)
|
||||
|
||||
# Now merge them based on the custom_id field
|
||||
# Map the datasets down to the core fields that we're going to need to make them easier to process
|
||||
logger.info("Mapping query data")
|
||||
query_data = query_data["train"]
|
||||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
|
||||
|
||||
logger.info("Mapping response data")
|
||||
response_data = response_data["train"]
|
||||
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names)
|
||||
|
||||
# What we're going to do, is build an in-memory map for the response data from custom_id to row
|
||||
# This will let us do quick lookups when we do a merge step, but it will not scale past a certain point
|
||||
logger.info("Building custom_id to row map")
|
||||
custom_id_to_response_row = {}
|
||||
for row_id, entry in enumerate(response_data):
|
||||
custom_id_to_response_row[entry["custom_id"]] = row_id
|
||||
|
||||
logger.info("Running merge map")
|
||||
final_dataset = query_data.map(partial(merge_query_response, response_data=response_data, response_map=custom_id_to_response_row),
|
||||
num_proc=multiprocessing.cpu_count())
|
||||
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop")
|
||||
|
||||
return final_dataset
|
||||
|
||||
return query_ds
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from pdelfin.train.dataloader import load_jsonl_from_s3, build_batch_query_response_vision_dataset
|
||||
from pdelfin.train.dataloader import extract_openai_batch_query
|
||||
from pdelfin.train.dataloader import extract_openai_batch_query, extract_openai_batch_response
|
||||
|
||||
class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
def testLoadS3(self):
|
||||
@ -23,4 +23,13 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
|
||||
|
||||
print(query_data)
|
||||
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"])
|
||||
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"])
|
||||
|
||||
def testExtractResponse(self):
|
||||
response_data = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", first_n_files=3)
|
||||
response_data = response_data["train"]
|
||||
|
||||
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names)
|
||||
|
||||
print(response_data)
|
||||
print(response_data[0])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user