mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-25 16:30:28 +00:00
Datasetdict fix
This commit is contained in:
parent
decfd7fbc1
commit
e53f782b0f
@ -15,6 +15,7 @@ from tqdm import tqdm
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed
|
||||
from datasets import DatasetDict
|
||||
from datasets.utils import disable_progress_bars
|
||||
from datasets.utils.logging import set_verbosity
|
||||
from peft import LoraConfig, get_peft_model # pyright: ignore
|
||||
@ -137,8 +138,8 @@ def run_train(config: TrainConfig):
|
||||
model = get_peft_model(model=model, peft_config=peft_config)
|
||||
log_trainable_parameters(model=model, logger=logger)
|
||||
|
||||
filtered_dataset = {split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset}
|
||||
|
||||
# Do final filtering, and prep for running model forward()
|
||||
filtered_dataset = DatasetDict(**{split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset})
|
||||
formatted_dataset = filtered_dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
print(formatted_dataset)
|
||||
print("---------------")
|
||||
|
Loading…
x
Reference in New Issue
Block a user