mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 18:15:44 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			402 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			402 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # The idea is that you have a Qwen2-VL-7B model located here:s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
 | |
| 
 | |
| # You need to load it in both hugging face transformers, and send page 1 of edgar.pdf to it from tests/gnarly_pdfs
 | |
| # Compare that the temperature 0 sampled result is the same
 | |
| 
 | |
| import asyncio
 | |
| import unittest
 | |
| from unittest.mock import patch, AsyncMock
 | |
| import os
 | |
| import json
 | |
| import tempfile
 | |
| import math
 | |
| import base64
 | |
| import torch
 | |
| import numpy as np
 | |
| from io import BytesIO
 | |
| from PIL import Image
 | |
| from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration
 | |
| from pathlib import Path
 | |
| from pdelfin.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text, download_directory
 | |
| from pdelfin.prompts import PageResponse
 | |
| from httpx import AsyncClient
 | |
| import torch.nn.functional as F
 | |
| MODEL_FINETUNED_PATH = "s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
 | |
| 
 | |
| 
 | |
| class TestSglangServer(unittest.IsolatedAsyncioTestCase):
 | |
|     async def asyncSetUp(self):
 | |
|         # Mock arguments
 | |
|         self.args = AsyncMock()
 | |
|         self.args.workspace = "/tmp/test_workspace"
 | |
|         self.args.model = [MODEL_FINETUNED_PATH]
 | |
|         self.args.model_chat_template = "qwen2-vl"
 | |
|         self.args.target_longest_image_dim = 1024
 | |
|         self.args.target_anchor_text_len = 6000
 | |
|         self.args.model_max_context = 8192
 | |
| 
 | |
|         # Create a temporary workspace directory
 | |
|         os.makedirs(self.args.workspace, exist_ok=True)
 | |
| 
 | |
|         # Set up a semaphore for server tasks
 | |
|         self.semaphore = asyncio.Semaphore(1)
 | |
|         self.maxDiff = None
 | |
| 
 | |
|         # # Start the sglang server
 | |
|         # self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
 | |
|         
 | |
|         # # Wait for the server to become ready
 | |
|         # await sglang_server_ready()
 | |
| 
 | |
|     async def test_sglang_server_initialization_and_request(self):
 | |
|         # Mock data paths
 | |
|         self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "ambiguous.pdf"))
 | |
| 
 | |
|         # Send a single request to the sglang server for page 1
 | |
|         async with AsyncClient(timeout=600) as session:
 | |
|             query = await build_page_query(
 | |
|                 str(self.test_pdf_path),
 | |
|                 page=1,
 | |
|                 target_longest_image_dim=self.args.target_longest_image_dim,
 | |
|                 target_anchor_text_len=self.args.target_anchor_text_len,
 | |
|             )
 | |
|             COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
 | |
| 
 | |
|             query["temperature"] = 0.0
 | |
|             query["logprobs"] = True
 | |
|             query["top_logprobs"] = 5
 | |
|             response = await session.post(COMPLETION_URL, json=query)
 | |
| 
 | |
|         print(response.text)
 | |
| 
 | |
|         # Check the server response
 | |
|         self.assertEqual(response.status_code, 200)
 | |
|         response_data = response.json()
 | |
|         self.assertIn("choices", response_data)
 | |
|         self.assertGreater(len(response_data["choices"]), 0)
 | |
| 
 | |
|       
 | |
| 
 | |
|         model_response_json = json.loads(response_data["choices"][0]["message"]["content"])
 | |
|         page_response = PageResponse(**model_response_json)
 | |
| 
 | |
|         print(page_response)
 | |
| 
 | |
|         self.assertEqual(page_response.natural_text, EDGAR_TEXT)
 | |
| 
 | |
| 
 | |
|     async def asyncTearDown(self):
 | |
|         pass
 | |
|         # # Shut down the server
 | |
|         # self.my_server_task.cancel()
 | |
|         # with self.assertRaises(asyncio.CancelledError):
 | |
|         #     await self.my_server_task
 | |
| 
 | |
|         # # Cleanup temporary workspace
 | |
|         # if os.path.exists(self.args.workspace):
 | |
|         #     for root, _, files in os.walk(self.args.workspace):
 | |
|         #         for file in files:
 | |
|         #             os.unlink(os.path.join(root, file))
 | |
|         #     os.rmdir(self.args.workspace)
 | |
| 
 | |
| 
 | |
| class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase):
 | |
|     async def asyncSetUp(self):
 | |
|         # Set up the Hugging Face model and tokenizer
 | |
|         model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
 | |
|         download_directory([MODEL_FINETUNED_PATH], model_cache_dir)
 | |
| 
 | |
|         # Check the rope config and make sure it's got the proper key
 | |
|         with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
 | |
|             config_data = json.load(cfin)
 | |
| 
 | |
|         if "rope_type" in config_data["rope_scaling"]:
 | |
|             del config_data["rope_scaling"]["rope_type"]
 | |
|             config_data["rope_scaling"]["type"] = "mrope"
 | |
| 
 | |
|             with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
 | |
|                 json.dump(config_data, cfout)
 | |
| 
 | |
|         self.tokenizer = AutoTokenizer.from_pretrained(model_cache_dir, trust_remote_code=True)
 | |
|         self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0]
 | |
| 
 | |
| 
 | |
|         self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_cache_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
 | |
|         self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
 | |
|         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | |
|         self.model.to(self.device)
 | |
| 
 | |
|         # Path to the test PDF
 | |
|         self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "ambiguous.pdf"))
 | |
|         self.maxDiff = None
 | |
| 
 | |
|     async def test_hugging_face_generation(self):
 | |
|         query = await build_page_query(
 | |
|                 str(self.test_pdf_path),
 | |
|                 page=1,
 | |
|                 target_longest_image_dim=1024,
 | |
|                 target_anchor_text_len=6000,
 | |
|             )
 | |
| 
 | |
|         messages = query["messages"]
 | |
|   
 | |
|        # Apply chat template to get the text
 | |
|         text = self.processor.apply_chat_template(
 | |
|             query["messages"], tokenize=False, add_generation_prompt=True
 | |
|         )
 | |
| 
 | |
|         image_url = query["messages"][0]["content"][1]["image_url"]["url"]
 | |
| 
 | |
|         # Remove the "data:image/png;base64," prefix
 | |
|         base64_image = image_url.split(",")[1]
 | |
| 
 | |
|         # Decode the base64 string into bytes
 | |
|         image_data = base64.b64decode(base64_image)
 | |
| 
 | |
|         # Create a BytesIO object and load it into a PIL image
 | |
|         main_image = Image.open(BytesIO(image_data))
 | |
| 
 | |
|         # Process inputs using processor
 | |
|         inputs = self.processor(
 | |
|             text=[text],
 | |
|             images=[main_image],
 | |
|             padding=True,
 | |
|             return_tensors="pt",
 | |
|         )
 | |
| 
 | |
|         image_indices = [
 | |
|             idx
 | |
|             for idx, token in enumerate(inputs["input_ids"][0])
 | |
|             if token.item() == self.image_token_id
 | |
|         ]
 | |
| 
 | |
|         print("IMAGE INDICES", image_indices)
 | |
| 
 | |
|         print(f"image_grid_thw - {inputs['image_grid_thw'].shape} {inputs['image_grid_thw']}")
 | |
|         print(f"pixel_values - {inputs['pixel_values'].shape} {inputs['pixel_values'].detach().cpu().numpy()}")
 | |
|         np.save('/root/pixel_values.npy', inputs['pixel_values'].detach().cpu().numpy())
 | |
| 
 | |
|         inputs = {key: value.to(self.device) for (key, value) in inputs.items()}
 | |
| 
 | |
|         generated_tokens = []
 | |
|         max_steps = 50
 | |
| 
 | |
|         top_logprobs_hf = []
 | |
| 
 | |
|         for step in range(max_steps):
 | |
|             # Generate the output with temperature=0
 | |
|             generation_output = self.model.generate(
 | |
|                 **inputs,
 | |
|                 temperature=0.0,
 | |
|                 max_new_tokens=1,
 | |
|                 #max_length=8192,
 | |
|                 num_return_sequences=1,
 | |
|                 do_sample=False,
 | |
|                 output_scores=True,
 | |
|                 return_dict_in_generate=True,
 | |
|             )
 | |
| 
 | |
| 
 | |
|             # Extract the generated token's log probabilities
 | |
|             scores = generation_output.scores  # Tuple of length 1
 | |
|             logits = scores[0]  # Tensor of shape (batch_size, vocab_size)
 | |
|             log_probs = F.log_softmax(logits, dim=-1)  # Apply log softmax to get log probabilities
 | |
| 
 | |
|             # Get top 5 tokens and their log probabilities
 | |
|             topk_log_probs, topk_indices = torch.topk(log_probs[0], k=5)
 | |
|             topk_tokens = self.tokenizer.convert_ids_to_tokens(topk_indices.tolist())
 | |
| 
 | |
| 
 | |
|             top_logprobs_hf.append((topk_tokens, topk_log_probs.tolist()))
 | |
| 
 | |
|             # Pick the top token
 | |
|             next_token_id = topk_indices[0].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1)
 | |
|             next_token_str = self.tokenizer.convert_ids_to_tokens([next_token_id.item()])[0]
 | |
| 
 | |
|             generated_tokens.append(next_token_id.item())
 | |
| 
 | |
|             # Append the next token to input_ids and update attention_mask
 | |
|             inputs['input_ids'] = torch.cat([inputs['input_ids'], next_token_id], dim=-1)
 | |
|             inputs['attention_mask'] = torch.cat(
 | |
|                 [inputs['attention_mask'], torch.ones((1, 1), dtype=inputs['attention_mask'].dtype).to(self.device)], dim=-1
 | |
|             )
 | |
| 
 | |
|         print(self.tokenizer.decode(generated_tokens))
 | |
| 
 | |
|         # Now take all the input ids and run them through sglang as a comparison
 | |
|         async with AsyncClient(timeout=600) as session:
 | |
|             query["temperature"] = 0.0
 | |
|             query["max_tokens"] = max_steps 
 | |
|             query["logprobs"] = True
 | |
|             query["top_logprobs"] = 5
 | |
|             COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
 | |
|             response = await session.post(COMPLETION_URL, json=query)
 | |
| 
 | |
|             response_data = response.json()
 | |
| 
 | |
|             for step, lptok in enumerate(response_data["choices"][0]["logprobs"]["content"]):
 | |
|                 print("\nTop 5 tokens and their log probabilities:")
 | |
|                 (topk_tokens, topk_log_probs) = top_logprobs_hf[step]
 | |
|                 for token, log_prob, lptokcur in zip(topk_tokens, topk_log_probs, lptok["top_logprobs"]):
 | |
|                     print(f"HF Token: {token} Log Prob: {log_prob:.2f} Prob {math.exp(log_prob)*100:.2f}%  SGLANG Token {lptokcur['token']} Logprob {lptokcur['logprob']:.2f} Prob {math.exp(lptokcur['logprob'])*100:.2f}%")
 | |
| 
 | |
|     async def asyncTearDown(self):
 | |
|         # Clean up the model and tokenizer
 | |
|         del self.model
 | |
|         del self.tokenizer
 | |
|         torch.cuda.empty_cache()
 | |
| 
 | |
| class RawSGLangTest(unittest.IsolatedAsyncioTestCase):
 | |
|     def setUp(self):
 | |
|         # Set up the Hugging Face model and tokenizer
 | |
|         model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
 | |
|         download_directory([MODEL_FINETUNED_PATH], model_cache_dir)
 | |
| 
 | |
|         # Check the rope config and make sure it's got the proper key
 | |
|         with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
 | |
|             config_data = json.load(cfin)
 | |
| 
 | |
|         if "rope_type" in config_data["rope_scaling"]:
 | |
|             del config_data["rope_scaling"]["rope_type"]
 | |
|             config_data["rope_scaling"]["type"] = "mrope"
 | |
| 
 | |
|             with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
 | |
|                 json.dump(config_data, cfout)
 | |
| 
 | |
|         self.model_cache_dir = model_cache_dir
 | |
| 
 | |
|         self.tokenizer = AutoTokenizer.from_pretrained(model_cache_dir, trust_remote_code=True)
 | |
|         self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0]
 | |
| 
 | |
|         self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_cache_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
 | |
|         self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
 | |
|         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | |
|         self.model.to(self.device)
 | |
| 
 | |
|         # Path to the test PDF
 | |
|         self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "ambiguous.pdf"))
 | |
|         self.maxDiff = None
 | |
| 
 | |
|     async def test_vision_encoder(self):
 | |
|         query = await build_page_query(
 | |
|                 str(self.test_pdf_path),
 | |
|                 page=1,
 | |
|                 target_longest_image_dim=1024,
 | |
|                 target_anchor_text_len=6000,
 | |
|             )
 | |
| 
 | |
|         messages = query["messages"]
 | |
|   
 | |
|         # Apply chat template to get the text
 | |
|         text = self.processor.apply_chat_template(
 | |
|             query["messages"], tokenize=False, add_generation_prompt=True
 | |
|         )
 | |
| 
 | |
|         image_url = query["messages"][0]["content"][1]["image_url"]["url"]
 | |
| 
 | |
|         # Remove the "data:image/png;base64," prefix
 | |
|         base64_image = image_url.split(",")[1]
 | |
| 
 | |
|         # Decode the base64 string into bytes
 | |
|         image_data = base64.b64decode(base64_image)
 | |
| 
 | |
|         # Create a BytesIO object and load it into a PIL image
 | |
|         main_image = Image.open(BytesIO(image_data))
 | |
| 
 | |
|         # Process inputs using processor
 | |
|         inputs = self.processor(
 | |
|             text=[text],
 | |
|             images=[main_image],
 | |
|             padding=True,
 | |
|             return_tensors="pt",
 | |
|         )
 | |
| 
 | |
|         with torch.no_grad():
 | |
|             hf_output = self.model.visual(inputs["pixel_values"].to(self.device), grid_thw=inputs["image_grid_thw"].to(self.device))
 | |
| 
 | |
| 
 | |
|         print("HF", hf_output, hf_output.shape)
 | |
| 
 | |
|         from sglang.srt.configs.model_config import ModelConfig
 | |
|         from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
 | |
|         from sglang.srt.model_executor.forward_batch_info import ForwardBatch
 | |
|         from sglang.srt.model_executor.model_runner import ModelRunner
 | |
|         from sglang.srt.sampling.sampling_params import SamplingParams
 | |
|         from sglang.srt.hf_transformers_utils import get_tokenizer
 | |
|         from sglang.srt.server_args import ServerArgs, PortArgs
 | |
| 
 | |
|         model_config = ModelConfig(
 | |
|             self.model_cache_dir,
 | |
|             model_override_args="{}"
 | |
|         )
 | |
|         
 | |
|         server_args = ServerArgs(model_path=self.model_cache_dir)
 | |
|         # Initialize model runner
 | |
|         model_runner = ModelRunner(
 | |
|             model_config=model_config,
 | |
|             mem_fraction_static=0.8,
 | |
|             gpu_id=0,
 | |
|             tp_rank=0,
 | |
|             tp_size=1,
 | |
|             nccl_port=12435,
 | |
|             server_args=server_args,
 | |
|         )
 | |
| 
 | |
|         print(model_runner)
 | |
|         with torch.no_grad():
 | |
|             sglang_output = model_runner.model.visual(inputs["pixel_values"].to(self.device), grid_thw=inputs["image_grid_thw"].to(self.device))
 | |
| 
 | |
|         print("SGLANG", sglang_output, sglang_output.shape)
 | |
| 
 | |
|         # Convert to float32 for numerical stability if needed
 | |
|         hf = hf_output.float()
 | |
|         sg = sglang_output.float()
 | |
|         
 | |
|         # Basic shape and dtype comparison
 | |
|         print("\n=== Basic Properties ===")
 | |
|         print(f"Shapes match: {hf.shape == sg.shape}")
 | |
|         print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}")
 | |
|         print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}")
 | |
|         
 | |
|         # Move tensors to CPU for numpy operations
 | |
|         hf_np = hf.cpu().numpy()
 | |
|         sg_np = sg.cpu().numpy()
 | |
|         
 | |
|         # Statistical metrics
 | |
|         print("\n=== Statistical Metrics ===")
 | |
|         print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}")
 | |
|         print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}")
 | |
|         print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}")
 | |
|         print(f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}")
 | |
|         
 | |
|         # Cosine similarity (across feature dimension)
 | |
|         cos_sim = F.cosine_similarity(hf, sg)
 | |
|         print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}")
 | |
|         print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}")
 | |
| 
 | |
|         # Find largest absolute differences
 | |
|         print("\n=== Largest Absolute Differences ===")
 | |
|         diffs = torch.abs(hf - sg)
 | |
|         flat_diffs = diffs.flatten()
 | |
|         
 | |
|         # Get indices of top 10 differences
 | |
|         top_k = 10
 | |
|         top_values, top_flat_indices = torch.topk(flat_diffs, top_k)
 | |
|         
 | |
|         # Convert flat indices to multidimensional indices
 | |
|         top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape)
 | |
|         
 | |
|         print(f"\nTop {top_k} largest absolute differences:")
 | |
|         print("Index".ljust(30) + "Difference".ljust(15) + "HF Value".ljust(15) + "SGLang Value")
 | |
|         print("-" * 75)
 | |
|         
 | |
|         for i in range(top_k):
 | |
|             # Get the index tuple for this difference
 | |
|             idx = tuple(dim[i] for dim in top_indices)
 | |
|             diff_val = top_values[i].item()
 | |
|             hf_val = hf[idx].item()
 | |
|             sg_val = sg[idx].item()
 | |
|             
 | |
|             # Format the index tuple and values
 | |
|             idx_str = str(idx)
 | |
|             print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}") | 
