mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-13 00:49:28 +00:00
fixed dotsocr runner
This commit is contained in:
parent
4f7623c429
commit
68defa23d7
@ -1,5 +1,4 @@
|
|||||||
import base64
|
import base64
|
||||||
import os
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -9,17 +8,12 @@ from qwen_vl_utils import process_vision_info
|
|||||||
|
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||||
|
|
||||||
# Set LOCAL_RANK as required by DotsOCR
|
|
||||||
if "LOCAL_RANK" not in os.environ:
|
|
||||||
os.environ["LOCAL_RANK"] = "0"
|
|
||||||
|
|
||||||
# Global cache for the model and processor.
|
|
||||||
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
_model = None
|
_model = None
|
||||||
_processor = None
|
_processor = None
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name: str = "rednote-hilab/dots.ocr"):
|
def load_model(model_name: str = "./weights/DotsOCR"):
|
||||||
"""
|
"""
|
||||||
Load the DotsOCR model and processor if they haven't been loaded already.
|
Load the DotsOCR model and processor if they haven't been loaded already.
|
||||||
|
|
||||||
@ -32,12 +26,12 @@ def load_model(model_name: str = "rednote-hilab/dots.ocr"):
|
|||||||
"""
|
"""
|
||||||
global _model, _processor
|
global _model, _processor
|
||||||
if _model is None or _processor is None:
|
if _model is None or _processor is None:
|
||||||
# Load model following the official repo pattern
|
|
||||||
_model = AutoModelForCausalLM.from_pretrained(
|
_model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
trust_remote_code=True
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
_processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
_processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
@ -47,7 +41,7 @@ def load_model(model_name: str = "rednote-hilab/dots.ocr"):
|
|||||||
def run_dotsocr(
|
def run_dotsocr(
|
||||||
pdf_path: str,
|
pdf_path: str,
|
||||||
page_num: int = 1,
|
page_num: int = 1,
|
||||||
model_name: str = "rednote-hilab/dots.ocr",
|
model_name: str = "./weights/DotsOCR",
|
||||||
target_longest_image_dim: int = 1024
|
target_longest_image_dim: int = 1024
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@ -59,7 +53,7 @@ def run_dotsocr(
|
|||||||
Args:
|
Args:
|
||||||
pdf_path (str): The local path to the PDF file.
|
pdf_path (str): The local path to the PDF file.
|
||||||
page_num (int): The page number to process (default: 1).
|
page_num (int): The page number to process (default: 1).
|
||||||
model_name (str): Hugging Face model name (default: "rednote-hilab/dots.ocr").
|
model_name (str): Hugging Face model name (default: "./weights/DotsOCR").
|
||||||
target_longest_image_dim (int): Target dimension for the longest side of the image (default: 1024).
|
target_longest_image_dim (int): Target dimension for the longest side of the image (default: 1024).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -75,24 +69,7 @@ def run_dotsocr(
|
|||||||
image = Image.open(BytesIO(base64.b64decode(image_base64)))
|
image = Image.open(BytesIO(base64.b64decode(image_base64)))
|
||||||
|
|
||||||
# Define the prompt for layout extraction
|
# Define the prompt for layout extraction
|
||||||
prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
|
prompt = """Extract the text content from this image."""
|
||||||
|
|
||||||
1. Bbox format: [x1, y1, x2, y2]
|
|
||||||
|
|
||||||
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
|
|
||||||
|
|
||||||
3. Text Extraction & Formatting Rules:
|
|
||||||
- Picture: For the 'Picture' category, the text field should be omitted.
|
|
||||||
- Formula: Format its text as LaTeX.
|
|
||||||
- Table: Format its text as HTML.
|
|
||||||
- All Others (Text, Title, etc.): Format their text as Markdown.
|
|
||||||
|
|
||||||
4. Constraints:
|
|
||||||
- The output text must be the original text from the image, with no translation.
|
|
||||||
- All layout elements must be sorted according to human reading order.
|
|
||||||
|
|
||||||
5. Final Output: The entire output must be a single JSON object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@ -126,8 +103,8 @@ def run_dotsocr(
|
|||||||
inputs = inputs.to("cuda")
|
inputs = inputs.to("cuda")
|
||||||
|
|
||||||
# Inference: Generation of the output
|
# Inference: Generation of the output
|
||||||
with torch.no_grad():
|
# with torch.no_grad():
|
||||||
generated_ids = model.generate(**inputs, max_new_tokens=24000)
|
generated_ids = model.generate(**inputs, max_new_tokens=4096)
|
||||||
|
|
||||||
generated_ids_trimmed = [
|
generated_ids_trimmed = [
|
||||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user