mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
402 lines
16 KiB
Python
402 lines
16 KiB
Python
# The idea is that you have a Qwen2-VL-7B model located here:s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
|
|
|
|
# You need to load it in both hugging face transformers, and send page 1 of edgar.pdf to it from tests/gnarly_pdfs
|
|
# Compare that the temperature 0 sampled result is the same
|
|
|
|
import asyncio
|
|
import unittest
|
|
from unittest.mock import patch, AsyncMock
|
|
import os
|
|
import json
|
|
import tempfile
|
|
import math
|
|
import base64
|
|
import torch
|
|
import numpy as np
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration
|
|
from pathlib import Path
|
|
from olmocr.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text, download_directory
|
|
from olmocr.prompts import PageResponse
|
|
from httpx import AsyncClient
|
|
import torch.nn.functional as F
|
|
MODEL_FINETUNED_PATH = "s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
|
|
|
|
|
|
class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
|
async def asyncSetUp(self):
|
|
# Mock arguments
|
|
self.args = AsyncMock()
|
|
self.args.workspace = "/tmp/test_workspace"
|
|
self.args.model = [MODEL_FINETUNED_PATH]
|
|
self.args.model_chat_template = "qwen2-vl"
|
|
self.args.target_longest_image_dim = 1024
|
|
self.args.target_anchor_text_len = 6000
|
|
self.args.model_max_context = 8192
|
|
|
|
# Create a temporary workspace directory
|
|
os.makedirs(self.args.workspace, exist_ok=True)
|
|
|
|
# Set up a semaphore for server tasks
|
|
self.semaphore = asyncio.Semaphore(1)
|
|
self.maxDiff = None
|
|
|
|
# # Start the sglang server
|
|
# self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
|
|
|
|
# # Wait for the server to become ready
|
|
# await sglang_server_ready()
|
|
|
|
async def test_sglang_server_initialization_and_request(self):
|
|
# Mock data paths
|
|
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "ambiguous.pdf"))
|
|
|
|
# Send a single request to the sglang server for page 1
|
|
async with AsyncClient(timeout=600) as session:
|
|
query = await build_page_query(
|
|
str(self.test_pdf_path),
|
|
page=1,
|
|
target_longest_image_dim=self.args.target_longest_image_dim,
|
|
target_anchor_text_len=self.args.target_anchor_text_len,
|
|
)
|
|
COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
|
|
|
|
query["temperature"] = 0.0
|
|
query["logprobs"] = True
|
|
query["top_logprobs"] = 5
|
|
response = await session.post(COMPLETION_URL, json=query)
|
|
|
|
print(response.text)
|
|
|
|
# Check the server response
|
|
self.assertEqual(response.status_code, 200)
|
|
response_data = response.json()
|
|
self.assertIn("choices", response_data)
|
|
self.assertGreater(len(response_data["choices"]), 0)
|
|
|
|
|
|
|
|
model_response_json = json.loads(response_data["choices"][0]["message"]["content"])
|
|
page_response = PageResponse(**model_response_json)
|
|
|
|
print(page_response)
|
|
|
|
self.assertEqual(page_response.natural_text, EDGAR_TEXT)
|
|
|
|
|
|
async def asyncTearDown(self):
|
|
pass
|
|
# # Shut down the server
|
|
# self.my_server_task.cancel()
|
|
# with self.assertRaises(asyncio.CancelledError):
|
|
# await self.my_server_task
|
|
|
|
# # Cleanup temporary workspace
|
|
# if os.path.exists(self.args.workspace):
|
|
# for root, _, files in os.walk(self.args.workspace):
|
|
# for file in files:
|
|
# os.unlink(os.path.join(root, file))
|
|
# os.rmdir(self.args.workspace)
|
|
|
|
|
|
class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase):
|
|
async def asyncSetUp(self):
|
|
# Set up the Hugging Face model and tokenizer
|
|
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
|
|
download_directory([MODEL_FINETUNED_PATH], model_cache_dir)
|
|
|
|
# Check the rope config and make sure it's got the proper key
|
|
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
|
|
config_data = json.load(cfin)
|
|
|
|
if "rope_type" in config_data["rope_scaling"]:
|
|
del config_data["rope_scaling"]["rope_type"]
|
|
config_data["rope_scaling"]["type"] = "mrope"
|
|
|
|
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
|
|
json.dump(config_data, cfout)
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_cache_dir, trust_remote_code=True)
|
|
self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0]
|
|
|
|
|
|
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_cache_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
|
|
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.model.to(self.device)
|
|
|
|
# Path to the test PDF
|
|
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "ambiguous.pdf"))
|
|
self.maxDiff = None
|
|
|
|
async def test_hugging_face_generation(self):
|
|
query = await build_page_query(
|
|
str(self.test_pdf_path),
|
|
page=1,
|
|
target_longest_image_dim=1024,
|
|
target_anchor_text_len=6000,
|
|
)
|
|
|
|
messages = query["messages"]
|
|
|
|
# Apply chat template to get the text
|
|
text = self.processor.apply_chat_template(
|
|
query["messages"], tokenize=False, add_generation_prompt=True
|
|
)
|
|
|
|
image_url = query["messages"][0]["content"][1]["image_url"]["url"]
|
|
|
|
# Remove the "data:image/png;base64," prefix
|
|
base64_image = image_url.split(",")[1]
|
|
|
|
# Decode the base64 string into bytes
|
|
image_data = base64.b64decode(base64_image)
|
|
|
|
# Create a BytesIO object and load it into a PIL image
|
|
main_image = Image.open(BytesIO(image_data))
|
|
|
|
# Process inputs using processor
|
|
inputs = self.processor(
|
|
text=[text],
|
|
images=[main_image],
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
image_indices = [
|
|
idx
|
|
for idx, token in enumerate(inputs["input_ids"][0])
|
|
if token.item() == self.image_token_id
|
|
]
|
|
|
|
print("IMAGE INDICES", image_indices)
|
|
|
|
print(f"image_grid_thw - {inputs['image_grid_thw'].shape} {inputs['image_grid_thw']}")
|
|
print(f"pixel_values - {inputs['pixel_values'].shape} {inputs['pixel_values'].detach().cpu().numpy()}")
|
|
np.save('/root/pixel_values.npy', inputs['pixel_values'].detach().cpu().numpy())
|
|
|
|
inputs = {key: value.to(self.device) for (key, value) in inputs.items()}
|
|
|
|
generated_tokens = []
|
|
max_steps = 50
|
|
|
|
top_logprobs_hf = []
|
|
|
|
for step in range(max_steps):
|
|
# Generate the output with temperature=0
|
|
generation_output = self.model.generate(
|
|
**inputs,
|
|
temperature=0.0,
|
|
max_new_tokens=1,
|
|
#max_length=8192,
|
|
num_return_sequences=1,
|
|
do_sample=False,
|
|
output_scores=True,
|
|
return_dict_in_generate=True,
|
|
)
|
|
|
|
|
|
# Extract the generated token's log probabilities
|
|
scores = generation_output.scores # Tuple of length 1
|
|
logits = scores[0] # Tensor of shape (batch_size, vocab_size)
|
|
log_probs = F.log_softmax(logits, dim=-1) # Apply log softmax to get log probabilities
|
|
|
|
# Get top 5 tokens and their log probabilities
|
|
topk_log_probs, topk_indices = torch.topk(log_probs[0], k=5)
|
|
topk_tokens = self.tokenizer.convert_ids_to_tokens(topk_indices.tolist())
|
|
|
|
|
|
top_logprobs_hf.append((topk_tokens, topk_log_probs.tolist()))
|
|
|
|
# Pick the top token
|
|
next_token_id = topk_indices[0].unsqueeze(0).unsqueeze(0) # Shape: (1, 1)
|
|
next_token_str = self.tokenizer.convert_ids_to_tokens([next_token_id.item()])[0]
|
|
|
|
generated_tokens.append(next_token_id.item())
|
|
|
|
# Append the next token to input_ids and update attention_mask
|
|
inputs['input_ids'] = torch.cat([inputs['input_ids'], next_token_id], dim=-1)
|
|
inputs['attention_mask'] = torch.cat(
|
|
[inputs['attention_mask'], torch.ones((1, 1), dtype=inputs['attention_mask'].dtype).to(self.device)], dim=-1
|
|
)
|
|
|
|
print(self.tokenizer.decode(generated_tokens))
|
|
|
|
# Now take all the input ids and run them through sglang as a comparison
|
|
async with AsyncClient(timeout=600) as session:
|
|
query["temperature"] = 0.0
|
|
query["max_tokens"] = max_steps
|
|
query["logprobs"] = True
|
|
query["top_logprobs"] = 5
|
|
COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
|
|
response = await session.post(COMPLETION_URL, json=query)
|
|
|
|
response_data = response.json()
|
|
|
|
for step, lptok in enumerate(response_data["choices"][0]["logprobs"]["content"]):
|
|
print("\nTop 5 tokens and their log probabilities:")
|
|
(topk_tokens, topk_log_probs) = top_logprobs_hf[step]
|
|
for token, log_prob, lptokcur in zip(topk_tokens, topk_log_probs, lptok["top_logprobs"]):
|
|
print(f"HF Token: {token} Log Prob: {log_prob:.2f} Prob {math.exp(log_prob)*100:.2f}% SGLANG Token {lptokcur['token']} Logprob {lptokcur['logprob']:.2f} Prob {math.exp(lptokcur['logprob'])*100:.2f}%")
|
|
|
|
async def asyncTearDown(self):
|
|
# Clean up the model and tokenizer
|
|
del self.model
|
|
del self.tokenizer
|
|
torch.cuda.empty_cache()
|
|
|
|
class RawSGLangTest(unittest.IsolatedAsyncioTestCase):
|
|
def setUp(self):
|
|
# Set up the Hugging Face model and tokenizer
|
|
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
|
|
download_directory([MODEL_FINETUNED_PATH], model_cache_dir)
|
|
|
|
# Check the rope config and make sure it's got the proper key
|
|
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
|
|
config_data = json.load(cfin)
|
|
|
|
if "rope_type" in config_data["rope_scaling"]:
|
|
del config_data["rope_scaling"]["rope_type"]
|
|
config_data["rope_scaling"]["type"] = "mrope"
|
|
|
|
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
|
|
json.dump(config_data, cfout)
|
|
|
|
self.model_cache_dir = model_cache_dir
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_cache_dir, trust_remote_code=True)
|
|
self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0]
|
|
|
|
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_cache_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
|
|
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.model.to(self.device)
|
|
|
|
# Path to the test PDF
|
|
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "ambiguous.pdf"))
|
|
self.maxDiff = None
|
|
|
|
async def test_vision_encoder(self):
|
|
query = await build_page_query(
|
|
str(self.test_pdf_path),
|
|
page=1,
|
|
target_longest_image_dim=1024,
|
|
target_anchor_text_len=6000,
|
|
)
|
|
|
|
messages = query["messages"]
|
|
|
|
# Apply chat template to get the text
|
|
text = self.processor.apply_chat_template(
|
|
query["messages"], tokenize=False, add_generation_prompt=True
|
|
)
|
|
|
|
image_url = query["messages"][0]["content"][1]["image_url"]["url"]
|
|
|
|
# Remove the "data:image/png;base64," prefix
|
|
base64_image = image_url.split(",")[1]
|
|
|
|
# Decode the base64 string into bytes
|
|
image_data = base64.b64decode(base64_image)
|
|
|
|
# Create a BytesIO object and load it into a PIL image
|
|
main_image = Image.open(BytesIO(image_data))
|
|
|
|
# Process inputs using processor
|
|
inputs = self.processor(
|
|
text=[text],
|
|
images=[main_image],
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
with torch.no_grad():
|
|
hf_output = self.model.visual(inputs["pixel_values"].to(self.device), grid_thw=inputs["image_grid_thw"].to(self.device))
|
|
|
|
|
|
print("HF", hf_output, hf_output.shape)
|
|
|
|
from sglang.srt.configs.model_config import ModelConfig
|
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
from sglang.srt.server_args import ServerArgs, PortArgs
|
|
|
|
model_config = ModelConfig(
|
|
self.model_cache_dir,
|
|
model_override_args="{}"
|
|
)
|
|
|
|
server_args = ServerArgs(model_path=self.model_cache_dir)
|
|
# Initialize model runner
|
|
model_runner = ModelRunner(
|
|
model_config=model_config,
|
|
mem_fraction_static=0.8,
|
|
gpu_id=0,
|
|
tp_rank=0,
|
|
tp_size=1,
|
|
nccl_port=12435,
|
|
server_args=server_args,
|
|
)
|
|
|
|
print(model_runner)
|
|
with torch.no_grad():
|
|
sglang_output = model_runner.model.visual(inputs["pixel_values"].to(self.device), grid_thw=inputs["image_grid_thw"].to(self.device))
|
|
|
|
print("SGLANG", sglang_output, sglang_output.shape)
|
|
|
|
# Convert to float32 for numerical stability if needed
|
|
hf = hf_output.float()
|
|
sg = sglang_output.float()
|
|
|
|
# Basic shape and dtype comparison
|
|
print("\n=== Basic Properties ===")
|
|
print(f"Shapes match: {hf.shape == sg.shape}")
|
|
print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}")
|
|
print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}")
|
|
|
|
# Move tensors to CPU for numpy operations
|
|
hf_np = hf.cpu().numpy()
|
|
sg_np = sg.cpu().numpy()
|
|
|
|
# Statistical metrics
|
|
print("\n=== Statistical Metrics ===")
|
|
print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}")
|
|
print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}")
|
|
print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}")
|
|
print(f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}")
|
|
|
|
# Cosine similarity (across feature dimension)
|
|
cos_sim = F.cosine_similarity(hf, sg)
|
|
print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}")
|
|
print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}")
|
|
|
|
# Find largest absolute differences
|
|
print("\n=== Largest Absolute Differences ===")
|
|
diffs = torch.abs(hf - sg)
|
|
flat_diffs = diffs.flatten()
|
|
|
|
# Get indices of top 10 differences
|
|
top_k = 10
|
|
top_values, top_flat_indices = torch.topk(flat_diffs, top_k)
|
|
|
|
# Convert flat indices to multidimensional indices
|
|
top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape)
|
|
|
|
print(f"\nTop {top_k} largest absolute differences:")
|
|
print("Index".ljust(30) + "Difference".ljust(15) + "HF Value".ljust(15) + "SGLang Value")
|
|
print("-" * 75)
|
|
|
|
for i in range(top_k):
|
|
# Get the index tuple for this difference
|
|
idx = tuple(dim[i] for dim in top_indices)
|
|
diff_val = top_values[i].item()
|
|
hf_val = hf[idx].item()
|
|
sg_val = sg[idx].item()
|
|
|
|
# Format the index tuple and values
|
|
idx_str = str(idx)
|
|
print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}") |