mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-09 23:32:02 +00:00
SOuping in fp32
This commit is contained in:
parent
3d6e6a6a01
commit
bb06829840
@ -340,29 +340,55 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
|
||||
if file_path.endswith(".safetensors"):
|
||||
sum_state = load_file(file_path, device="cpu")
|
||||
# Store original dtypes for each tensor
|
||||
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
|
||||
# Upconvert to at least fp32 for accurate averaging
|
||||
for k in sum_state:
|
||||
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
sum_state[k] = sum_state[k].to(torch.float32)
|
||||
|
||||
for other_path in all_paths[1:]:
|
||||
other_state = load_file(other_path, device="cpu")
|
||||
if set(sum_state.keys()) != set(other_state.keys()):
|
||||
raise ValueError(f"Key mismatch in {rel_path}")
|
||||
for k in sum_state:
|
||||
# Upconvert other state to match sum_state dtype
|
||||
if other_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
other_state[k] = other_state[k].to(torch.float32)
|
||||
sum_state[k] += other_state[k]
|
||||
del other_state
|
||||
|
||||
n = len(all_paths)
|
||||
for k in sum_state:
|
||||
sum_state[k] /= n
|
||||
# Cast back to original dtype
|
||||
sum_state[k] = sum_state[k].to(original_dtypes[k])
|
||||
save_file(sum_state, souped_path)
|
||||
elif file_path.endswith(".bin"):
|
||||
sum_state = torch.load(file_path, map_location="cpu")
|
||||
# Store original dtypes for each tensor
|
||||
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
|
||||
# Upconvert to at least fp32 for accurate averaging
|
||||
for k in sum_state:
|
||||
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
sum_state[k] = sum_state[k].to(torch.float32)
|
||||
|
||||
for other_path in all_paths[1:]:
|
||||
other_state = torch.load(other_path, map_location="cpu")
|
||||
if set(sum_state.keys()) != set(other_state.keys()):
|
||||
raise ValueError(f"Key mismatch in {rel_path}")
|
||||
for k in sum_state:
|
||||
# Upconvert other state to match sum_state dtype
|
||||
if other_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||
other_state[k] = other_state[k].to(torch.float32)
|
||||
sum_state[k] += other_state[k]
|
||||
del other_state
|
||||
|
||||
n = len(all_paths)
|
||||
for k in sum_state:
|
||||
sum_state[k] /= n
|
||||
# Cast back to original dtype
|
||||
sum_state[k] = sum_state[k].to(original_dtypes[k])
|
||||
torch.save(sum_state, souped_path)
|
||||
else:
|
||||
print(f"Skipping unknown weight file: {rel_path}")
|
||||
|
@ -163,7 +163,7 @@ for i in range(len(modified_args)):
|
||||
setup_commands = [
|
||||
# Install dependencies
|
||||
"pip install .[train]",
|
||||
"pip install trl wandb",
|
||||
"pip install trl==0.23.0 wandb",
|
||||
"pip install transformers==4.55.2", # Updated for GRPO compatibility
|
||||
"pip install flash-attn==2.8.0.post2 --no-build-isolation",
|
||||
"pip install vllm==v0.10.1.1",
|
||||
|
Loading…
x
Reference in New Issue
Block a user