Sglang based unit test

This commit is contained in:
Jake Poznanski 2024-11-25 09:48:05 -08:00
parent 60f24ad2d6
commit 5e3080db28
2 changed files with 19 additions and 10 deletions

View File

@ -494,7 +494,13 @@ async def sglang_server_task(args, semaphore):
stderr_task = asyncio.create_task(read_stream(proc.stderr))
timeout_task = asyncio.create_task(timeout_task())
await proc.wait()
try:
await proc.wait()
except asyncio.CancelledError:
logger.warning("Got cancellation for sglang_server_task, terminating server")
proc.terminate()
raise
timeout_task.cancel()
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)

View File

@ -11,6 +11,7 @@ import json
import tempfile
from pathlib import Path
from pdelfin.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text
from pdelfin.prompts import PageResponse
from httpx import AsyncClient
@ -31,17 +32,18 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
# Set up a semaphore for server tasks
self.semaphore = asyncio.Semaphore(1)
# Mock data paths
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
async def test_sglang_server_initialization_and_request(self):
# Start the sglang server
my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
# Wait for the server to become ready
await sglang_server_ready()
async def test_sglang_server_initialization_and_request(self):
# Mock data paths
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
# Send a single request to the sglang server for page 1
async with AsyncClient(timeout=600) as session:
query = await build_page_query(
@ -64,12 +66,13 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
print(page_response)
# Shut down the server
my_server_task.cancel()
with self.assertRaises(asyncio.CancelledError):
await my_server_task
async def asyncTearDown(self):
# Shut down the server
self.my_server_task.cancel()
with self.assertRaises(asyncio.CancelledError):
await self.my_server_task
# Cleanup temporary workspace
if os.path.exists(self.args.workspace):
for root, _, files in os.walk(self.args.workspace):