mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-02 20:09:08 +00:00
Sglang based unit test
This commit is contained in:
parent
60f24ad2d6
commit
5e3080db28
@ -494,7 +494,13 @@ async def sglang_server_task(args, semaphore):
|
||||
stderr_task = asyncio.create_task(read_stream(proc.stderr))
|
||||
timeout_task = asyncio.create_task(timeout_task())
|
||||
|
||||
await proc.wait()
|
||||
try:
|
||||
await proc.wait()
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Got cancellation for sglang_server_task, terminating server")
|
||||
proc.terminate()
|
||||
raise
|
||||
|
||||
timeout_task.cancel()
|
||||
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
|
||||
|
||||
|
@ -11,6 +11,7 @@ import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from pdelfin.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text
|
||||
from pdelfin.prompts import PageResponse
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@ -31,17 +32,18 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||
# Set up a semaphore for server tasks
|
||||
self.semaphore = asyncio.Semaphore(1)
|
||||
|
||||
# Mock data paths
|
||||
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
|
||||
|
||||
|
||||
async def test_sglang_server_initialization_and_request(self):
|
||||
# Start the sglang server
|
||||
my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
|
||||
self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
|
||||
|
||||
# Wait for the server to become ready
|
||||
await sglang_server_ready()
|
||||
|
||||
|
||||
async def test_sglang_server_initialization_and_request(self):
|
||||
# Mock data paths
|
||||
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
|
||||
|
||||
# Send a single request to the sglang server for page 1
|
||||
async with AsyncClient(timeout=600) as session:
|
||||
query = await build_page_query(
|
||||
@ -64,12 +66,13 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
print(page_response)
|
||||
|
||||
# Shut down the server
|
||||
my_server_task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await my_server_task
|
||||
|
||||
async def asyncTearDown(self):
|
||||
# Shut down the server
|
||||
self.my_server_task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await self.my_server_task
|
||||
|
||||
# Cleanup temporary workspace
|
||||
if os.path.exists(self.args.workspace):
|
||||
for root, _, files in os.walk(self.args.workspace):
|
||||
|
Loading…
x
Reference in New Issue
Block a user