diff --git a/olmocr/train/prepare_checkpoint.py b/olmocr/train/prepare_checkpoint.py index 7257a1e..b1927c1 100755 --- a/olmocr/train/prepare_checkpoint.py +++ b/olmocr/train/prepare_checkpoint.py @@ -340,29 +340,55 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None: if file_path.endswith(".safetensors"): sum_state = load_file(file_path, device="cpu") + # Store original dtypes for each tensor + original_dtypes = {k: v.dtype for k, v in sum_state.items()} + # Upconvert to at least fp32 for accurate averaging + for k in sum_state: + if sum_state[k].dtype in (torch.float16, torch.bfloat16): + sum_state[k] = sum_state[k].to(torch.float32) + for other_path in all_paths[1:]: other_state = load_file(other_path, device="cpu") if set(sum_state.keys()) != set(other_state.keys()): raise ValueError(f"Key mismatch in {rel_path}") for k in sum_state: + # Upconvert other state to match sum_state dtype + if other_state[k].dtype in (torch.float16, torch.bfloat16): + other_state[k] = other_state[k].to(torch.float32) sum_state[k] += other_state[k] del other_state + n = len(all_paths) for k in sum_state: sum_state[k] /= n + # Cast back to original dtype + sum_state[k] = sum_state[k].to(original_dtypes[k]) save_file(sum_state, souped_path) elif file_path.endswith(".bin"): sum_state = torch.load(file_path, map_location="cpu") + # Store original dtypes for each tensor + original_dtypes = {k: v.dtype for k, v in sum_state.items()} + # Upconvert to at least fp32 for accurate averaging + for k in sum_state: + if sum_state[k].dtype in (torch.float16, torch.bfloat16): + sum_state[k] = sum_state[k].to(torch.float32) + for other_path in all_paths[1:]: other_state = torch.load(other_path, map_location="cpu") if set(sum_state.keys()) != set(other_state.keys()): raise ValueError(f"Key mismatch in {rel_path}") for k in sum_state: + # Upconvert other state to match sum_state dtype + if other_state[k].dtype in (torch.float16, torch.bfloat16): + other_state[k] = other_state[k].to(torch.float32) sum_state[k] += other_state[k] del other_state + n = len(all_paths) for k in sum_state: sum_state[k] /= n + # Cast back to original dtype + sum_state[k] = sum_state[k].to(original_dtypes[k]) torch.save(sum_state, souped_path) else: print(f"Skipping unknown weight file: {rel_path}") diff --git a/scripts/train/grpotrainer-beaker-multi-gpu.sh b/scripts/train/grpotrainer-beaker-multi-gpu.sh index 0cdc18c..fec07ce 100755 --- a/scripts/train/grpotrainer-beaker-multi-gpu.sh +++ b/scripts/train/grpotrainer-beaker-multi-gpu.sh @@ -163,7 +163,7 @@ for i in range(len(modified_args)): setup_commands = [ # Install dependencies "pip install .[train]", - "pip install trl wandb", + "pip install trl==0.23.0 wandb", "pip install transformers==4.55.2", # Updated for GRPO compatibility "pip install flash-attn==2.8.0.post2 --no-build-isolation", "pip install vllm==v0.10.1.1",