mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-24 21:57:13 +00:00
Starting batch inference script to measure performance, train script using proper model from config now
This commit is contained in:
parent
b0777dcb87
commit
45f691c718
101
pdelfin/train/batch_inference.py
Normal file
101
pdelfin/train/batch_inference.py
Normal file
@ -0,0 +1,101 @@
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from functools import partial
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional
|
||||
from tqdm import tqdm
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed
|
||||
from datasets.utils import disable_progress_bars
|
||||
from datasets.utils.logging import set_verbosity
|
||||
from peft import LoraConfig, get_peft_model # pyright: ignore
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
AutoProcessor
|
||||
)
|
||||
from transformers.integrations import WandbCallback
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import wandb
|
||||
|
||||
from pdelfin.train.core.cli import make_cli, save_config, to_native_types
|
||||
from pdelfin.train.core.config import TrainConfig
|
||||
from pdelfin.train.core.loggers import get_logger
|
||||
from pdelfin.train.core.paths import copy_dir, join_path
|
||||
from pdelfin.train.core.state import BeakerState
|
||||
|
||||
from .utils import (
|
||||
RunName,
|
||||
get_local_dir,
|
||||
log_trainable_parameters,
|
||||
packing_collator,
|
||||
setup_environment,
|
||||
)
|
||||
|
||||
|
||||
from pdelfin.train.dataloader import make_dataset
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
|
||||
|
||||
|
||||
def run_train(model_name: str, dataset_path: str):
|
||||
if get_rank() == 0:
|
||||
logger_level = logging.INFO
|
||||
else:
|
||||
logger_level = logging.WARN
|
||||
disable_progress_bars()
|
||||
|
||||
logger = get_logger(__name__, level=logger_level)
|
||||
set_verbosity(logger_level)
|
||||
|
||||
dataset = make_dataset(
|
||||
train_data_config=config.train_data,
|
||||
valid_data_config=config.valid_data,
|
||||
num_proc=config.num_proc,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch.bfloat16, device_map="auto",
|
||||
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
|
||||
formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
print(formatted_dataset)
|
||||
print("---------------")
|
||||
|
||||
|
||||
with TemporaryDirectory() as output_dir:
|
||||
|
||||
|
||||
|
||||
# Uncomment to test speed of data loader
|
||||
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
|
||||
# for entry in tqdm(train_dataloader):
|
||||
# print("Step!")
|
||||
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
|
||||
|
||||
|
||||
def main():
|
||||
run_inference(model_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
dataset_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -120,10 +120,10 @@ def run_train(config: TrainConfig):
|
||||
)
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto",
|
||||
config.model.name_or_path, torch_dtype=torch.bfloat16, device_map="auto",
|
||||
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
processor = AutoProcessor.from_pretrained(config.model.name_or_path)
|
||||
|
||||
if config.lora is not None:
|
||||
peft_config = LoraConfig(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user