mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-16 12:08:13 +00:00
Typo
This commit is contained in:
parent
e973de7ba9
commit
4557a5b296
@ -116,9 +116,11 @@ def run_train(config: TrainConfig):
|
|||||||
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
|
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
|
||||||
accelerator = accelerate.Accelerator()
|
accelerator = accelerate.Accelerator()
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(config.model.name_or_path)
|
||||||
|
|
||||||
# Build and download the dataset on process 0
|
# Build and download the dataset on process 0
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
make_dataset(config)
|
make_dataset(config, processor)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@ -128,7 +130,6 @@ def run_train(config: TrainConfig):
|
|||||||
config.model.name_or_path, torch_dtype=torch.bfloat16,
|
config.model.name_or_path, torch_dtype=torch.bfloat16,
|
||||||
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
||||||
)
|
)
|
||||||
processor = AutoProcessor.from_pretrained(config.model.name_or_path)
|
|
||||||
|
|
||||||
if config.lora is not None:
|
if config.lora is not None:
|
||||||
peft_config = LoraConfig(
|
peft_config = LoraConfig(
|
||||||
|
@ -14,6 +14,7 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from transformers import AutoProcessor
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.utils import PrecisionType
|
from accelerate.utils import PrecisionType
|
||||||
from datasets import Dataset, concatenate_datasets, DatasetDict
|
from datasets import Dataset, concatenate_datasets, DatasetDict
|
||||||
@ -41,7 +42,7 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
|
|||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
def make_dataset(config: TrainConfig) -> tuple[Dataset, Dataset]:
|
def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
|
||||||
random.seed(config.train_data.seed)
|
random.seed(config.train_data.seed)
|
||||||
|
|
||||||
# Training sets get all concatenated and shuffled
|
# Training sets get all concatenated and shuffled
|
||||||
|
Loading…
x
Reference in New Issue
Block a user