diff --git a/tests/test_sglang.py b/tests/test_sglang.py index c07e27a..0e94184 100644 --- a/tests/test_sglang.py +++ b/tests/test_sglang.py @@ -132,6 +132,9 @@ class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase): json.dump(config_data, cfout) self.tokenizer = AutoTokenizer.from_pretrained(model_cache_dir, trust_remote_code=True) + self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0] + + self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_cache_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).eval() self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -149,13 +152,13 @@ class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase): target_anchor_text_len=6000, ) + messages = query["messages"] + # Apply chat template to get the text text = self.processor.apply_chat_template( query["messages"], tokenize=False, add_generation_prompt=True ) - print(text) - image_url = query["messages"][0]["content"][1]["image_url"]["url"] # Remove the "data:image/png;base64," prefix @@ -175,6 +178,14 @@ class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase): return_tensors="pt", ) + image_indices = [ + idx + for idx, token in enumerate(inputs["input_ids"][0]) + if token.item() == self.image_token_id + ] + + print("IMAGE INDICES", image_indices) + print(f"image_grid_thw - {inputs['image_grid_thw'].shape} {inputs['image_grid_thw']}") print(f"pixel_values - {inputs['pixel_values'].shape} {inputs['pixel_values'].detach().cpu().numpy()}") np.save('/root/pixel_values.npy', inputs['pixel_values'].detach().cpu().numpy()) @@ -224,6 +235,8 @@ class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase): [inputs['attention_mask'], torch.ones((1, 1), dtype=inputs['attention_mask'].dtype).to(self.device)], dim=-1 ) + print(self.tokenizer.decode(generated_tokens)) + # Now take all the input ids and run them through sglang as a comparison async with AsyncClient(timeout=600) as session: query["temperature"] = 0.0