mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-17 11:12:33 +00:00
Fix for languages "no" in yaml
This commit is contained in:
parent
df52cb0e0e
commit
c1061146e5
@ -7,8 +7,8 @@ import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
@ -16,8 +16,8 @@ from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.prompts.prompts import (
|
||||
PageResponse,
|
||||
build_finetuning_prompt,
|
||||
build_openai_silver_data_prompt,
|
||||
build_no_anchoring_yaml_prompt,
|
||||
build_openai_silver_data_prompt,
|
||||
)
|
||||
from olmocr.train.dataloader import FrontMatterParser
|
||||
|
||||
@ -52,13 +52,10 @@ def run_transformers(
|
||||
|
||||
if _cached_model is None:
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="flash_attention_2"
|
||||
model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2"
|
||||
).eval()
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
_cached_model = model
|
||||
@ -69,7 +66,7 @@ def run_transformers(
|
||||
|
||||
# Convert the first page of the PDF to a base64-encoded PNG image.
|
||||
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=target_longest_image_dim)
|
||||
|
||||
|
||||
if prompt_template == "yaml":
|
||||
prompt = build_no_anchoring_yaml_prompt()
|
||||
else:
|
||||
|
@ -9,7 +9,19 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeAlias
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeAlias,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
@ -166,6 +178,12 @@ class FrontMatterParser(PipelineStep):
|
||||
|
||||
front_matter_class: Optional[Type] = None
|
||||
|
||||
def _is_optional_str(self, field_type: Type) -> bool:
|
||||
"""Check if a type is Optional[str]."""
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
return origin is Union and type(None) in args and str in args
|
||||
|
||||
def _extract_front_matter_and_text(self, markdown_content: str) -> tuple[Dict[str, Any], str]:
|
||||
"""Extract YAML front matter and text from markdown content."""
|
||||
if markdown_content.startswith("---\n"):
|
||||
@ -210,8 +228,14 @@ class FrontMatterParser(PipelineStep):
|
||||
kwargs[field_name] = int(value)
|
||||
elif field_type is bool and isinstance(value, str):
|
||||
kwargs[field_name] = value.lower() == "true"
|
||||
elif field_type is Optional[str]:
|
||||
kwargs[field_name] = value if value else None
|
||||
elif self._is_optional_str(field_type):
|
||||
# Handle boolean values that YAML might produce (e.g., 'no' -> False)
|
||||
if isinstance(value, bool):
|
||||
kwargs[field_name] = None
|
||||
elif isinstance(value, str):
|
||||
kwargs[field_name] = value if value else None
|
||||
else:
|
||||
kwargs[field_name] = None if not value else value
|
||||
else:
|
||||
kwargs[field_name] = value
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user