2024-11-25 09:34:59 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								# The idea is that you have a Qwen2-VL-7B model located here:s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# You need to load it in both hugging face transformers, and send page 1 of edgar.pdf to it from tests/gnarly_pdfs  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# Compare that the temperature 0 sampled result is the same  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  asyncio  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  unittest  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  unittest . mock  import  patch ,  AsyncMock  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  os  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  json  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  tempfile  
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:00:03 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  math  
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:25:55 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  base64  
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
									
										
										
										
											2024-11-26 08:38:25 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  numpy  as  np  
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:25:55 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  io  import  BytesIO  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  PIL  import  Image  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  transformers  import  AutoProcessor ,  AutoTokenizer ,  Qwen2VLForConditionalGeneration  
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:34:59 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								from  pathlib  import  Path  
						 
					
						
							
								
									
										
										
										
											2025-01-27 18:48:15 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  olmocr . pipeline  import  sglang_server_task ,  sglang_server_ready ,  build_page_query ,  SGLANG_SERVER_PORT ,  render_pdf_to_base64png ,  get_anchor_text ,  download_directory  
						 
					
						
							
								
									
										
										
										
											2025-01-27 18:30:41 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  olmocr . prompts  import  PageResponse  
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:34:59 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								from  httpx  import  AsyncClient  
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:00:03 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  torch . nn . functional  as  F  
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								MODEL_FINETUNED_PATH  =  " s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/ "  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:34:59 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TestSglangServer ( unittest . IsolatedAsyncioTestCase ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    async  def  asyncSetUp ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Mock arguments 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . args  =  AsyncMock ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . args . workspace  =  " /tmp/test_workspace " 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . args . model  =  [ MODEL_FINETUNED_PATH ] 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:34:59 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        self . args . model_chat_template  =  " qwen2-vl " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . args . target_longest_image_dim  =  1024 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . args . target_anchor_text_len  =  6000 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . args . model_max_context  =  8192 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Create a temporary workspace directory 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        os . makedirs ( self . args . workspace ,  exist_ok = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Set up a semaphore for server tasks 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . semaphore  =  asyncio . Semaphore ( 1 ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:00:03 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . maxDiff  =  None 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:34:59 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:24:21 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        # # Start the sglang server 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore)) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:34:59 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:24:21 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        # # Wait for the server to become ready 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # await sglang_server_ready() 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:48:05 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:24:21 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    async  def  test_sglang_server_initialization_and_request ( self ) : 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:48:05 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        # Mock data paths 
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:20:10 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . test_pdf_path  =  Path ( os . path . join ( os . path . dirname ( __file__ ) ,  " gnarly_pdfs " ,  " ambiguous.pdf " ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:48:05 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:34:59 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        # Send a single request to the sglang server for page 1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        async  with  AsyncClient ( timeout = 600 )  as  session : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query  =  await  build_page_query ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                str ( self . test_pdf_path ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                page = 1 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                target_longest_image_dim = self . args . target_longest_image_dim , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                target_anchor_text_len = self . args . target_anchor_text_len , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:24:21 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            COMPLETION_URL  =  f " http://localhost: { 30000 } /v1/chat/completions " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query [ " temperature " ]  =  0.0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query [ " logprobs " ]  =  True 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            query [ " top_logprobs " ]  =  5 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:34:59 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            response  =  await  session . post ( COMPLETION_URL ,  json = query ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:24:21 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        print ( response . text ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:34:59 -08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        # Check the server response 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . assertEqual ( response . status_code ,  200 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        response_data  =  response . json ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . assertIn ( " choices " ,  response_data ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . assertGreater ( len ( response_data [ " choices " ] ) ,  0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:24:21 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:39:55 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        model_response_json  =  json . loads ( response_data [ " choices " ] [ 0 ] [ " message " ] [ " content " ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        page_response  =  PageResponse ( * * model_response_json ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( page_response ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . assertEqual ( page_response . natural_text ,  EDGAR_TEXT ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 09:48:05 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    async  def  asyncTearDown ( self ) : 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:24:21 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        pass 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # # Shut down the server 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # self.my_server_task.cancel() 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # with self.assertRaises(asyncio.CancelledError): 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        #     await self.my_server_task 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # # Cleanup temporary workspace 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # if os.path.exists(self.args.workspace): 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        #     for root, _, files in os.walk(self.args.workspace): 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        #         for file in files: 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        #             os.unlink(os.path.join(root, file)) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        #     os.rmdir(self.args.workspace) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  TestHuggingFaceModel ( unittest . IsolatedAsyncioTestCase ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    async  def  asyncSetUp ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Set up the Hugging Face model and tokenizer 
							 
						 
					
						
							
								
									
										
										
										
											2025-01-27 18:30:41 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        model_cache_dir  =  os . path . join ( os . path . expanduser ( ' ~ ' ) ,  ' .cache ' ,  ' olmocr ' ,  ' model ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        download_directory ( [ MODEL_FINETUNED_PATH ] ,  model_cache_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Check the rope config and make sure it's got the proper key 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        with  open ( os . path . join ( model_cache_dir ,  " config.json " ) ,  " r " )  as  cfin : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            config_data  =  json . load ( cfin ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  " rope_type "  in  config_data [ " rope_scaling " ] : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            del  config_data [ " rope_scaling " ] [ " rope_type " ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            config_data [ " rope_scaling " ] [ " type " ]  =  " mrope " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            with  open ( os . path . join ( model_cache_dir ,  " config.json " ) ,  " w " )  as  cfout : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                json . dump ( config_data ,  cfout ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . tokenizer  =  AutoTokenizer . from_pretrained ( model_cache_dir ,  trust_remote_code = True ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-12-03 15:32:53 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . image_token_id  =  self . tokenizer . encode ( " <|image_pad|> " ) [ 0 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . model  =  Qwen2VLForConditionalGeneration . from_pretrained ( model_cache_dir ,  torch_dtype = torch . bfloat16 ,  trust_remote_code = True ) . eval ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:25:55 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . processor  =  AutoProcessor . from_pretrained ( " Qwen/Qwen2-VL-7B-Instruct " ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . device  =  torch . device ( " cuda "  if  torch . cuda . is_available ( )  else  " cpu " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . model . to ( self . device ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Path to the test PDF 
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:20:10 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . test_pdf_path  =  Path ( os . path . join ( os . path . dirname ( __file__ ) ,  " gnarly_pdfs " ,  " ambiguous.pdf " ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:00:03 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . maxDiff  =  None 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    async  def  test_hugging_face_generation ( self ) : 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:25:55 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        query  =  await  build_page_query ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                str ( self . test_pdf_path ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                page = 1 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                target_longest_image_dim = 1024 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                target_anchor_text_len = 6000 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-12-03 15:32:53 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        messages  =  query [ " messages " ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:25:55 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       # Apply chat template to get the text 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        text  =  self . processor . apply_chat_template ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query [ " messages " ] ,  tokenize = False ,  add_generation_prompt = True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        image_url  =  query [ " messages " ] [ 0 ] [ " content " ] [ 1 ] [ " image_url " ] [ " url " ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Remove the "data:image/png;base64," prefix 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        base64_image  =  image_url . split ( " , " ) [ 1 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Decode the base64 string into bytes 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        image_data  =  base64 . b64decode ( base64_image ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Create a BytesIO object and load it into a PIL image 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        main_image  =  Image . open ( BytesIO ( image_data ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Process inputs using processor 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        inputs  =  self . processor ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            text = [ text ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            images = [ main_image ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            padding = True , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return_tensors = " pt " , 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-12-03 15:32:53 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        image_indices  =  [ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            idx 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            for  idx ,  token  in  enumerate ( inputs [ " input_ids " ] [ 0 ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  token . item ( )  ==  self . image_token_id 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( " IMAGE INDICES " ,  image_indices ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-26 08:38:25 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        print ( f " image_grid_thw -  { inputs [ ' image_grid_thw ' ] . shape }   { inputs [ ' image_grid_thw ' ] } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " pixel_values -  { inputs [ ' pixel_values ' ] . shape }   { inputs [ ' pixel_values ' ] . detach ( ) . cpu ( ) . numpy ( ) } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        np . save ( ' /root/pixel_values.npy ' ,  inputs [ ' pixel_values ' ] . detach ( ) . cpu ( ) . numpy ( ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:25:55 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        inputs  =  { key :  value . to ( self . device )  for  ( key ,  value )  in  inputs . items ( ) } 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        generated_tokens  =  [ ] 
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:20:10 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        max_steps  =  50 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        top_logprobs_hf  =  [ ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  step  in  range ( max_steps ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            # Generate the output with temperature=0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            generation_output  =  self . model . generate ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                * * inputs , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                temperature = 0.0 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                max_new_tokens = 1 , 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-26 08:38:25 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                #max_length=8192, 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                num_return_sequences = 1 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                do_sample = False , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                output_scores = True , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                return_dict_in_generate = True , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            # Extract the generated token's log probabilities 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            scores  =  generation_output . scores   # Tuple of length 1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            logits  =  scores [ 0 ]   # Tensor of shape (batch_size, vocab_size) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            log_probs  =  F . log_softmax ( logits ,  dim = - 1 )   # Apply log softmax to get log probabilities 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:00:03 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            # Get top 5 tokens and their log probabilities 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            topk_log_probs ,  topk_indices  =  torch . topk ( log_probs [ 0 ] ,  k = 5 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            topk_tokens  =  self . tokenizer . convert_ids_to_tokens ( topk_indices . tolist ( ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:00:03 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            top_logprobs_hf . append ( ( topk_tokens ,  topk_log_probs . tolist ( ) ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:00:03 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            # Pick the top token 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            next_token_id  =  topk_indices [ 0 ] . unsqueeze ( 0 ) . unsqueeze ( 0 )   # Shape: (1, 1) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            next_token_str  =  self . tokenizer . convert_ids_to_tokens ( [ next_token_id . item ( ) ] ) [ 0 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            generated_tokens . append ( next_token_id . item ( ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            # Append the next token to input_ids and update attention_mask 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            inputs [ ' input_ids ' ]  =  torch . cat ( [ inputs [ ' input_ids ' ] ,  next_token_id ] ,  dim = - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            inputs [ ' attention_mask ' ]  =  torch . cat ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                [ inputs [ ' attention_mask ' ] ,  torch . ones ( ( 1 ,  1 ) ,  dtype = inputs [ ' attention_mask ' ] . dtype ) . to ( self . device ) ] ,  dim = - 1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 11:00:03 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-12-03 15:32:53 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        print ( self . tokenizer . decode ( generated_tokens ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        # Now take all the input ids and run them through sglang as a comparison 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        async  with  AsyncClient ( timeout = 600 )  as  session : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query [ " temperature " ]  =  0.0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query [ " max_tokens " ]  =  max_steps  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query [ " logprobs " ]  =  True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query [ " top_logprobs " ]  =  5 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            COMPLETION_URL  =  f " http://localhost: { 30000 } /v1/chat/completions " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            response  =  await  session . post ( COMPLETION_URL ,  json = query ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            response_data  =  response . json ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            for  step ,  lptok  in  enumerate ( response_data [ " choices " ] [ 0 ] [ " logprobs " ] [ " content " ] ) : 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 16:08:24 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                print ( " \n Top 5 tokens and their log probabilities: " ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                ( topk_tokens ,  topk_log_probs )  =  top_logprobs_hf [ step ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                for  token ,  log_prob ,  lptokcur  in  zip ( topk_tokens ,  topk_log_probs ,  lptok [ " top_logprobs " ] ) : 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 16:08:24 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                    print ( f " HF Token:  { token }  Log Prob:  { log_prob : .2f }  Prob  { math . exp ( log_prob ) * 100 : .2f } %  SGLANG Token  { lptokcur [ ' token ' ] }  Logprob  { lptokcur [ ' logprob ' ] : .2f }  Prob  { math . exp ( lptokcur [ ' logprob ' ] ) * 100 : .2f } % " ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:32:18 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 10:12:29 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    async  def  asyncTearDown ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Clean up the model and tokenizer 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        del  self . model 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        del  self . tokenizer 
							 
						 
					
						
							
								
									
										
										
										
											2024-11-25 15:36:04 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        torch . cuda . empty_cache ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:20:10 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								class  RawSGLangTest ( unittest . IsolatedAsyncioTestCase ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  setUp ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Set up the Hugging Face model and tokenizer 
							 
						 
					
						
							
								
									
										
										
										
											2025-01-27 18:30:41 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        model_cache_dir  =  os . path . join ( os . path . expanduser ( ' ~ ' ) ,  ' .cache ' ,  ' olmocr ' ,  ' model ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:20:10 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        download_directory ( [ MODEL_FINETUNED_PATH ] ,  model_cache_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Check the rope config and make sure it's got the proper key 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        with  open ( os . path . join ( model_cache_dir ,  " config.json " ) ,  " r " )  as  cfin : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            config_data  =  json . load ( cfin ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  " rope_type "  in  config_data [ " rope_scaling " ] : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            del  config_data [ " rope_scaling " ] [ " rope_type " ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            config_data [ " rope_scaling " ] [ " type " ]  =  " mrope " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            with  open ( os . path . join ( model_cache_dir ,  " config.json " ) ,  " w " )  as  cfout : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                json . dump ( config_data ,  cfout ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . model_cache_dir  =  model_cache_dir 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . tokenizer  =  AutoTokenizer . from_pretrained ( model_cache_dir ,  trust_remote_code = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . image_token_id  =  self . tokenizer . encode ( " <|image_pad|> " ) [ 0 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . model  =  Qwen2VLForConditionalGeneration . from_pretrained ( model_cache_dir ,  torch_dtype = torch . bfloat16 ,  trust_remote_code = True ) . eval ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . processor  =  AutoProcessor . from_pretrained ( " Qwen/Qwen2-VL-7B-Instruct " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . device  =  torch . device ( " cuda "  if  torch . cuda . is_available ( )  else  " cpu " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . model . to ( self . device ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Path to the test PDF 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . test_pdf_path  =  Path ( os . path . join ( os . path . dirname ( __file__ ) ,  " gnarly_pdfs " ,  " ambiguous.pdf " ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . maxDiff  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    async  def  test_vision_encoder ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        query  =  await  build_page_query ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                str ( self . test_pdf_path ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                page = 1 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                target_longest_image_dim = 1024 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                target_anchor_text_len = 6000 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        messages  =  query [ " messages " ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Apply chat template to get the text 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        text  =  self . processor . apply_chat_template ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query [ " messages " ] ,  tokenize = False ,  add_generation_prompt = True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        image_url  =  query [ " messages " ] [ 0 ] [ " content " ] [ 1 ] [ " image_url " ] [ " url " ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Remove the "data:image/png;base64," prefix 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        base64_image  =  image_url . split ( " , " ) [ 1 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Decode the base64 string into bytes 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        image_data  =  base64 . b64decode ( base64_image ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Create a BytesIO object and load it into a PIL image 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        main_image  =  Image . open ( BytesIO ( image_data ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Process inputs using processor 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        inputs  =  self . processor ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            text = [ text ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            images = [ main_image ] , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            padding = True , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return_tensors = " pt " , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        with  torch . no_grad ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:27:51 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            hf_output  =  self . model . visual ( inputs [ " pixel_values " ] . to ( self . device ) ,  grid_thw = inputs [ " image_grid_thw " ] . to ( self . device ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:20:10 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:27:51 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        print ( " HF " ,  hf_output ,  hf_output . shape ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:20:10 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        from  sglang . srt . configs . model_config  import  ModelConfig 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        from  sglang . srt . managers . schedule_batch  import  Req ,  ScheduleBatch 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        from  sglang . srt . model_executor . forward_batch_info  import  ForwardBatch 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        from  sglang . srt . model_executor . model_runner  import  ModelRunner 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        from  sglang . srt . sampling . sampling_params  import  SamplingParams 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        from  sglang . srt . hf_transformers_utils  import  get_tokenizer 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        from  sglang . srt . server_args  import  ServerArgs ,  PortArgs 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        model_config  =  ModelConfig ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            self . model_cache_dir , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            model_override_args = " {} " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        server_args  =  ServerArgs ( model_path = self . model_cache_dir ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Initialize model runner 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        model_runner  =  ModelRunner ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            model_config = model_config , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            mem_fraction_static = 0.8 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            gpu_id = 0 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            tp_rank = 0 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            tp_size = 1 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            nccl_port = 12435 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            server_args = server_args , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( model_runner ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:27:51 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        with  torch . no_grad ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            sglang_output  =  model_runner . model . visual ( inputs [ " pixel_values " ] . to ( self . device ) ,  grid_thw = inputs [ " image_grid_thw " ] . to ( self . device ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( " SGLANG " ,  sglang_output ,  sglang_output . shape ) 
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:20:10 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2024-12-04 13:27:51 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        # Convert to float32 for numerical stability if needed 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        hf  =  hf_output . float ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sg  =  sglang_output . float ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Basic shape and dtype comparison 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( " \n === Basic Properties === " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Shapes match:  { hf . shape  ==  sg . shape } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " HF shape:  { hf . shape } , SGLang shape:  { sg . shape } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " HF dtype:  { hf . dtype } , SGLang dtype:  { sg . dtype } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Move tensors to CPU for numpy operations 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        hf_np  =  hf . cpu ( ) . numpy ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sg_np  =  sg . cpu ( ) . numpy ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Statistical metrics 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( " \n === Statistical Metrics === " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Mean absolute difference:  { torch . mean ( torch . abs ( hf  -  sg ) ) . item ( ) : .6f } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Max absolute difference:  { torch . max ( torch . abs ( hf  -  sg ) ) . item ( ) : .6f } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Mean squared error:  { torch . mean ( ( hf  -  sg )  * *  2 ) . item ( ) : .6f } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Root mean squared error:  { torch . sqrt ( torch . mean ( ( hf  -  sg )  * *  2 ) ) . item ( ) : .6f } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Cosine similarity (across feature dimension) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        cos_sim  =  F . cosine_similarity ( hf ,  sg ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Mean cosine similarity:  { torch . mean ( cos_sim ) . item ( ) : .6f } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Min cosine similarity:  { torch . min ( cos_sim ) . item ( ) : .6f } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Find largest absolute differences 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( " \n === Largest Absolute Differences === " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        diffs  =  torch . abs ( hf  -  sg ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        flat_diffs  =  diffs . flatten ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Get indices of top 10 differences 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        top_k  =  10 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        top_values ,  top_flat_indices  =  torch . topk ( flat_diffs ,  top_k ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Convert flat indices to multidimensional indices 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        top_indices  =  np . unravel_index ( top_flat_indices . cpu ( ) . numpy ( ) ,  diffs . shape ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " \n Top  { top_k }  largest absolute differences: " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( " Index " . ljust ( 30 )  +  " Difference " . ljust ( 15 )  +  " HF Value " . ljust ( 15 )  +  " SGLang Value " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( " - "  *  75 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  i  in  range ( top_k ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            # Get the index tuple for this difference 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            idx  =  tuple ( dim [ i ]  for  dim  in  top_indices ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            diff_val  =  top_values [ i ] . item ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            hf_val  =  hf [ idx ] . item ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            sg_val  =  sg [ idx ] . item ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            # Format the index tuple and values 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            idx_str  =  str ( idx ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            print ( f " { idx_str : <30 } { diff_val : <15.6f } { hf_val : <15.6f } { sg_val : .6f } " )