Inference test for qwen2 and 2.5, work queue fixes, build current still broken

This commit is contained in:
Jake Poznanski 2025-01-27 15:58:48 -08:00
parent 4d0d9246b4
commit 00e3aac058
5 changed files with 70 additions and 59 deletions

View File

@ -16,6 +16,7 @@ import datetime
import tempfile
import random
import re
import glob
import torch
import multiprocessing
@ -463,19 +464,23 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
async def sglang_server_task(args, semaphore):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
download_directory(args.model, model_cache_dir)
model_name_or_path = args.model
# Check the rope config and make sure it's got the proper key
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
config_data = json.load(cfin)
# if "://" in model_name_or_path:
# # TODO, Fix this code so that we support the multiple s3/weka paths, or else remove it
# model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
# download_directory(model_name_or_path, model_cache_dir)
if "rope_type" in config_data["rope_scaling"]:
del config_data["rope_scaling"]["rope_type"]
config_data["rope_scaling"]["type"] = "mrope"
# # Check the rope config and make sure it's got the proper key
# with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
# config_data = json.load(cfin)
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
json.dump(config_data, cfout)
# if "rope_type" in config_data["rope_scaling"]:
# del config_data["rope_scaling"]["rope_type"]
# config_data["rope_scaling"]["type"] = "mrope"
# with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
# json.dump(config_data, cfout)
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB
@ -484,7 +489,7 @@ async def sglang_server_task(args, semaphore):
cmd = [
"python3",
"-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--model-path", model_name_or_path,
"--chat-template", args.model_chat_template,
# "--context-length", str(args.model_max_context), # Commented out due to crashes
"--port", str(SGLANG_SERVER_PORT),
@ -847,9 +852,7 @@ async def main():
# Model parameters
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
"gs://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/",
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"])
default="allenai/olmocr-preview")
parser.add_argument('--model_max_context', type=int, default="8192", help="Maximum context length that the model was fine tuned under")
parser.add_argument('--model_chat_template', type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
@ -903,6 +906,9 @@ async def main():
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
s3_work_paths = expand_s3_glob(pdf_s3, args.pdfs)
elif any(char in args.pdfs for char in {"*", "?", "[", "]"}):
logger.info(f"Expanding local glob at {args.pdfs}")
s3_work_paths = glob.glob(args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:

View File

@ -63,8 +63,6 @@ def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
key = obj["Key"]
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
if not matched:
raise ValueError(f"No objects found for pattern '{s3_glob}'. Check your path or pattern.")
return matched
# Case 2: No wildcard → single file or a bare prefix

View File

@ -15,57 +15,67 @@ from tqdm import tqdm
import accelerate
import torch
import torch.distributed
from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore
from transformers import (
AutoModelForCausalLM,
Trainer,
TrainerCallback,
TrainingArguments,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
Qwen2VLConfig
AutoConfig,
)
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import build_finetuning_prompt
from olmocr.prompts.prompts import build_finetuning_prompt, build_openai_silver_data_prompt
from olmocr.train.dataprep import prepare_data_for_qwen2_inference
def build_page_query(local_pdf_path: str, page: int) -> dict:
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
return {
"input_prompt_text": build_finetuning_prompt(anchor_text),
"input_prompt_image_base64": image_base64
}
@torch.no_grad()
def run_inference(model_name: str):
config = Qwen2VLConfig.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
# If it doesn't load, change the type:mrope key to "default"
model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
#model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
model.eval()
#local_pdf_path = os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "horribleocr.pdf")
local_pdf_path = "/root/brochure.pdf"
page = 1
query = build_page_query(os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "overrun_on_pg8.pdf"), 8)
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
inputs = prepare_data_for_qwen2_inference(query, processor)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
],
}
]
print(inputs)
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = {
x: torch.from_numpy(y).unsqueeze(0).to("cuda")
for (x,y) in inputs.items()
}
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
generated_ids = [
@ -75,12 +85,12 @@ def run_inference(model_name: str):
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
print(output_text)
print(output_text[0])
def main():
run_inference(model_name="/root/model")
run_inference(model_name="Qwen/Qwen2.5-VL-7B-Instruct")
if __name__ == "__main__":

View File

@ -93,18 +93,18 @@ class WorkQueue(abc.ABC):
pass
@staticmethod
def _compute_workgroup_hash(s3_work_paths: List[str]) -> str:
def _compute_workgroup_hash(work_paths: List[str]) -> str:
"""
Compute a deterministic hash for a group of paths.
Args:
s3_work_paths: List of paths (local or S3)
work_paths: List of paths (local or S3)
Returns:
SHA1 hash of the sorted paths
"""
sha1 = hashlib.sha1()
for path in sorted(s3_work_paths):
for path in sorted(work_paths):
sha1.update(path.encode('utf-8'))
return sha1.hexdigest()
@ -189,17 +189,17 @@ class LocalWorkQueue(WorkQueue):
# Internal queue
self._queue = asyncio.Queue()
async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None:
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue (local version).
Args:
s3_work_paths: Each individual path (local in this context)
work_paths: Each individual path (local in this context)
that we will process over
items_per_group: Number of items to group together in a single work item
"""
# Treat them as local paths, but keep variable name for consistency
all_paths = set(s3_work_paths)
all_paths = set(work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
# Load existing work groups from local index
@ -276,7 +276,7 @@ class LocalWorkQueue(WorkQueue):
# 3) Filter out completed items
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
WorkItem(hash=hash_, work_paths=work_queue[hash_])
for hash_ in remaining_work_hashes
]
random.shuffle(remaining_items)
@ -415,15 +415,15 @@ class S3WorkQueue(WorkQueue):
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
self._queue = asyncio.Queue()
async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None:
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue.
Args:
s3_work_paths: Each individual s3 path that we will process over
work_paths: Each individual s3 path that we will process over
items_per_group: Number of items to group together in a single work item
"""
all_paths = set(s3_work_paths)
all_paths = set(work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
# Load existing work groups
@ -515,7 +515,7 @@ class S3WorkQueue(WorkQueue):
# Find remaining work and shuffle
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
WorkItem(hash=hash_, work_paths=work_queue[hash_])
for hash_ in remaining_work_hashes
]
random.shuffle(remaining_items)

View File

@ -34,10 +34,11 @@ dependencies = [
"orjson",
"requests",
"zstandard",
"aiohttp>=3.10,<3.11", # Specific timeout thing is causing issues
"boto3",
"torch>=2.4.0",
"torch==2.5.1",
"transformers>=4.46.2",
"sglang[all]==0.4.1",
"beaker-py",
]
license = {file = "LICENSE"}
@ -70,10 +71,6 @@ dev = [
"necessary",
]
inference = [
"sglang[all]>=0.3.6",
"beaker-py",
]
train = [
"torch",