olmocr/tests/test_sglang.py

267 lines
12 KiB
Python
Raw Normal View History

2024-11-25 09:34:59 -08:00
# The idea is that you have a Qwen2-VL-7B model located here:s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
# You need to load it in both hugging face transformers, and send page 1 of edgar.pdf to it from tests/gnarly_pdfs
# Compare that the temperature 0 sampled result is the same
import asyncio
import unittest
from unittest.mock import patch, AsyncMock
import os
import json
import tempfile
2024-11-25 11:00:03 -08:00
import math
2024-11-25 10:25:55 -08:00
import base64
2024-11-25 10:12:29 -08:00
import torch
2024-11-26 08:38:25 -08:00
import numpy as np
2024-11-25 10:25:55 -08:00
from io import BytesIO
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration
2024-11-25 09:34:59 -08:00
from pathlib import Path
2024-11-25 10:12:29 -08:00
from pdelfin.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text, download_directory
2024-11-25 09:48:05 -08:00
from pdelfin.prompts import PageResponse
2024-11-25 09:34:59 -08:00
from httpx import AsyncClient
2024-11-25 11:00:03 -08:00
import torch.nn.functional as F
2024-11-25 10:12:29 -08:00
MODEL_FINETUNED_PATH = "s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
EDGAR_TEXT = (
"Edgar, King of England\n\nEdgar (or Eadgar;[1] c. 944 8 July 975) was King of the English from 959 until his death in 975. "
"He became king of all England on his brother's death. He was the younger son of King Edmund I and his first wife Ælfgifu. "
"A detailed account of Edgar's reign is not possible, because only a few events were recorded by chroniclers and monastic writers "
"were more interested in recording the activities of the leaders of the church.\n\nEdgar mainly followed the political policies of his predecessors, "
"but there were major changes in the religious sphere. The English Benedictine Reform, which he strongly supported, became a dominant religious and social force.[2] "
"It is seen by historians as a major achievement, and it was accompanied by a literary and artistic flowering, mainly associated with Æthelwold, Bishop of Winchester. "
"Monasteries aggressively acquired estates from lay landowners with Edgar's assistance, leading to disorder when he died and former owners sought to recover their lost property, "
"sometimes by force. Edgar's major administrative reform was the introduction of a standardised coinage in the early 970s to replace the previous decentralised system. "
"He also issued legislative codes which mainly concentrated on improving procedures for enforcement of the law.\n\nEngland had suffered from Viking invasions for over a century "
"when Edgar came to power, but there were none during his reign, which fell in a lull in attacks between the mid-950s and the early 980s.[3] After his death the throne was disputed "
"between the supporters of his two surviving sons; the elder one, Edward the Martyr, was chosen with the support of Dunstan, the Archbishop of Canterbury. Three years later Edward was "
"murdered and succeeded by his younger half-brother, Æthelred the Unready. Later chroniclers presented Edgar's reign as a golden age when England was free from external attacks and internal disorder, especially"
)
2024-11-25 09:34:59 -08:00
class TestSglangServer(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
# Mock arguments
self.args = AsyncMock()
self.args.workspace = "/tmp/test_workspace"
2024-11-25 10:12:29 -08:00
self.args.model = [MODEL_FINETUNED_PATH]
2024-11-25 09:34:59 -08:00
self.args.model_chat_template = "qwen2-vl"
self.args.target_longest_image_dim = 1024
self.args.target_anchor_text_len = 6000
self.args.model_max_context = 8192
# Create a temporary workspace directory
os.makedirs(self.args.workspace, exist_ok=True)
# Set up a semaphore for server tasks
self.semaphore = asyncio.Semaphore(1)
2024-11-25 11:00:03 -08:00
self.maxDiff = None
2024-11-25 09:34:59 -08:00
2024-11-25 11:24:21 -08:00
# # Start the sglang server
# self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
2024-11-25 09:34:59 -08:00
2024-11-25 11:24:21 -08:00
# # Wait for the server to become ready
# await sglang_server_ready()
2024-11-25 09:48:05 -08:00
2024-11-25 11:24:21 -08:00
async def test_sglang_server_initialization_and_request(self):
2024-11-25 09:48:05 -08:00
# Mock data paths
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
2024-11-25 09:34:59 -08:00
# Send a single request to the sglang server for page 1
async with AsyncClient(timeout=600) as session:
query = await build_page_query(
str(self.test_pdf_path),
page=1,
target_longest_image_dim=self.args.target_longest_image_dim,
target_anchor_text_len=self.args.target_anchor_text_len,
)
2024-11-25 11:24:21 -08:00
COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
query["temperature"] = 0.0
query["logprobs"] = True
2024-11-25 15:36:04 -08:00
query["top_logprobs"] = 5
2024-11-25 09:34:59 -08:00
response = await session.post(COMPLETION_URL, json=query)
2024-11-25 11:24:21 -08:00
print(response.text)
2024-11-25 09:34:59 -08:00
# Check the server response
self.assertEqual(response.status_code, 200)
response_data = response.json()
self.assertIn("choices", response_data)
self.assertGreater(len(response_data["choices"]), 0)
2024-11-25 11:24:21 -08:00
2024-11-25 09:39:55 -08:00
model_response_json = json.loads(response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)
print(page_response)
2024-11-25 10:12:29 -08:00
self.assertEqual(page_response.natural_text, EDGAR_TEXT)
2024-11-25 09:48:05 -08:00
async def asyncTearDown(self):
2024-11-25 11:24:21 -08:00
pass
# # Shut down the server
# self.my_server_task.cancel()
# with self.assertRaises(asyncio.CancelledError):
# await self.my_server_task
# # Cleanup temporary workspace
# if os.path.exists(self.args.workspace):
# for root, _, files in os.walk(self.args.workspace):
# for file in files:
# os.unlink(os.path.join(root, file))
# os.rmdir(self.args.workspace)
2024-11-25 10:12:29 -08:00
class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
# Set up the Hugging Face model and tokenizer
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
download_directory([MODEL_FINETUNED_PATH], model_cache_dir)
# Check the rope config and make sure it's got the proper key
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
config_data = json.load(cfin)
if "rope_type" in config_data["rope_scaling"]:
del config_data["rope_scaling"]["rope_type"]
config_data["rope_scaling"]["type"] = "mrope"
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
json.dump(config_data, cfout)
self.tokenizer = AutoTokenizer.from_pretrained(model_cache_dir, trust_remote_code=True)
2024-12-03 15:32:53 -08:00
self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0]
2024-11-25 10:12:29 -08:00
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_cache_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
2024-11-25 10:25:55 -08:00
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
2024-11-25 10:12:29 -08:00
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# Path to the test PDF
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
2024-11-25 11:00:03 -08:00
self.maxDiff = None
2024-11-25 10:12:29 -08:00
async def test_hugging_face_generation(self):
2024-11-25 10:25:55 -08:00
query = await build_page_query(
str(self.test_pdf_path),
page=1,
target_longest_image_dim=1024,
target_anchor_text_len=6000,
)
2024-12-03 15:32:53 -08:00
messages = query["messages"]
2024-11-25 10:25:55 -08:00
# Apply chat template to get the text
text = self.processor.apply_chat_template(
query["messages"], tokenize=False, add_generation_prompt=True
)
image_url = query["messages"][0]["content"][1]["image_url"]["url"]
# Remove the "data:image/png;base64," prefix
base64_image = image_url.split(",")[1]
# Decode the base64 string into bytes
image_data = base64.b64decode(base64_image)
# Create a BytesIO object and load it into a PIL image
main_image = Image.open(BytesIO(image_data))
# Process inputs using processor
inputs = self.processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
2024-11-25 10:12:29 -08:00
)
2024-12-03 15:32:53 -08:00
image_indices = [
idx
for idx, token in enumerate(inputs["input_ids"][0])
if token.item() == self.image_token_id
]
print("IMAGE INDICES", image_indices)
2024-11-26 08:38:25 -08:00
print(f"image_grid_thw - {inputs['image_grid_thw'].shape} {inputs['image_grid_thw']}")
print(f"pixel_values - {inputs['pixel_values'].shape} {inputs['pixel_values'].detach().cpu().numpy()}")
np.save('/root/pixel_values.npy', inputs['pixel_values'].detach().cpu().numpy())
2024-11-25 10:25:55 -08:00
inputs = {key: value.to(self.device) for (key, value) in inputs.items()}
2024-11-25 10:12:29 -08:00
2024-11-25 15:36:04 -08:00
generated_tokens = []
2024-11-25 16:08:24 -08:00
max_steps = 100
2024-11-25 15:36:04 -08:00
top_logprobs_hf = []
for step in range(max_steps):
# Generate the output with temperature=0
generation_output = self.model.generate(
**inputs,
temperature=0.0,
max_new_tokens=1,
2024-11-26 08:38:25 -08:00
#max_length=8192,
2024-11-25 15:36:04 -08:00
num_return_sequences=1,
do_sample=False,
output_scores=True,
return_dict_in_generate=True,
)
2024-11-25 10:12:29 -08:00
2024-11-25 15:36:04 -08:00
# Extract the generated token's log probabilities
scores = generation_output.scores # Tuple of length 1
logits = scores[0] # Tensor of shape (batch_size, vocab_size)
log_probs = F.log_softmax(logits, dim=-1) # Apply log softmax to get log probabilities
2024-11-25 11:00:03 -08:00
2024-11-25 15:36:04 -08:00
# Get top 5 tokens and their log probabilities
topk_log_probs, topk_indices = torch.topk(log_probs[0], k=5)
topk_tokens = self.tokenizer.convert_ids_to_tokens(topk_indices.tolist())
2024-11-25 11:00:03 -08:00
2024-11-25 15:36:04 -08:00
top_logprobs_hf.append((topk_tokens, topk_log_probs.tolist()))
2024-11-25 11:00:03 -08:00
2024-11-25 15:36:04 -08:00
# Pick the top token
next_token_id = topk_indices[0].unsqueeze(0).unsqueeze(0) # Shape: (1, 1)
next_token_str = self.tokenizer.convert_ids_to_tokens([next_token_id.item()])[0]
generated_tokens.append(next_token_id.item())
# Append the next token to input_ids and update attention_mask
inputs['input_ids'] = torch.cat([inputs['input_ids'], next_token_id], dim=-1)
inputs['attention_mask'] = torch.cat(
[inputs['attention_mask'], torch.ones((1, 1), dtype=inputs['attention_mask'].dtype).to(self.device)], dim=-1
)
2024-11-25 11:00:03 -08:00
2024-12-03 15:32:53 -08:00
print(self.tokenizer.decode(generated_tokens))
2024-11-25 15:36:04 -08:00
# Now take all the input ids and run them through sglang as a comparison
async with AsyncClient(timeout=600) as session:
query["temperature"] = 0.0
query["max_tokens"] = max_steps
query["logprobs"] = True
query["top_logprobs"] = 5
COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
response = await session.post(COMPLETION_URL, json=query)
2024-11-25 10:12:29 -08:00
2024-11-25 15:36:04 -08:00
response_data = response.json()
2024-11-25 10:12:29 -08:00
2024-11-25 15:36:04 -08:00
for step, lptok in enumerate(response_data["choices"][0]["logprobs"]["content"]):
2024-11-25 16:08:24 -08:00
print("\nTop 5 tokens and their log probabilities:")
2024-11-25 15:36:04 -08:00
(topk_tokens, topk_log_probs) = top_logprobs_hf[step]
for token, log_prob, lptokcur in zip(topk_tokens, topk_log_probs, lptok["top_logprobs"]):
2024-11-25 16:08:24 -08:00
print(f"HF Token: {token} Log Prob: {log_prob:.2f} Prob {math.exp(log_prob)*100:.2f}% SGLANG Token {lptokcur['token']} Logprob {lptokcur['logprob']:.2f} Prob {math.exp(lptokcur['logprob'])*100:.2f}%")
2024-11-25 10:32:18 -08:00
2024-11-25 16:08:24 -08:00
2024-11-25 15:36:04 -08:00
2024-11-25 10:12:29 -08:00
async def asyncTearDown(self):
# Clean up the model and tokenizer
del self.model
del self.tokenizer
2024-11-25 15:36:04 -08:00
torch.cuda.empty_cache()