mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-25 06:06:23 +00:00
Inference test for qwen2 and 2.5, work queue fixes, build current still broken
This commit is contained in:
parent
4d0d9246b4
commit
00e3aac058
@ -16,6 +16,7 @@ import datetime
|
||||
import tempfile
|
||||
import random
|
||||
import re
|
||||
import glob
|
||||
import torch
|
||||
import multiprocessing
|
||||
|
||||
@ -463,19 +464,23 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
|
||||
|
||||
|
||||
async def sglang_server_task(args, semaphore):
|
||||
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
|
||||
download_directory(args.model, model_cache_dir)
|
||||
model_name_or_path = args.model
|
||||
|
||||
# Check the rope config and make sure it's got the proper key
|
||||
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
|
||||
config_data = json.load(cfin)
|
||||
# if "://" in model_name_or_path:
|
||||
# # TODO, Fix this code so that we support the multiple s3/weka paths, or else remove it
|
||||
# model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
|
||||
# download_directory(model_name_or_path, model_cache_dir)
|
||||
|
||||
if "rope_type" in config_data["rope_scaling"]:
|
||||
del config_data["rope_scaling"]["rope_type"]
|
||||
config_data["rope_scaling"]["type"] = "mrope"
|
||||
# # Check the rope config and make sure it's got the proper key
|
||||
# with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
|
||||
# config_data = json.load(cfin)
|
||||
|
||||
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
|
||||
json.dump(config_data, cfout)
|
||||
# if "rope_type" in config_data["rope_scaling"]:
|
||||
# del config_data["rope_scaling"]["rope_type"]
|
||||
# config_data["rope_scaling"]["type"] = "mrope"
|
||||
|
||||
# with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
|
||||
# json.dump(config_data, cfout)
|
||||
|
||||
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
|
||||
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB
|
||||
@ -484,7 +489,7 @@ async def sglang_server_task(args, semaphore):
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m", "sglang.launch_server",
|
||||
"--model-path", model_cache_dir,
|
||||
"--model-path", model_name_or_path,
|
||||
"--chat-template", args.model_chat_template,
|
||||
# "--context-length", str(args.model_max_context), # Commented out due to crashes
|
||||
"--port", str(SGLANG_SERVER_PORT),
|
||||
@ -847,9 +852,7 @@ async def main():
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
|
||||
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
|
||||
"gs://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/",
|
||||
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"])
|
||||
default="allenai/olmocr-preview")
|
||||
parser.add_argument('--model_max_context', type=int, default="8192", help="Maximum context length that the model was fine tuned under")
|
||||
parser.add_argument('--model_chat_template', type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
|
||||
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
|
||||
@ -903,6 +906,9 @@ async def main():
|
||||
if args.pdfs.startswith("s3://"):
|
||||
logger.info(f"Expanding s3 glob at {args.pdfs}")
|
||||
s3_work_paths = expand_s3_glob(pdf_s3, args.pdfs)
|
||||
elif any(char in args.pdfs for char in {"*", "?", "[", "]"}):
|
||||
logger.info(f"Expanding local glob at {args.pdfs}")
|
||||
s3_work_paths = glob.glob(args.pdfs)
|
||||
elif os.path.exists(args.pdfs):
|
||||
logger.info(f"Loading file at {args.pdfs}")
|
||||
with open(args.pdfs, "r") as f:
|
||||
|
||||
@ -63,8 +63,6 @@ def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
|
||||
key = obj["Key"]
|
||||
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
|
||||
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
|
||||
if not matched:
|
||||
raise ValueError(f"No objects found for pattern '{s3_glob}'. Check your path or pattern.")
|
||||
return matched
|
||||
|
||||
# Case 2: No wildcard → single file or a bare prefix
|
||||
|
||||
@ -15,57 +15,67 @@ from tqdm import tqdm
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed
|
||||
from datasets.utils import disable_progress_bars
|
||||
from datasets.utils.logging import set_verbosity
|
||||
from peft import LoraConfig, get_peft_model # pyright: ignore
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
AutoProcessor,
|
||||
Qwen2VLConfig
|
||||
AutoConfig,
|
||||
)
|
||||
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.prompts.prompts import build_finetuning_prompt
|
||||
from olmocr.prompts.prompts import build_finetuning_prompt, build_openai_silver_data_prompt
|
||||
|
||||
from olmocr.train.dataprep import prepare_data_for_qwen2_inference
|
||||
|
||||
def build_page_query(local_pdf_path: str, page: int) -> dict:
|
||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
|
||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
||||
|
||||
return {
|
||||
"input_prompt_text": build_finetuning_prompt(anchor_text),
|
||||
"input_prompt_image_base64": image_base64
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(model_name: str):
|
||||
config = Qwen2VLConfig.from_pretrained(model_name)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
# If it doesn't load, change the type:mrope key to "default"
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
|
||||
#model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
|
||||
model.eval()
|
||||
|
||||
#local_pdf_path = os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "horribleocr.pdf")
|
||||
local_pdf_path = "/root/brochure.pdf"
|
||||
page = 1
|
||||
|
||||
query = build_page_query(os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "overrun_on_pg8.pdf"), 8)
|
||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
|
||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
||||
|
||||
inputs = prepare_data_for_qwen2_inference(query, processor)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
print(inputs)
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
inputs = {
|
||||
x: torch.from_numpy(y).unsqueeze(0).to("cuda")
|
||||
for (x,y) in inputs.items()
|
||||
}
|
||||
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
|
||||
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=[main_image],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
|
||||
generated_ids = [
|
||||
@ -75,12 +85,12 @@ def run_inference(model_name: str):
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
print(output_text)
|
||||
print(output_text[0])
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
run_inference(model_name="/root/model")
|
||||
run_inference(model_name="Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -93,18 +93,18 @@ class WorkQueue(abc.ABC):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _compute_workgroup_hash(s3_work_paths: List[str]) -> str:
|
||||
def _compute_workgroup_hash(work_paths: List[str]) -> str:
|
||||
"""
|
||||
Compute a deterministic hash for a group of paths.
|
||||
|
||||
Args:
|
||||
s3_work_paths: List of paths (local or S3)
|
||||
work_paths: List of paths (local or S3)
|
||||
|
||||
Returns:
|
||||
SHA1 hash of the sorted paths
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
for path in sorted(s3_work_paths):
|
||||
for path in sorted(work_paths):
|
||||
sha1.update(path.encode('utf-8'))
|
||||
return sha1.hexdigest()
|
||||
|
||||
@ -189,17 +189,17 @@ class LocalWorkQueue(WorkQueue):
|
||||
# Internal queue
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None:
|
||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||
"""
|
||||
Add new items to the work queue (local version).
|
||||
|
||||
Args:
|
||||
s3_work_paths: Each individual path (local in this context)
|
||||
work_paths: Each individual path (local in this context)
|
||||
that we will process over
|
||||
items_per_group: Number of items to group together in a single work item
|
||||
"""
|
||||
# Treat them as local paths, but keep variable name for consistency
|
||||
all_paths = set(s3_work_paths)
|
||||
all_paths = set(work_paths)
|
||||
logger.info(f"Found {len(all_paths):,} total paths")
|
||||
|
||||
# Load existing work groups from local index
|
||||
@ -276,7 +276,7 @@ class LocalWorkQueue(WorkQueue):
|
||||
# 3) Filter out completed items
|
||||
remaining_work_hashes = set(work_queue) - done_work_hashes
|
||||
remaining_items = [
|
||||
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
|
||||
WorkItem(hash=hash_, work_paths=work_queue[hash_])
|
||||
for hash_ in remaining_work_hashes
|
||||
]
|
||||
random.shuffle(remaining_items)
|
||||
@ -415,15 +415,15 @@ class S3WorkQueue(WorkQueue):
|
||||
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None:
|
||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||
"""
|
||||
Add new items to the work queue.
|
||||
|
||||
Args:
|
||||
s3_work_paths: Each individual s3 path that we will process over
|
||||
work_paths: Each individual s3 path that we will process over
|
||||
items_per_group: Number of items to group together in a single work item
|
||||
"""
|
||||
all_paths = set(s3_work_paths)
|
||||
all_paths = set(work_paths)
|
||||
logger.info(f"Found {len(all_paths):,} total paths")
|
||||
|
||||
# Load existing work groups
|
||||
@ -515,7 +515,7 @@ class S3WorkQueue(WorkQueue):
|
||||
# Find remaining work and shuffle
|
||||
remaining_work_hashes = set(work_queue) - done_work_hashes
|
||||
remaining_items = [
|
||||
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
|
||||
WorkItem(hash=hash_, work_paths=work_queue[hash_])
|
||||
for hash_ in remaining_work_hashes
|
||||
]
|
||||
random.shuffle(remaining_items)
|
||||
|
||||
@ -34,10 +34,11 @@ dependencies = [
|
||||
"orjson",
|
||||
"requests",
|
||||
"zstandard",
|
||||
"aiohttp>=3.10,<3.11", # Specific timeout thing is causing issues
|
||||
"boto3",
|
||||
"torch>=2.4.0",
|
||||
"torch==2.5.1",
|
||||
"transformers>=4.46.2",
|
||||
"sglang[all]==0.4.1",
|
||||
"beaker-py",
|
||||
]
|
||||
license = {file = "LICENSE"}
|
||||
|
||||
@ -70,10 +71,6 @@ dev = [
|
||||
"necessary",
|
||||
]
|
||||
|
||||
inference = [
|
||||
"sglang[all]>=0.3.6",
|
||||
"beaker-py",
|
||||
]
|
||||
|
||||
train = [
|
||||
"torch",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user