Loading dataset from config now

This commit is contained in:
Jake Poznanski 2024-09-23 09:40:24 -07:00
parent ab9458b913
commit ea3af0143c
5 changed files with 74 additions and 58 deletions

View File

View File

@ -27,63 +27,21 @@ generate:
train_data:
seed: 1337
sources:
- name: fw-edu-all
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/fw-edu-all/*.json.gz
backend:
- openai
size: 100_000
- name: dclm
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dclm/*.zstd
backend:
- openai
size: 100_000
- name: dolma-v17
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v17/*.zstd
backend:
- openai
size: 100_000
- name: dolma-v1-small
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v1-small/*.zstd
- name: openai_batch_data_v2_mini
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json
backend:
- openai
size: 100_000
valid_data:
sources:
- name: fw-edu-10k
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/fw-edu-10k/valid/*.gz
- name: openai_batch_data_v2_mini
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json
backend:
- openai
size: 1500
- name: dolma-10k
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-10k/valid/*.gz
backend:
- openai
size: 1500
- name: dclm
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dclm/*.zstd
backend:
- openai
size: 1500
- name: dolma-v17
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v17/*.zstd
backend:
- openai
size: 1500
- name: dolma-v1-small
paths:
- s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v1-small/*.zstd
backend:
- openai
size: 3000
size: 100_000
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:

View File

@ -76,7 +76,8 @@ class AwsConfig:
class SourceConfig:
name: str = field(help="The name of the source")
size: int = field(help="Limit size for the source")
paths: List[str] = field(help="The paths to the data files")
query_glob_path: str = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data")
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
backend: List[str] = field(help="The data generation backend to use to train the model")

View File

@ -2,12 +2,17 @@ import json
import logging
import multiprocessing
import re
import random
from functools import partial
from typing import Any, Dict
from typing import Any, Dict, Optional
from logging import Logger
import boto3
from datasets import Dataset, Features, Value, load_dataset
from .core.config import DataConfig, SourceConfig
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -121,7 +126,7 @@ def merge_query_response(query_example, response_data: Dataset, response_map: di
return {"response": response_row["response"], "finish_reason": response_row["finish_reason"]}
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str) -> Dataset:
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str, num_proc: int=32) -> Dataset:
logger.info("Loading query and response datasets")
query_data = load_jsonl_from_s3(query_glob_path)
response_data = load_jsonl_from_s3(response_glob_path)
@ -145,8 +150,58 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
logger.info("Running merge map")
final_dataset = query_data.map(
partial(merge_query_response, response_data=response_data, response_map=custom_id_to_response_row),
num_proc=multiprocessing.cpu_count(),
num_proc=num_proc
)
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop")
return final_dataset
def make_dataset(
train_data_config: DataConfig,
valid_data_config: Optional[DataConfig] = None,
test_data_config: Optional[DataConfig] = None,
num_proc: int = 32,
logger: Optional[Logger] = None,
):
logger = logger or get_logger(__name__)
random.seed(train_data_config.seed)
dataset_splits: Dict[str, datasets.Dataset] = {}
tmp_train_sets = []
logger.info("Loading training data from %s sources", len(train_data_config.sources))
for source in train_data_config.sources:
tmp_train_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["train"] = datasets.concatenate_datasets(tmp_train_sets)
logger.info(
f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources"
)
if valid_data_config:
tmp_validation_sets = []
logger.info("Loading validation data from %s sources", len(valid_data_config.sources))
for source in valid_data_config.sources:
tmp_validation_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["validation"] = datasets.concatenate_datasets(tmp_validation_sets)
logger.info(
f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources"
)
if test_data_config:
tmp_test_sets = []
logger.info("Loading test data from %s sources", len(test_data_config.sources))
for source in test_data_config.sources:
tmp_test_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["test"] = datasets.concatenate_datasets(tmp_test_sets)
logger.info(
f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources"
)
return datasets.DatasetDict(**dataset_splits)

View File

@ -59,7 +59,7 @@ from .utils import (
)
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataloader import make_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
@ -125,10 +125,12 @@ def run_train(config: TrainConfig):
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
train_ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
)
dataset = make_dataset(
train_data_config=config.train_data,
valid_data_config=config.valid_data,
num_proc=config.num_proc,
logger=logger,
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"