mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-19 06:12:23 +00:00
loading fix for parquets again...
This commit is contained in:
parent
fdcd77eadd
commit
3d36545fa5
@ -241,53 +241,3 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
|
||||
final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc)
|
||||
|
||||
return final_dataset
|
||||
|
||||
|
||||
def make_dataset(
|
||||
train_data_config: DataConfig,
|
||||
valid_data_config: Optional[DataConfig] = None,
|
||||
test_data_config: Optional[DataConfig] = None,
|
||||
num_proc: int = 32,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
logger = logger or get_logger(__name__)
|
||||
random.seed(train_data_config.seed)
|
||||
|
||||
dataset_splits: Dict[str, Dataset] = {}
|
||||
tmp_train_sets = []
|
||||
|
||||
logger.info("Loading training data from %s sources", len(train_data_config.sources))
|
||||
for source in train_data_config.sources:
|
||||
tmp_train_sets.append(
|
||||
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
)
|
||||
dataset_splits["train"] = concatenate_datasets(tmp_train_sets)
|
||||
logger.info(
|
||||
f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources"
|
||||
)
|
||||
|
||||
if valid_data_config:
|
||||
tmp_validation_sets = []
|
||||
logger.info("Loading validation data from %s sources", len(valid_data_config.sources))
|
||||
for source in valid_data_config.sources:
|
||||
tmp_validation_sets.append(
|
||||
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
)
|
||||
dataset_splits["validation"] = concatenate_datasets(tmp_validation_sets)
|
||||
logger.info(
|
||||
f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources"
|
||||
)
|
||||
|
||||
if test_data_config:
|
||||
tmp_test_sets = []
|
||||
logger.info("Loading test data from %s sources", len(test_data_config.sources))
|
||||
for source in test_data_config.sources:
|
||||
tmp_test_sets.append(
|
||||
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
)
|
||||
dataset_splits["test"] = concatenate_datasets(tmp_test_sets)
|
||||
logger.info(
|
||||
f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources"
|
||||
)
|
||||
|
||||
return DatasetDict(**dataset_splits)
|
@ -44,7 +44,7 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
|
||||
|
||||
def get_rawdataset_from_source(source) -> Dataset:
|
||||
if source.parquet_path is not None:
|
||||
return load_dataset("parquet", data_files=list_dataset_files(source.parquet_path))
|
||||
return load_dataset("parquet", data_files=list_dataset_files(source.parquet_path))["train"]
|
||||
else:
|
||||
return build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user