2024-09-18 21:42:09 +00:00
|
|
|
import json
|
|
|
|
import logging
|
2024-09-18 22:48:38 +00:00
|
|
|
import multiprocessing
|
2024-09-18 22:52:42 +00:00
|
|
|
import re
|
2024-09-23 09:40:24 -07:00
|
|
|
import random
|
|
|
|
|
2024-09-18 22:48:38 +00:00
|
|
|
from functools import partial
|
2024-09-23 09:40:24 -07:00
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
from logging import Logger
|
2024-09-18 21:42:09 +00:00
|
|
|
|
2024-09-18 22:52:42 +00:00
|
|
|
import boto3
|
|
|
|
from datasets import Dataset, Features, Value, load_dataset
|
2024-09-18 21:42:09 +00:00
|
|
|
|
2024-09-23 09:40:24 -07:00
|
|
|
from .core.config import DataConfig, SourceConfig
|
|
|
|
|
2024-09-18 21:42:09 +00:00
|
|
|
# Configure logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def list_s3_files(s3_path: str):
|
|
|
|
"""
|
|
|
|
Lists files in the specified S3 path that match the glob pattern.
|
|
|
|
"""
|
2024-09-18 22:48:38 +00:00
|
|
|
s3 = boto3.client("s3")
|
2024-09-18 21:42:09 +00:00
|
|
|
match = re.match(r"s3://([^/]+)/(.+)", s3_path)
|
|
|
|
if not match:
|
|
|
|
logger.error(f"Invalid S3 path: {s3_path}")
|
|
|
|
raise ValueError(f"Invalid S3 path: {s3_path}")
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-09-18 21:42:09 +00:00
|
|
|
bucket, prefix_pattern = match.groups()
|
2024-09-18 22:48:38 +00:00
|
|
|
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
|
|
|
|
paginator = s3.get_paginator("list_objects_v2")
|
2024-09-18 21:42:09 +00:00
|
|
|
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-09-18 21:42:09 +00:00
|
|
|
files = []
|
2024-09-18 22:48:38 +00:00
|
|
|
pattern = re.compile(prefix_pattern.replace("*", ".*"))
|
2024-09-18 21:42:09 +00:00
|
|
|
for page in pages:
|
2024-09-18 22:48:38 +00:00
|
|
|
for obj in page.get("Contents", []):
|
|
|
|
key = obj["Key"]
|
2024-09-18 21:42:09 +00:00
|
|
|
if pattern.fullmatch(key):
|
|
|
|
files.append(f"s3://{bucket}/{key}")
|
|
|
|
return files
|
|
|
|
|
|
|
|
|
2024-09-18 22:52:42 +00:00
|
|
|
def load_jsonl_from_s3(s3_glob_path: str, first_n_files: int = None) -> Dataset:
|
2024-09-18 21:42:09 +00:00
|
|
|
"""
|
|
|
|
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
|
|
|
|
"""
|
|
|
|
all_s3_files = list_s3_files(s3_glob_path)
|
|
|
|
|
|
|
|
if first_n_files:
|
|
|
|
all_s3_files = all_s3_files[:first_n_files]
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-09-18 21:42:09 +00:00
|
|
|
# Use datasets library to load JSON files from S3
|
|
|
|
dataset = load_dataset(
|
2024-09-18 22:48:38 +00:00
|
|
|
"json",
|
2024-09-18 21:42:09 +00:00
|
|
|
data_files=all_s3_files,
|
|
|
|
)
|
|
|
|
|
|
|
|
return dataset
|
|
|
|
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-09-18 21:42:09 +00:00
|
|
|
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
"""
|
2024-09-18 22:48:38 +00:00
|
|
|
Extracts necessary fields from a query entry passed to openai's batch API for vision LMs
|
2024-09-18 21:42:09 +00:00
|
|
|
"""
|
2024-09-18 22:48:38 +00:00
|
|
|
custom_id = query.get("custom_id", "")
|
|
|
|
body = query.get("body", {})
|
|
|
|
messages = body.get("messages", [])
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-09-18 21:42:09 +00:00
|
|
|
input_prompt_text = ""
|
|
|
|
input_prompt_image_base64 = ""
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-09-18 21:42:09 +00:00
|
|
|
for message in messages:
|
2024-09-18 22:48:38 +00:00
|
|
|
if message.get("role") != "user":
|
2024-09-18 21:42:09 +00:00
|
|
|
continue # We are only interested in user messages
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-09-18 22:48:38 +00:00
|
|
|
contents = message.get("content", [])
|
2024-09-18 21:42:09 +00:00
|
|
|
for content_item in contents:
|
2024-09-18 22:48:38 +00:00
|
|
|
if content_item.get("type") == "text":
|
|
|
|
input_prompt_text = content_item.get("text", "")
|
|
|
|
elif content_item.get("type") == "image_url":
|
|
|
|
image_url = content_item.get("image_url", {}).get("url", "")
|
|
|
|
if image_url.startswith("data:image"):
|
2024-09-18 21:42:09 +00:00
|
|
|
# Extract base64 part from data URL
|
|
|
|
try:
|
2024-09-18 22:48:38 +00:00
|
|
|
base64_data = image_url.split(",", 1)[1]
|
2024-09-18 21:42:09 +00:00
|
|
|
input_prompt_image_base64 = base64_data
|
|
|
|
except IndexError:
|
|
|
|
input_prompt_image_base64 = ""
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-09-18 21:42:09 +00:00
|
|
|
return {
|
2024-09-18 22:52:42 +00:00
|
|
|
"custom_id": custom_id,
|
|
|
|
"input_prompt_text": input_prompt_text,
|
|
|
|
"input_prompt_image_base64": input_prompt_image_base64,
|
2024-09-18 21:42:09 +00:00
|
|
|
}
|
|
|
|
|
2024-09-18 22:48:38 +00:00
|
|
|
|
|
|
|
def extract_openai_batch_response(example):
|
2024-09-18 22:52:42 +00:00
|
|
|
custom_id = example.get("custom_id", None)
|
|
|
|
response_body = example.get("response", {}).get("body", {})
|
|
|
|
choices = response_body.get("choices", [])
|
|
|
|
response = ""
|
|
|
|
finish_reason = ""
|
2024-09-18 22:48:38 +00:00
|
|
|
if choices:
|
|
|
|
first_choice = choices[0]
|
2024-09-18 22:52:42 +00:00
|
|
|
message = first_choice.get("message", {})
|
|
|
|
response = message.get("content", "")
|
|
|
|
finish_reason = first_choice.get("finish_reason", "")
|
2024-09-18 22:48:38 +00:00
|
|
|
|
2024-09-18 22:52:42 +00:00
|
|
|
return {"custom_id": custom_id, "response": response, "finish_reason": finish_reason}
|
2024-09-18 22:48:38 +00:00
|
|
|
|
|
|
|
|
|
|
|
def merge_query_response(query_example, response_data: Dataset, response_map: dict[str, int]):
|
|
|
|
custom_id = query_example["custom_id"]
|
|
|
|
|
|
|
|
if custom_id not in response_map:
|
|
|
|
return {
|
|
|
|
"response": None,
|
|
|
|
"finish_reason": None,
|
|
|
|
}
|
|
|
|
|
|
|
|
response_row = response_data[response_map[custom_id]]
|
|
|
|
|
2024-09-18 22:52:42 +00:00
|
|
|
return {"response": response_row["response"], "finish_reason": response_row["finish_reason"]}
|
|
|
|
|
2024-09-18 22:48:38 +00:00
|
|
|
|
2024-09-23 09:40:24 -07:00
|
|
|
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str, num_proc: int=32) -> Dataset:
|
2024-09-18 22:48:38 +00:00
|
|
|
logger.info("Loading query and response datasets")
|
|
|
|
query_data = load_jsonl_from_s3(query_glob_path)
|
|
|
|
response_data = load_jsonl_from_s3(response_glob_path)
|
|
|
|
|
|
|
|
# Map the datasets down to the core fields that we're going to need to make them easier to process
|
|
|
|
logger.info("Mapping query data")
|
|
|
|
query_data = query_data["train"]
|
|
|
|
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
|
|
|
|
|
|
|
|
logger.info("Mapping response data")
|
|
|
|
response_data = response_data["train"]
|
|
|
|
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names)
|
|
|
|
|
|
|
|
# What we're going to do, is build an in-memory map for the response data from custom_id to row
|
|
|
|
# This will let us do quick lookups when we do a merge step, but it will not scale past a certain point
|
|
|
|
logger.info("Building custom_id to row map")
|
|
|
|
custom_id_to_response_row = {}
|
|
|
|
for row_id, entry in enumerate(response_data):
|
|
|
|
custom_id_to_response_row[entry["custom_id"]] = row_id
|
2024-09-18 21:42:09 +00:00
|
|
|
|
2024-09-18 22:48:38 +00:00
|
|
|
logger.info("Running merge map")
|
2024-09-18 22:52:42 +00:00
|
|
|
final_dataset = query_data.map(
|
|
|
|
partial(merge_query_response, response_data=response_data, response_map=custom_id_to_response_row),
|
2024-09-23 09:40:24 -07:00
|
|
|
num_proc=num_proc
|
2024-09-18 22:52:42 +00:00
|
|
|
)
|
2024-09-18 22:48:38 +00:00
|
|
|
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop")
|
2024-09-18 21:42:09 +00:00
|
|
|
|
2024-09-18 22:48:38 +00:00
|
|
|
return final_dataset
|
2024-09-23 09:40:24 -07:00
|
|
|
|
|
|
|
|
|
|
|
def make_dataset(
|
|
|
|
train_data_config: DataConfig,
|
|
|
|
valid_data_config: Optional[DataConfig] = None,
|
|
|
|
test_data_config: Optional[DataConfig] = None,
|
|
|
|
num_proc: int = 32,
|
|
|
|
logger: Optional[Logger] = None,
|
|
|
|
):
|
|
|
|
logger = logger or get_logger(__name__)
|
|
|
|
random.seed(train_data_config.seed)
|
|
|
|
|
|
|
|
dataset_splits: Dict[str, datasets.Dataset] = {}
|
|
|
|
tmp_train_sets = []
|
|
|
|
|
|
|
|
logger.info("Loading training data from %s sources", len(train_data_config.sources))
|
|
|
|
for source in train_data_config.sources:
|
|
|
|
tmp_train_sets.append(
|
|
|
|
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
|
|
|
)
|
|
|
|
dataset_splits["train"] = datasets.concatenate_datasets(tmp_train_sets)
|
|
|
|
logger.info(
|
|
|
|
f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources"
|
|
|
|
)
|
|
|
|
|
|
|
|
if valid_data_config:
|
|
|
|
tmp_validation_sets = []
|
|
|
|
logger.info("Loading validation data from %s sources", len(valid_data_config.sources))
|
|
|
|
for source in valid_data_config.sources:
|
|
|
|
tmp_validation_sets.append(
|
|
|
|
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
|
|
|
)
|
|
|
|
dataset_splits["validation"] = datasets.concatenate_datasets(tmp_validation_sets)
|
|
|
|
logger.info(
|
|
|
|
f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources"
|
|
|
|
)
|
|
|
|
|
|
|
|
if test_data_config:
|
|
|
|
tmp_test_sets = []
|
|
|
|
logger.info("Loading test data from %s sources", len(test_data_config.sources))
|
|
|
|
for source in test_data_config.sources:
|
|
|
|
tmp_test_sets.append(
|
|
|
|
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
|
|
|
)
|
|
|
|
dataset_splits["test"] = datasets.concatenate_datasets(tmp_test_sets)
|
|
|
|
logger.info(
|
|
|
|
f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources"
|
|
|
|
)
|
|
|
|
|
|
|
|
return datasets.DatasetDict(**dataset_splits)
|