SOuping in fp32

This commit is contained in:
Jake Poznanski 2025-09-26 20:03:29 +00:00
parent 3d6e6a6a01
commit bb06829840
2 changed files with 27 additions and 1 deletions

View File

@ -340,29 +340,55 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
if file_path.endswith(".safetensors"):
sum_state = load_file(file_path, device="cpu")
# Store original dtypes for each tensor
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
# Upconvert to at least fp32 for accurate averaging
for k in sum_state:
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
sum_state[k] = sum_state[k].to(torch.float32)
for other_path in all_paths[1:]:
other_state = load_file(other_path, device="cpu")
if set(sum_state.keys()) != set(other_state.keys()):
raise ValueError(f"Key mismatch in {rel_path}")
for k in sum_state:
# Upconvert other state to match sum_state dtype
if other_state[k].dtype in (torch.float16, torch.bfloat16):
other_state[k] = other_state[k].to(torch.float32)
sum_state[k] += other_state[k]
del other_state
n = len(all_paths)
for k in sum_state:
sum_state[k] /= n
# Cast back to original dtype
sum_state[k] = sum_state[k].to(original_dtypes[k])
save_file(sum_state, souped_path)
elif file_path.endswith(".bin"):
sum_state = torch.load(file_path, map_location="cpu")
# Store original dtypes for each tensor
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
# Upconvert to at least fp32 for accurate averaging
for k in sum_state:
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
sum_state[k] = sum_state[k].to(torch.float32)
for other_path in all_paths[1:]:
other_state = torch.load(other_path, map_location="cpu")
if set(sum_state.keys()) != set(other_state.keys()):
raise ValueError(f"Key mismatch in {rel_path}")
for k in sum_state:
# Upconvert other state to match sum_state dtype
if other_state[k].dtype in (torch.float16, torch.bfloat16):
other_state[k] = other_state[k].to(torch.float32)
sum_state[k] += other_state[k]
del other_state
n = len(all_paths)
for k in sum_state:
sum_state[k] /= n
# Cast back to original dtype
sum_state[k] = sum_state[k].to(original_dtypes[k])
torch.save(sum_state, souped_path)
else:
print(f"Skipping unknown weight file: {rel_path}")

View File

@ -163,7 +163,7 @@ for i in range(len(modified_args)):
setup_commands = [
# Install dependencies
"pip install .[train]",
"pip install trl wandb",
"pip install trl==0.23.0 wandb",
"pip install transformers==4.55.2", # Updated for GRPO compatibility
"pip install flash-attn==2.8.0.post2 --no-build-isolation",
"pip install vllm==v0.10.1.1",