Proper use of iterable_dataset

This commit is contained in:
Jake Poznanski 2024-09-26 19:55:54 +00:00
parent 05fdb81da2
commit cf1aa0176e
2 changed files with 17 additions and 3 deletions

View File

@ -140,10 +140,14 @@ def run_train(config: TrainConfig):
# formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
# Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance
formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500)
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
validation_ds = dataset["validation"]
print(formatted_dataset)
train_ds = train_ds.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500)
validation_ds = validation_ds.map(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(train_ds)
print(validation_ds)
print("---------------")
save_path = join_path("", config.save.path, run_name.run)

View File

@ -85,3 +85,13 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
print(response_data)
print(response_data[0])
def testIterableDataset(self):
dataset = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500)