mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-10 15:52:31 +00:00
SOuping in fp32
This commit is contained in:
parent
3d6e6a6a01
commit
bb06829840
@ -340,29 +340,55 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
|||||||
|
|
||||||
if file_path.endswith(".safetensors"):
|
if file_path.endswith(".safetensors"):
|
||||||
sum_state = load_file(file_path, device="cpu")
|
sum_state = load_file(file_path, device="cpu")
|
||||||
|
# Store original dtypes for each tensor
|
||||||
|
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
|
||||||
|
# Upconvert to at least fp32 for accurate averaging
|
||||||
|
for k in sum_state:
|
||||||
|
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||||
|
sum_state[k] = sum_state[k].to(torch.float32)
|
||||||
|
|
||||||
for other_path in all_paths[1:]:
|
for other_path in all_paths[1:]:
|
||||||
other_state = load_file(other_path, device="cpu")
|
other_state = load_file(other_path, device="cpu")
|
||||||
if set(sum_state.keys()) != set(other_state.keys()):
|
if set(sum_state.keys()) != set(other_state.keys()):
|
||||||
raise ValueError(f"Key mismatch in {rel_path}")
|
raise ValueError(f"Key mismatch in {rel_path}")
|
||||||
for k in sum_state:
|
for k in sum_state:
|
||||||
|
# Upconvert other state to match sum_state dtype
|
||||||
|
if other_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||||
|
other_state[k] = other_state[k].to(torch.float32)
|
||||||
sum_state[k] += other_state[k]
|
sum_state[k] += other_state[k]
|
||||||
del other_state
|
del other_state
|
||||||
|
|
||||||
n = len(all_paths)
|
n = len(all_paths)
|
||||||
for k in sum_state:
|
for k in sum_state:
|
||||||
sum_state[k] /= n
|
sum_state[k] /= n
|
||||||
|
# Cast back to original dtype
|
||||||
|
sum_state[k] = sum_state[k].to(original_dtypes[k])
|
||||||
save_file(sum_state, souped_path)
|
save_file(sum_state, souped_path)
|
||||||
elif file_path.endswith(".bin"):
|
elif file_path.endswith(".bin"):
|
||||||
sum_state = torch.load(file_path, map_location="cpu")
|
sum_state = torch.load(file_path, map_location="cpu")
|
||||||
|
# Store original dtypes for each tensor
|
||||||
|
original_dtypes = {k: v.dtype for k, v in sum_state.items()}
|
||||||
|
# Upconvert to at least fp32 for accurate averaging
|
||||||
|
for k in sum_state:
|
||||||
|
if sum_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||||
|
sum_state[k] = sum_state[k].to(torch.float32)
|
||||||
|
|
||||||
for other_path in all_paths[1:]:
|
for other_path in all_paths[1:]:
|
||||||
other_state = torch.load(other_path, map_location="cpu")
|
other_state = torch.load(other_path, map_location="cpu")
|
||||||
if set(sum_state.keys()) != set(other_state.keys()):
|
if set(sum_state.keys()) != set(other_state.keys()):
|
||||||
raise ValueError(f"Key mismatch in {rel_path}")
|
raise ValueError(f"Key mismatch in {rel_path}")
|
||||||
for k in sum_state:
|
for k in sum_state:
|
||||||
|
# Upconvert other state to match sum_state dtype
|
||||||
|
if other_state[k].dtype in (torch.float16, torch.bfloat16):
|
||||||
|
other_state[k] = other_state[k].to(torch.float32)
|
||||||
sum_state[k] += other_state[k]
|
sum_state[k] += other_state[k]
|
||||||
del other_state
|
del other_state
|
||||||
|
|
||||||
n = len(all_paths)
|
n = len(all_paths)
|
||||||
for k in sum_state:
|
for k in sum_state:
|
||||||
sum_state[k] /= n
|
sum_state[k] /= n
|
||||||
|
# Cast back to original dtype
|
||||||
|
sum_state[k] = sum_state[k].to(original_dtypes[k])
|
||||||
torch.save(sum_state, souped_path)
|
torch.save(sum_state, souped_path)
|
||||||
else:
|
else:
|
||||||
print(f"Skipping unknown weight file: {rel_path}")
|
print(f"Skipping unknown weight file: {rel_path}")
|
||||||
|
@ -163,7 +163,7 @@ for i in range(len(modified_args)):
|
|||||||
setup_commands = [
|
setup_commands = [
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
"pip install .[train]",
|
"pip install .[train]",
|
||||||
"pip install trl wandb",
|
"pip install trl==0.23.0 wandb",
|
||||||
"pip install transformers==4.55.2", # Updated for GRPO compatibility
|
"pip install transformers==4.55.2", # Updated for GRPO compatibility
|
||||||
"pip install flash-attn==2.8.0.post2 --no-build-isolation",
|
"pip install flash-attn==2.8.0.post2 --no-build-isolation",
|
||||||
"pip install vllm==v0.10.1.1",
|
"pip install vllm==v0.10.1.1",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user