mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-25 06:06:23 +00:00
Working on HF test for comparison
This commit is contained in:
parent
5e3080db28
commit
2e4f7d7827
@ -9,18 +9,36 @@ from unittest.mock import patch, AsyncMock
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2VLForConditionalGeneration
|
||||
from pathlib import Path
|
||||
from pdelfin.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text
|
||||
from pdelfin.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text, download_directory
|
||||
from pdelfin.prompts import PageResponse
|
||||
from httpx import AsyncClient
|
||||
|
||||
MODEL_FINETUNED_PATH = "s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
|
||||
|
||||
EDGAR_TEXT = (
|
||||
"Edgar, King of England\n\nEdgar (or Eadgar;[1] c. 944 – 8 July 975) was King of the English from 959 until his death in 975. "
|
||||
"He became king of all England on his brother's death. He was the younger son of King Edmund I and his first wife Ælfgifu. "
|
||||
"A detailed account of Edgar's reign is not possible, because only a few events were recorded by chroniclers and monastic writers "
|
||||
"were more interested in recording the activities of the leaders of the church.\n\nEdgar mainly followed the political policies of his predecessors, "
|
||||
"but there were major changes in the religious sphere. The English Benedictine Reform, which he strongly supported, became a dominant religious and social force.[2] "
|
||||
"It is seen by historians as a major achievement, and it was accompanied by a literary and artistic flowering, mainly associated with Æthelwold, Bishop of Winchester. "
|
||||
"Monasteries aggressively acquired estates from lay landowners with Edgar's assistance, leading to disorder when he died and former owners sought to recover their lost property, "
|
||||
"sometimes by force. Edgar's major administrative reform was the introduction of a standardised coinage in the early 970s to replace the previous decentralised system. "
|
||||
"He also issued legislative codes which mainly concentrated on improving procedures for enforcement of the law.\n\nEngland had suffered from Viking invasions for over a century "
|
||||
"when Edgar came to power, but there were none during his reign, which fell in a lull in attacks between the mid-950s and the early 980s.[3] After his death the throne was disputed "
|
||||
"between the supporters of his two surviving sons; the elder one, Edward the Martyr, was chosen with the support of Dunstan, the Archbishop of Canterbury. Three years later Edward was "
|
||||
"murdered and succeeded by his younger half-brother, Æthelred the Unready. Later chroniclers presented Edgar's reign as a golden age when England was free from external attacks and internal disorder, especially"
|
||||
)
|
||||
|
||||
class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
# Mock arguments
|
||||
self.args = AsyncMock()
|
||||
self.args.workspace = "/tmp/test_workspace"
|
||||
self.args.model = ["s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"]
|
||||
self.args.model = [MODEL_FINETUNED_PATH]
|
||||
self.args.model_chat_template = "qwen2-vl"
|
||||
self.args.target_longest_image_dim = 1024
|
||||
self.args.target_anchor_text_len = 6000
|
||||
@ -39,8 +57,14 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||
# Wait for the server to become ready
|
||||
await sglang_server_ready()
|
||||
|
||||
@patch("pdelfin.beakerpipeline.build_page_query", autospec=True)
|
||||
async def test_sglang_server_initialization_and_request(self, mock_build_page_query):
|
||||
# Mock the build_page_query function to set temperature to 0.0
|
||||
async def mocked_build_page_query(*args, **kwargs):
|
||||
query = await main.build_page_query(*args, **kwargs)
|
||||
query["temperature"] = 0.0 # Override temperature
|
||||
return query
|
||||
|
||||
async def test_sglang_server_initialization_and_request(self):
|
||||
# Mock data paths
|
||||
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
|
||||
|
||||
@ -66,6 +90,8 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
print(page_response)
|
||||
|
||||
self.assertEqual(page_response.natural_text, EDGAR_TEXT)
|
||||
|
||||
|
||||
async def asyncTearDown(self):
|
||||
# Shut down the server
|
||||
@ -79,3 +105,67 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||
for file in files:
|
||||
os.unlink(os.path.join(root, file))
|
||||
os.rmdir(self.args.workspace)
|
||||
|
||||
|
||||
class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
# Set up the Hugging Face model and tokenizer
|
||||
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
||||
download_directory([MODEL_FINETUNED_PATH], model_cache_dir)
|
||||
|
||||
# Check the rope config and make sure it's got the proper key
|
||||
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
|
||||
config_data = json.load(cfin)
|
||||
|
||||
if "rope_type" in config_data["rope_scaling"]:
|
||||
del config_data["rope_scaling"]["rope_type"]
|
||||
config_data["rope_scaling"]["type"] = "mrope"
|
||||
|
||||
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
|
||||
json.dump(config_data, cfout)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_cache_dir, trust_remote_code=True)
|
||||
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_cache_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model.to(self.device)
|
||||
|
||||
# Path to the test PDF
|
||||
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
|
||||
|
||||
async def test_hugging_face_generation(self):
|
||||
# Prepare the input text for the model (mock extracted text for page 1 of the PDF)
|
||||
input_text = (
|
||||
"Extracted content of page 1 of edgar.pdf. "
|
||||
"Convert to natural text with proper formatting and summarization:"
|
||||
)
|
||||
|
||||
# Tokenize the input
|
||||
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.device)
|
||||
|
||||
# Generate the output with temperature=0
|
||||
generation_output = self.model.generate(
|
||||
**inputs,
|
||||
temperature=0.0,
|
||||
max_new_tokens=100,
|
||||
max_length=8192,
|
||||
num_return_sequences=1,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
# Decode the output
|
||||
decoded_output = self.tokenizer.decode(generation_output[0], skip_special_tokens=True)
|
||||
|
||||
print(decoded_output)
|
||||
|
||||
# Convert the decoded output into the expected PageResponse structure
|
||||
generated_response = PageResponse(natural_text=decoded_output)
|
||||
|
||||
|
||||
# Assert the output matches the expected text
|
||||
self.assertEqual(generated_response.natural_text, EDGAR_TEXT)
|
||||
|
||||
async def asyncTearDown(self):
|
||||
# Clean up the model and tokenizer
|
||||
del self.model
|
||||
del self.tokenizer
|
||||
torch.cuda.empty_cache()
|
||||
Loading…
x
Reference in New Issue
Block a user