mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-27 15:14:43 +00:00
Hopefully fixing dataloader for now
This commit is contained in:
parent
5d35461dd2
commit
ebd40f9084
@ -47,6 +47,7 @@ from .utils import (
|
||||
log_trainable_parameters,
|
||||
packing_collator,
|
||||
setup_environment,
|
||||
make_dataset
|
||||
)
|
||||
|
||||
|
||||
@ -113,6 +114,15 @@ def run_train(config: TrainConfig):
|
||||
run_name = RunName.get(config)
|
||||
|
||||
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
|
||||
accelerator = accelerate.Accelerator()
|
||||
|
||||
# Build and download the dataset on process 0
|
||||
if accelerator.is_main_process:
|
||||
make_dataset(config)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
train_dataset, valid_dataset = make_dataset(config)
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
config.model.name_or_path, torch_dtype=torch.bfloat16,
|
||||
@ -132,30 +142,6 @@ def run_train(config: TrainConfig):
|
||||
model = get_peft_model(model=model, peft_config=peft_config)
|
||||
log_trainable_parameters(model=model, logger=logger)
|
||||
|
||||
random.seed(config.train_data.seed)
|
||||
|
||||
# Training sets get all concatenated and shuffled
|
||||
train_dataset = (
|
||||
concatenate_datasets(
|
||||
[
|
||||
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
for source in config.train_data.sources
|
||||
]
|
||||
)
|
||||
.filter(partial(filter_by_max_seq_len, processor=processor))
|
||||
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
)
|
||||
|
||||
# Validation sets get put into a datasetdict so each can report a loss separately
|
||||
valid_dataset = DatasetDict(
|
||||
**{
|
||||
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
.filter(partial(filter_by_max_seq_len, processor=processor))
|
||||
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
for source in config.valid_data.sources
|
||||
}
|
||||
)
|
||||
|
||||
save_path = join_path("", config.save.path, run_name.run)
|
||||
|
||||
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
@ -13,6 +14,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import PrecisionType
|
||||
from datasets import Dataset, concatenate_datasets, DatasetDict
|
||||
|
||||
from .core.cli import to_native_types
|
||||
from .core.config import AwsConfig, TrainConfig, WandbConfig
|
||||
@ -23,6 +25,9 @@ from .core.state import BeakerState
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len
|
||||
|
||||
|
||||
def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
|
||||
pt = PrecisionType(accelerator.mixed_precision)
|
||||
@ -34,6 +39,33 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
|
||||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
def make_dataset(config: TrainConfig) -> tuple[Dataset, Dataset]:
|
||||
random.seed(config.train_data.seed)
|
||||
|
||||
# Training sets get all concatenated and shuffled
|
||||
train_dataset = (
|
||||
concatenate_datasets(
|
||||
[
|
||||
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
for source in config.train_data.sources
|
||||
]
|
||||
)
|
||||
.filter(partial(filter_by_max_seq_len, processor=processor))
|
||||
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
)
|
||||
|
||||
# Validation sets get put into a datasetdict so each can report a loss separately
|
||||
valid_dataset = DatasetDict(
|
||||
**{
|
||||
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
.filter(partial(filter_by_max_seq_len, processor=processor))
|
||||
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
for source in config.valid_data.sources
|
||||
}
|
||||
)
|
||||
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
def setup_environment(
|
||||
aws_config: Optional[AwsConfig] = None, wandb_config: Optional[WandbConfig] = None, **kwargs: str
|
||||
|
||||
@ -37,6 +37,21 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
|
||||
print(ds[0])
|
||||
|
||||
def testLocalDS(self):
|
||||
ds = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="/root/openai_batch_data_v5_1_train/*.jsonl",
|
||||
response_glob_path="/root/openai_batch_data_v5_1_train_done/*.json",
|
||||
)
|
||||
|
||||
print(ds)
|
||||
|
||||
ds.to_parquet("/root/trainds_parquet/bigds.parquet")
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
from pdelfin.train.dataprep import filter_by_max_seq_len
|
||||
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))
|
||||
|
||||
print(ds[0])
|
||||
|
||||
def testPlotSequenceLengthHistogram(self):
|
||||
import plotly.express as px
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user