Fixes to train.py

This commit is contained in:
Jake Poznanski 2025-10-28 22:15:45 +00:00
parent 91962d64e2
commit 4768db63d2

View File

@ -271,7 +271,7 @@ def main():
# Load model
logger.info(f"Loading model: {config.model.name}")
if "qwen2.5-vl" in config.model.name.lower():
if "qwen2.5-vl" in config.model.name.lower() or "olmocr-2-7b-1025" in config.model.name.lower():
model_class = Qwen2_5_VLForConditionalGeneration
model = model_class.from_pretrained(config.model.name, **model_init_kwargs)
elif "qwen2-vl" in config.model.name.lower():