diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 91e4540..c803191 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -140,10 +140,14 @@ def run_train(config: TrainConfig): # formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) # Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance - formatted_dataset = dataset.to_iterable_dataset(num_shards=64) - formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500) + train_ds = dataset["train"].to_iterable_dataset(num_shards=64) + validation_ds = dataset["validation"] - print(formatted_dataset) + train_ds = train_ds.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500) + validation_ds = validation_ds.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)) + + print(train_ds) + print(validation_ds) print("---------------") save_path = join_path("", config.save.path, run_name.run) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 71c1da1..cc596ed 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -85,3 +85,13 @@ class TestBatchQueryResponseDataset(unittest.TestCase): print(response_data) print(response_data[0]) + + def testIterableDataset(self): + dataset = build_batch_query_response_vision_dataset( + query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", + response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", + ) + processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") + + formatted_dataset = dataset.to_iterable_dataset(num_shards=64) + formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500)