loading fix for parquets again...

This commit is contained in:
Jake Poznanski 2024-10-07 14:48:53 -07:00
parent fdcd77eadd
commit 3d36545fa5
2 changed files with 1 additions and 51 deletions

View File

@ -241,53 +241,3 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc) final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc)
return final_dataset return final_dataset
def make_dataset(
train_data_config: DataConfig,
valid_data_config: Optional[DataConfig] = None,
test_data_config: Optional[DataConfig] = None,
num_proc: int = 32,
logger: Optional[Logger] = None,
):
logger = logger or get_logger(__name__)
random.seed(train_data_config.seed)
dataset_splits: Dict[str, Dataset] = {}
tmp_train_sets = []
logger.info("Loading training data from %s sources", len(train_data_config.sources))
for source in train_data_config.sources:
tmp_train_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["train"] = concatenate_datasets(tmp_train_sets)
logger.info(
f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources"
)
if valid_data_config:
tmp_validation_sets = []
logger.info("Loading validation data from %s sources", len(valid_data_config.sources))
for source in valid_data_config.sources:
tmp_validation_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["validation"] = concatenate_datasets(tmp_validation_sets)
logger.info(
f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources"
)
if test_data_config:
tmp_test_sets = []
logger.info("Loading test data from %s sources", len(test_data_config.sources))
for source in test_data_config.sources:
tmp_test_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["test"] = concatenate_datasets(tmp_test_sets)
logger.info(
f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources"
)
return DatasetDict(**dataset_splits)

View File

@ -44,7 +44,7 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
def get_rawdataset_from_source(source) -> Dataset: def get_rawdataset_from_source(source) -> Dataset:
if source.parquet_path is not None: if source.parquet_path is not None:
return load_dataset("parquet", data_files=list_dataset_files(source.parquet_path)) return load_dataset("parquet", data_files=list_dataset_files(source.parquet_path))["train"]
else: else:
return build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path) return build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)