mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-03 19:45:41 +00:00 
			
		
		
		
	fixed dotsocr runner
This commit is contained in:
		
							parent
							
								
									4f7623c429
								
							
						
					
					
						commit
						68defa23d7
					
				@ -1,5 +1,4 @@
 | 
			
		||||
import base64
 | 
			
		||||
import os
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -9,17 +8,12 @@ from qwen_vl_utils import process_vision_info
 | 
			
		||||
 | 
			
		||||
from olmocr.data.renderpdf import render_pdf_to_base64png
 | 
			
		||||
 | 
			
		||||
# Set LOCAL_RANK as required by DotsOCR
 | 
			
		||||
if "LOCAL_RANK" not in os.environ:
 | 
			
		||||
    os.environ["LOCAL_RANK"] = "0"
 | 
			
		||||
 | 
			
		||||
# Global cache for the model and processor.
 | 
			
		||||
_device = "cuda" if torch.cuda.is_available() else "cpu"
 | 
			
		||||
_model = None
 | 
			
		||||
_processor = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model(model_name: str = "rednote-hilab/dots.ocr"):
 | 
			
		||||
def load_model(model_name: str = "./weights/DotsOCR"):
 | 
			
		||||
    """
 | 
			
		||||
    Load the DotsOCR model and processor if they haven't been loaded already.
 | 
			
		||||
 | 
			
		||||
@ -32,12 +26,12 @@ def load_model(model_name: str = "rednote-hilab/dots.ocr"):
 | 
			
		||||
    """
 | 
			
		||||
    global _model, _processor
 | 
			
		||||
    if _model is None or _processor is None:
 | 
			
		||||
        # Load model following the official repo pattern
 | 
			
		||||
        _model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
            model_name,
 | 
			
		||||
            attn_implementation="flash_attention_2",
 | 
			
		||||
            torch_dtype=torch.bfloat16,
 | 
			
		||||
            device_map="auto",
 | 
			
		||||
            attn_implementation="flash_attention_2",
 | 
			
		||||
            low_cpu_mem_usage=True,
 | 
			
		||||
            trust_remote_code=True
 | 
			
		||||
        )
 | 
			
		||||
        _processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
 | 
			
		||||
@ -47,7 +41,7 @@ def load_model(model_name: str = "rednote-hilab/dots.ocr"):
 | 
			
		||||
def run_dotsocr(
 | 
			
		||||
    pdf_path: str,
 | 
			
		||||
    page_num: int = 1,
 | 
			
		||||
    model_name: str = "rednote-hilab/dots.ocr",
 | 
			
		||||
    model_name: str = "./weights/DotsOCR",
 | 
			
		||||
    target_longest_image_dim: int = 1024
 | 
			
		||||
) -> str:
 | 
			
		||||
    """
 | 
			
		||||
@ -59,7 +53,7 @@ def run_dotsocr(
 | 
			
		||||
    Args:
 | 
			
		||||
        pdf_path (str): The local path to the PDF file.
 | 
			
		||||
        page_num (int): The page number to process (default: 1).
 | 
			
		||||
        model_name (str): Hugging Face model name (default: "rednote-hilab/dots.ocr").
 | 
			
		||||
        model_name (str): Hugging Face model name (default: "./weights/DotsOCR").
 | 
			
		||||
        target_longest_image_dim (int): Target dimension for the longest side of the image (default: 1024).
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
@ -75,24 +69,7 @@ def run_dotsocr(
 | 
			
		||||
    image = Image.open(BytesIO(base64.b64decode(image_base64)))
 | 
			
		||||
 | 
			
		||||
    # Define the prompt for layout extraction
 | 
			
		||||
    prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
 | 
			
		||||
 | 
			
		||||
1. Bbox format: [x1, y1, x2, y2]
 | 
			
		||||
 | 
			
		||||
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
 | 
			
		||||
 | 
			
		||||
3. Text Extraction & Formatting Rules:
 | 
			
		||||
    - Picture: For the 'Picture' category, the text field should be omitted.
 | 
			
		||||
    - Formula: Format its text as LaTeX.
 | 
			
		||||
    - Table: Format its text as HTML.
 | 
			
		||||
    - All Others (Text, Title, etc.): Format their text as Markdown.
 | 
			
		||||
 | 
			
		||||
4. Constraints:
 | 
			
		||||
    - The output text must be the original text from the image, with no translation.
 | 
			
		||||
    - All layout elements must be sorted according to human reading order.
 | 
			
		||||
 | 
			
		||||
5. Final Output: The entire output must be a single JSON object.
 | 
			
		||||
"""
 | 
			
		||||
    prompt = """Extract the text content from this image."""
 | 
			
		||||
 | 
			
		||||
    messages = [
 | 
			
		||||
        {
 | 
			
		||||
@ -126,8 +103,8 @@ def run_dotsocr(
 | 
			
		||||
    inputs = inputs.to("cuda")
 | 
			
		||||
 | 
			
		||||
    # Inference: Generation of the output
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        generated_ids = model.generate(**inputs, max_new_tokens=24000)
 | 
			
		||||
    # with torch.no_grad():
 | 
			
		||||
    generated_ids = model.generate(**inputs, max_new_tokens=4096)
 | 
			
		||||
 | 
			
		||||
    generated_ids_trimmed = [
 | 
			
		||||
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user