Fix for languages "no" in yaml

This commit is contained in:
Jake Poznanski 2025-07-30 21:48:20 +00:00
parent df52cb0e0e
commit c1061146e5
2 changed files with 32 additions and 11 deletions

View File

@ -7,8 +7,8 @@ import torch
from PIL import Image
from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
)
from olmocr.data.renderpdf import render_pdf_to_base64png
@ -16,8 +16,8 @@ from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import (
PageResponse,
build_finetuning_prompt,
build_openai_silver_data_prompt,
build_no_anchoring_yaml_prompt,
build_openai_silver_data_prompt,
)
from olmocr.train.dataloader import FrontMatterParser
@ -52,13 +52,10 @@ def run_transformers(
if _cached_model is None:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2"
model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2"
).eval()
processor = AutoProcessor.from_pretrained(model_name)
model = model.to(device)
_cached_model = model
@ -69,7 +66,7 @@ def run_transformers(
# Convert the first page of the PDF to a base64-encoded PNG image.
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=target_longest_image_dim)
if prompt_template == "yaml":
prompt = build_no_anchoring_yaml_prompt()
else:

View File

@ -9,7 +9,19 @@ from functools import reduce
from io import BytesIO
from os import PathLike
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeAlias
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeAlias,
Union,
get_args,
get_origin,
)
import numpy as np
import yaml
@ -166,6 +178,12 @@ class FrontMatterParser(PipelineStep):
front_matter_class: Optional[Type] = None
def _is_optional_str(self, field_type: Type) -> bool:
"""Check if a type is Optional[str]."""
origin = get_origin(field_type)
args = get_args(field_type)
return origin is Union and type(None) in args and str in args
def _extract_front_matter_and_text(self, markdown_content: str) -> tuple[Dict[str, Any], str]:
"""Extract YAML front matter and text from markdown content."""
if markdown_content.startswith("---\n"):
@ -210,8 +228,14 @@ class FrontMatterParser(PipelineStep):
kwargs[field_name] = int(value)
elif field_type is bool and isinstance(value, str):
kwargs[field_name] = value.lower() == "true"
elif field_type is Optional[str]:
kwargs[field_name] = value if value else None
elif self._is_optional_str(field_type):
# Handle boolean values that YAML might produce (e.g., 'no' -> False)
if isinstance(value, bool):
kwargs[field_name] = None
elif isinstance(value, str):
kwargs[field_name] = value if value else None
else:
kwargs[field_name] = None if not value else value
else:
kwargs[field_name] = value