This commit is contained in:
Jake Poznanski 2025-08-26 20:06:52 +00:00
parent 98f4d62d1e
commit 8327da2415

View File

@ -473,6 +473,13 @@ def main():
trust_remote_code=True,
)
# Load model
logger.info(f"Loading model: {args.model_name}")
if "Qwen2-VL" in args.model_name:
model_class = Qwen2VLForConditionalGeneration
else:
model_class = Qwen2_5_VLForConditionalGeneration
model = model_class.from_pretrained(
args.model_name,
torch_dtype=torch.bfloat16,