diff --git a/pdelfin/train/dataloader.py b/pdelfin/train/dataloader.py index 7f00948..9a22c39 100644 --- a/pdelfin/train/dataloader.py +++ b/pdelfin/train/dataloader.py @@ -241,53 +241,3 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc) return final_dataset - - -def make_dataset( - train_data_config: DataConfig, - valid_data_config: Optional[DataConfig] = None, - test_data_config: Optional[DataConfig] = None, - num_proc: int = 32, - logger: Optional[Logger] = None, -): - logger = logger or get_logger(__name__) - random.seed(train_data_config.seed) - - dataset_splits: Dict[str, Dataset] = {} - tmp_train_sets = [] - - logger.info("Loading training data from %s sources", len(train_data_config.sources)) - for source in train_data_config.sources: - tmp_train_sets.append( - build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) - ) - dataset_splits["train"] = concatenate_datasets(tmp_train_sets) - logger.info( - f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources" - ) - - if valid_data_config: - tmp_validation_sets = [] - logger.info("Loading validation data from %s sources", len(valid_data_config.sources)) - for source in valid_data_config.sources: - tmp_validation_sets.append( - build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) - ) - dataset_splits["validation"] = concatenate_datasets(tmp_validation_sets) - logger.info( - f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources" - ) - - if test_data_config: - tmp_test_sets = [] - logger.info("Loading test data from %s sources", len(test_data_config.sources)) - for source in test_data_config.sources: - tmp_test_sets.append( - build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) - ) - dataset_splits["test"] = concatenate_datasets(tmp_test_sets) - logger.info( - f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources" - ) - - return DatasetDict(**dataset_splits) \ No newline at end of file diff --git a/pdelfin/train/utils.py b/pdelfin/train/utils.py index 50a8cdd..ad56591 100644 --- a/pdelfin/train/utils.py +++ b/pdelfin/train/utils.py @@ -44,7 +44,7 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype: def get_rawdataset_from_source(source) -> Dataset: if source.parquet_path is not None: - return load_dataset("parquet", data_files=list_dataset_files(source.parquet_path)) + return load_dataset("parquet", data_files=list_dataset_files(source.parquet_path))["train"] else: return build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)