diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index 4fa49d5..92c7592 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -494,7 +494,13 @@ async def sglang_server_task(args, semaphore): stderr_task = asyncio.create_task(read_stream(proc.stderr)) timeout_task = asyncio.create_task(timeout_task()) - await proc.wait() + try: + await proc.wait() + except asyncio.CancelledError: + logger.warning("Got cancellation for sglang_server_task, terminating server") + proc.terminate() + raise + timeout_task.cancel() await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True) diff --git a/tests/test_sglang.py b/tests/test_sglang.py index d9d3aa7..3382f2b 100644 --- a/tests/test_sglang.py +++ b/tests/test_sglang.py @@ -11,6 +11,7 @@ import json import tempfile from pathlib import Path from pdelfin.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text +from pdelfin.prompts import PageResponse from httpx import AsyncClient @@ -31,17 +32,18 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase): # Set up a semaphore for server tasks self.semaphore = asyncio.Semaphore(1) - # Mock data paths - self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf")) - - async def test_sglang_server_initialization_and_request(self): # Start the sglang server - my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore)) + self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore)) # Wait for the server to become ready await sglang_server_ready() + + async def test_sglang_server_initialization_and_request(self): + # Mock data paths + self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf")) + # Send a single request to the sglang server for page 1 async with AsyncClient(timeout=600) as session: query = await build_page_query( @@ -64,12 +66,13 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase): print(page_response) - # Shut down the server - my_server_task.cancel() - with self.assertRaises(asyncio.CancelledError): - await my_server_task async def asyncTearDown(self): + # Shut down the server + self.my_server_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await self.my_server_task + # Cleanup temporary workspace if os.path.exists(self.args.workspace): for root, _, files in os.walk(self.args.workspace):