mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-03 10:42:02 +00:00
Prepping to train
This commit is contained in:
parent
9d647b13b8
commit
cbc667ce78
@ -46,7 +46,7 @@ hparams:
|
||||
batch_size: 1
|
||||
eval_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: false
|
||||
gradient_checkpointing: true
|
||||
clip_grad_norm: 1.0
|
||||
learning_rate: 1e-4
|
||||
max_steps: 10000
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
from transformers import (
|
||||
AutoProcessor
|
||||
AutoProcessor,
|
||||
DataCollatorForSeq2Seq
|
||||
)
|
||||
|
||||
from pdelfin.train.core.cli import make_cli
|
||||
from pdelfin.train.core.config import TrainConfig
|
||||
|
||||
from tqdm import tqdm
|
||||
from .utils import (
|
||||
make_dataset
|
||||
)
|
||||
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
def main():
|
||||
train_config = make_cli(TrainConfig) # pyright: ignore
|
||||
@ -19,14 +20,29 @@ def main():
|
||||
|
||||
print("Training dataset........")
|
||||
print(train_dataset)
|
||||
print(train_dataset[0])
|
||||
print("\n\n")
|
||||
|
||||
print("Validation dataset........")
|
||||
print(valid_dataset)
|
||||
print(valid_dataset[list(valid_dataset.keys())[0]][0])
|
||||
print("\n\n")
|
||||
|
||||
print("Datasets loaded into hugging face cache directory")
|
||||
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=processor.tokenizer, # use the processor's tokenizer
|
||||
max_length=4096,
|
||||
padding=False,
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=data_collator)
|
||||
max_seen_len = 0
|
||||
for entry in tqdm(train_dataloader):
|
||||
num_input_tokens = entry["input_ids"].shape[1]
|
||||
max_seen_len = max(max_seen_len, num_input_tokens)
|
||||
|
||||
print(max_seen_len)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -26,7 +26,8 @@ from transformers import (
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
AutoProcessor
|
||||
AutoProcessor,
|
||||
DataCollatorForSeq2Seq
|
||||
)
|
||||
from transformers.integrations import WandbCallback
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
@ -168,12 +169,16 @@ def run_train(config: TrainConfig):
|
||||
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
|
||||
max_grad_norm=config.hparams.clip_grad_norm,
|
||||
remove_unused_columns=False,
|
||||
eval_on_start=True,
|
||||
#eval_on_start=True,
|
||||
metric_for_best_model=config.valid_data.metric_for_best_model,
|
||||
)
|
||||
|
||||
# Set the collator
|
||||
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False)
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=processor.tokenizer, # use the processor's tokenizer
|
||||
max_length=config.generate.max_length,
|
||||
padding=False,
|
||||
)
|
||||
|
||||
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
|
||||
|
||||
# Initialize Trainer
|
||||
@ -183,16 +188,10 @@ def run_train(config: TrainConfig):
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=valid_dataset,
|
||||
tokenizer=processor.tokenizer,
|
||||
#Collator is not needed as we are doing batch size 1 for now...
|
||||
#data_collator=collator,
|
||||
data_collator=data_collator,
|
||||
callbacks=[checkpoint_callback],
|
||||
)
|
||||
|
||||
# Could not get this to work
|
||||
# if get_rank() == 0:
|
||||
# # this is a hack to add script and peft config to wandb config
|
||||
# update_wandb_config(config, trainer, model)
|
||||
|
||||
# Train the model
|
||||
trainer.train() # pyright: ignore
|
||||
|
||||
|
||||
@ -48,21 +48,37 @@ def get_rawdataset_from_source(data_config: DataConfig, source: SourceConfig) ->
|
||||
def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
|
||||
random.seed(config.train_data.seed)
|
||||
|
||||
# Training sets get all concatenated and shuffled
|
||||
# Retrieve the two target lengths from the first source for comparison
|
||||
first_source = config.train_data.sources[0]
|
||||
target_longest_image_dim = first_source.target_longest_image_dim
|
||||
target_anchor_text_len = first_source.target_anchor_text_len
|
||||
|
||||
# Verify that all sources have the same target lengths
|
||||
for source in config.train_data.sources:
|
||||
if source.target_longest_image_dim != target_longest_image_dim:
|
||||
raise ValueError(f"Inconsistent target_longest_image_dim found in source {source}")
|
||||
if source.target_anchor_text_len != target_anchor_text_len:
|
||||
raise ValueError(f"Inconsistent target_anchor_text_len found in source {source}")
|
||||
|
||||
|
||||
# Concatenate datasets first, unfortunately you can't apply the transform before concatenation due to the library
|
||||
train_dataset = concatenate_datasets(
|
||||
[
|
||||
get_rawdataset_from_source(config.train_data, source).with_transform(
|
||||
partial(
|
||||
batch_prepare_data_for_qwen2_training,
|
||||
processor=processor,
|
||||
target_longest_image_dim=source.target_longest_image_dim,
|
||||
target_anchor_text_len=source.target_anchor_text_len,
|
||||
)
|
||||
)
|
||||
get_rawdataset_from_source(config.train_data, source)
|
||||
for source in config.train_data.sources
|
||||
]
|
||||
)
|
||||
|
||||
# Apply the transform to the concatenated dataset
|
||||
train_dataset = train_dataset.with_transform(
|
||||
partial(
|
||||
batch_prepare_data_for_qwen2_training,
|
||||
processor=processor,
|
||||
target_longest_image_dim=target_longest_image_dim,
|
||||
target_anchor_text_len=target_anchor_text_len,
|
||||
)
|
||||
)
|
||||
|
||||
# Validation sets get put into a datasetdict so each can report a loss separately
|
||||
valid_dataset = DatasetDict(
|
||||
**{
|
||||
|
||||
@ -24,6 +24,7 @@ dependencies = [
|
||||
"pypdf",
|
||||
"pymupdf",
|
||||
"pypdfium2",
|
||||
"cryptography",
|
||||
"lingua-language-detector",
|
||||
"Pillow",
|
||||
"ftfy",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user