More log probs investigation

This commit is contained in:
Jake Poznanski 2024-11-25 11:24:21 -08:00
parent 28d52602e9
commit 51614efc83

View File

@ -55,20 +55,13 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
self.semaphore = asyncio.Semaphore(1)
self.maxDiff = None
# Start the sglang server
self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
# # Start the sglang server
# self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
# Wait for the server to become ready
await sglang_server_ready()
@patch("pdelfin.beakerpipeline.build_page_query", autospec=True)
async def test_sglang_server_initialization_and_request(self, mock_build_page_query):
# Mock the build_page_query function to set temperature to 0.0
async def mocked_build_page_query(*args, **kwargs):
query = await main.build_page_query(*args, **kwargs)
query["temperature"] = 0.0 # Override temperature
return query
# # Wait for the server to become ready
# await sglang_server_ready()
async def test_sglang_server_initialization_and_request(self):
# Mock data paths
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
@ -80,15 +73,22 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
target_longest_image_dim=self.args.target_longest_image_dim,
target_anchor_text_len=self.args.target_anchor_text_len,
)
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
query["temperature"] = 0.0
query["logprobs"] = True
response = await session.post(COMPLETION_URL, json=query)
print(response.text)
# Check the server response
self.assertEqual(response.status_code, 200)
response_data = response.json()
self.assertIn("choices", response_data)
self.assertGreater(len(response_data["choices"]), 0)
model_response_json = json.loads(response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)
@ -98,17 +98,18 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
async def asyncTearDown(self):
# Shut down the server
self.my_server_task.cancel()
with self.assertRaises(asyncio.CancelledError):
await self.my_server_task
pass
# # Shut down the server
# self.my_server_task.cancel()
# with self.assertRaises(asyncio.CancelledError):
# await self.my_server_task
# Cleanup temporary workspace
if os.path.exists(self.args.workspace):
for root, _, files in os.walk(self.args.workspace):
for file in files:
os.unlink(os.path.join(root, file))
os.rmdir(self.args.workspace)
# # Cleanup temporary workspace
# if os.path.exists(self.args.workspace):
# for root, _, files in os.walk(self.args.workspace):
# for file in files:
# os.unlink(os.path.join(root, file))
# os.rmdir(self.args.workspace)
class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase):