mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-15 19:47:34 +00:00
Typo
This commit is contained in:
parent
e973de7ba9
commit
4557a5b296
@ -116,9 +116,11 @@ def run_train(config: TrainConfig):
|
||||
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
|
||||
accelerator = accelerate.Accelerator()
|
||||
|
||||
processor = AutoProcessor.from_pretrained(config.model.name_or_path)
|
||||
|
||||
# Build and download the dataset on process 0
|
||||
if accelerator.is_main_process:
|
||||
make_dataset(config)
|
||||
make_dataset(config, processor)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@ -128,8 +130,7 @@ def run_train(config: TrainConfig):
|
||||
config.model.name_or_path, torch_dtype=torch.bfloat16,
|
||||
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(config.model.name_or_path)
|
||||
|
||||
|
||||
if config.lora is not None:
|
||||
peft_config = LoraConfig(
|
||||
r=config.lora.rank,
|
||||
|
@ -14,6 +14,7 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoProcessor
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import PrecisionType
|
||||
from datasets import Dataset, concatenate_datasets, DatasetDict
|
||||
@ -41,7 +42,7 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
|
||||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
def make_dataset(config: TrainConfig) -> tuple[Dataset, Dataset]:
|
||||
def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
|
||||
random.seed(config.train_data.seed)
|
||||
|
||||
# Training sets get all concatenated and shuffled
|
||||
|
Loading…
x
Reference in New Issue
Block a user