diff --git a/pdelfin/buildsilver/__init__.py b/pdelfin/buildsilver/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pdelfin/train/config/qwen2vl-2b.yaml b/pdelfin/train/config/qwen2vl-2b.yaml index 9ef18a1..d77d48c 100644 --- a/pdelfin/train/config/qwen2vl-2b.yaml +++ b/pdelfin/train/config/qwen2vl-2b.yaml @@ -27,63 +27,21 @@ generate: train_data: seed: 1337 sources: - - name: fw-edu-all - paths: - - s3://ai2-tylerm-experimental/experiments/rephrase/v1/fw-edu-all/*.json.gz - backend: - - openai - size: 100_000 - - name: dclm - paths: - - s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dclm/*.zstd - backend: - - openai - size: 100_000 - - name: dolma-v17 - paths: - - s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v17/*.zstd - backend: - - openai - size: 100_000 - - name: dolma-v1-small - paths: - - s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v1-small/*.zstd + - name: openai_batch_data_v2_mini + query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl + response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json backend: - openai size: 100_000 valid_data: sources: - - name: fw-edu-10k - paths: - - s3://ai2-tylerm-experimental/experiments/rephrase/v1/fw-edu-10k/valid/*.gz + - name: openai_batch_data_v2_mini + query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl + response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json backend: - openai - size: 1500 - - name: dolma-10k - paths: - - s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-10k/valid/*.gz - backend: - - openai - size: 1500 - - name: dclm - paths: - - s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dclm/*.zstd - backend: - - openai - size: 1500 - - name: dolma-v17 - paths: - - s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v17/*.zstd - backend: - - openai - size: 1500 - - name: dolma-v1-small - paths: - - s3://ai2-tylerm-experimental/experiments/rephrase/v1/dolma-dclm-300k/dolma-v1-small/*.zstd - backend: - - openai - size: 3000 + size: 100_000 # Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh hparams: diff --git a/pdelfin/train/core/config.py b/pdelfin/train/core/config.py index bd4de46..b3fe925 100644 --- a/pdelfin/train/core/config.py +++ b/pdelfin/train/core/config.py @@ -76,7 +76,8 @@ class AwsConfig: class SourceConfig: name: str = field(help="The name of the source") size: int = field(help="Limit size for the source") - paths: List[str] = field(help="The paths to the data files") + query_glob_path: str = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data") + response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai") backend: List[str] = field(help="The data generation backend to use to train the model") diff --git a/pdelfin/train/dataloader.py b/pdelfin/train/dataloader.py index d745c68..5352156 100644 --- a/pdelfin/train/dataloader.py +++ b/pdelfin/train/dataloader.py @@ -2,12 +2,17 @@ import json import logging import multiprocessing import re +import random + from functools import partial -from typing import Any, Dict +from typing import Any, Dict, Optional +from logging import Logger import boto3 from datasets import Dataset, Features, Value, load_dataset +from .core.config import DataConfig, SourceConfig + # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -121,7 +126,7 @@ def merge_query_response(query_example, response_data: Dataset, response_map: di return {"response": response_row["response"], "finish_reason": response_row["finish_reason"]} -def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str) -> Dataset: +def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str, num_proc: int=32) -> Dataset: logger.info("Loading query and response datasets") query_data = load_jsonl_from_s3(query_glob_path) response_data = load_jsonl_from_s3(response_glob_path) @@ -145,8 +150,58 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo logger.info("Running merge map") final_dataset = query_data.map( partial(merge_query_response, response_data=response_data, response_map=custom_id_to_response_row), - num_proc=multiprocessing.cpu_count(), + num_proc=num_proc ) final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop") return final_dataset + + +def make_dataset( + train_data_config: DataConfig, + valid_data_config: Optional[DataConfig] = None, + test_data_config: Optional[DataConfig] = None, + num_proc: int = 32, + logger: Optional[Logger] = None, +): + logger = logger or get_logger(__name__) + random.seed(train_data_config.seed) + + dataset_splits: Dict[str, datasets.Dataset] = {} + tmp_train_sets = [] + + logger.info("Loading training data from %s sources", len(train_data_config.sources)) + for source in train_data_config.sources: + tmp_train_sets.append( + build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) + ) + dataset_splits["train"] = datasets.concatenate_datasets(tmp_train_sets) + logger.info( + f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources" + ) + + if valid_data_config: + tmp_validation_sets = [] + logger.info("Loading validation data from %s sources", len(valid_data_config.sources)) + for source in valid_data_config.sources: + tmp_validation_sets.append( + build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) + ) + dataset_splits["validation"] = datasets.concatenate_datasets(tmp_validation_sets) + logger.info( + f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources" + ) + + if test_data_config: + tmp_test_sets = [] + logger.info("Loading test data from %s sources", len(test_data_config.sources)) + for source in test_data_config.sources: + tmp_test_sets.append( + build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) + ) + dataset_splits["test"] = datasets.concatenate_datasets(tmp_test_sets) + logger.info( + f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources" + ) + + return datasets.DatasetDict(**dataset_splits) \ No newline at end of file diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 4e9650c..015133d 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -59,7 +59,7 @@ from .utils import ( ) -from pdelfin.train.dataloader import build_batch_query_response_vision_dataset +from pdelfin.train.dataloader import make_dataset from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training @@ -125,10 +125,12 @@ def run_train(config: TrainConfig): setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group) - train_ds = build_batch_query_response_vision_dataset( - query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", - response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json", - ) + dataset = make_dataset( + train_data_config=config.train_data, + valid_data_config=config.valid_data, + num_proc=config.num_proc, + logger=logger, + ) model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"