diff --git a/pdelfin/train/dataloader.py b/pdelfin/train/dataloader.py index 5352156..2eb49be 100644 --- a/pdelfin/train/dataloader.py +++ b/pdelfin/train/dataloader.py @@ -9,7 +9,7 @@ from typing import Any, Dict, Optional from logging import Logger import boto3 -from datasets import Dataset, Features, Value, load_dataset +from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict from .core.config import DataConfig, SourceConfig @@ -167,7 +167,7 @@ def make_dataset( logger = logger or get_logger(__name__) random.seed(train_data_config.seed) - dataset_splits: Dict[str, datasets.Dataset] = {} + dataset_splits: Dict[str, Dataset] = {} tmp_train_sets = [] logger.info("Loading training data from %s sources", len(train_data_config.sources)) @@ -175,7 +175,7 @@ def make_dataset( tmp_train_sets.append( build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) ) - dataset_splits["train"] = datasets.concatenate_datasets(tmp_train_sets) + dataset_splits["train"] = concatenate_datasets(tmp_train_sets) logger.info( f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources" ) @@ -187,7 +187,7 @@ def make_dataset( tmp_validation_sets.append( build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) ) - dataset_splits["validation"] = datasets.concatenate_datasets(tmp_validation_sets) + dataset_splits["validation"] = concatenate_datasets(tmp_validation_sets) logger.info( f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources" ) @@ -199,9 +199,9 @@ def make_dataset( tmp_test_sets.append( build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) ) - dataset_splits["test"] = datasets.concatenate_datasets(tmp_test_sets) + dataset_splits["test"] = concatenate_datasets(tmp_test_sets) logger.info( f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources" ) - return datasets.DatasetDict(**dataset_splits) \ No newline at end of file + return DatasetDict(**dataset_splits) \ No newline at end of file diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 015133d..7f4f543 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -149,8 +149,8 @@ def run_train(config: TrainConfig): model = get_peft_model(model=model, peft_config=peft_config) log_trainable_parameters(model=model, logger=logger) - train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) - print(train_ds) + formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) + print(formatted_dataset) print("---------------") save_path = join_path("", config.save.path, run_name.run) @@ -202,8 +202,8 @@ def run_train(config: TrainConfig): trainer = Trainer( model=model, args=training_args, - train_dataset=train_ds, - #eval_dataset=formatted_dataset["validation"], # pyright: ignore + train_dataset=formatted_dataset["train"], + eval_dataset=formatted_dataset["validation"], # pyright: ignore tokenizer=processor.tokenizer, #data_collator=collator, #callbacks=[checkpoint_callback],