mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-15 10:12:14 +00:00
Small edits
This commit is contained in:
parent
46ffbe9324
commit
14e3f6e97b
@ -29,7 +29,7 @@ def init_model(model_name: str = "ds4sd/SmolDocling-256M-preview"):
|
|||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
# _attn_implementation="flash_attention_2" if device.type == "cuda" else "eager",
|
# _attn_implementation="flash_attention_2" if device.type == "cuda" else "eager",
|
||||||
_attn_implementation="eager",
|
_attn_implementation="eager",
|
||||||
).to(device)
|
).eval().to(device)
|
||||||
|
|
||||||
_cached_model = model
|
_cached_model = model
|
||||||
_cached_processor = processor
|
_cached_processor = processor
|
||||||
|
@ -47,7 +47,7 @@ def run_transformers(
|
|||||||
if _cached_model is None:
|
if _cached_model is None:
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(model, torch_dtype=torch.bfloat16).eval()
|
model = Qwen2VLForConditionalGeneration.from_pretrained(model, torch_dtype=torch.bfloat16).eval()
|
||||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
_cached_model = model
|
_cached_model = model
|
||||||
_cached_processor = processor
|
_cached_processor = processor
|
||||||
|
Loading…
x
Reference in New Issue
Block a user