mirror of
https://github.com/allenai/olmocr.git
synced 2026-01-07 12:51:39 +00:00
More log probs investigation
This commit is contained in:
parent
28d52602e9
commit
51614efc83
@ -55,20 +55,13 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||
self.semaphore = asyncio.Semaphore(1)
|
||||
self.maxDiff = None
|
||||
|
||||
# Start the sglang server
|
||||
self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
|
||||
# # Start the sglang server
|
||||
# self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
|
||||
|
||||
# Wait for the server to become ready
|
||||
await sglang_server_ready()
|
||||
|
||||
@patch("pdelfin.beakerpipeline.build_page_query", autospec=True)
|
||||
async def test_sglang_server_initialization_and_request(self, mock_build_page_query):
|
||||
# Mock the build_page_query function to set temperature to 0.0
|
||||
async def mocked_build_page_query(*args, **kwargs):
|
||||
query = await main.build_page_query(*args, **kwargs)
|
||||
query["temperature"] = 0.0 # Override temperature
|
||||
return query
|
||||
# # Wait for the server to become ready
|
||||
# await sglang_server_ready()
|
||||
|
||||
async def test_sglang_server_initialization_and_request(self):
|
||||
# Mock data paths
|
||||
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
|
||||
|
||||
@ -80,15 +73,22 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||
target_longest_image_dim=self.args.target_longest_image_dim,
|
||||
target_anchor_text_len=self.args.target_anchor_text_len,
|
||||
)
|
||||
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||
COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
|
||||
|
||||
query["temperature"] = 0.0
|
||||
query["logprobs"] = True
|
||||
response = await session.post(COMPLETION_URL, json=query)
|
||||
|
||||
print(response.text)
|
||||
|
||||
# Check the server response
|
||||
self.assertEqual(response.status_code, 200)
|
||||
response_data = response.json()
|
||||
self.assertIn("choices", response_data)
|
||||
self.assertGreater(len(response_data["choices"]), 0)
|
||||
|
||||
|
||||
|
||||
model_response_json = json.loads(response_data["choices"][0]["message"]["content"])
|
||||
page_response = PageResponse(**model_response_json)
|
||||
|
||||
@ -98,17 +98,18 @@ class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
|
||||
async def asyncTearDown(self):
|
||||
# Shut down the server
|
||||
self.my_server_task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await self.my_server_task
|
||||
pass
|
||||
# # Shut down the server
|
||||
# self.my_server_task.cancel()
|
||||
# with self.assertRaises(asyncio.CancelledError):
|
||||
# await self.my_server_task
|
||||
|
||||
# Cleanup temporary workspace
|
||||
if os.path.exists(self.args.workspace):
|
||||
for root, _, files in os.walk(self.args.workspace):
|
||||
for file in files:
|
||||
os.unlink(os.path.join(root, file))
|
||||
os.rmdir(self.args.workspace)
|
||||
# # Cleanup temporary workspace
|
||||
# if os.path.exists(self.args.workspace):
|
||||
# for root, _, files in os.walk(self.args.workspace):
|
||||
# for file in files:
|
||||
# os.unlink(os.path.join(root, file))
|
||||
# os.rmdir(self.args.workspace)
|
||||
|
||||
|
||||
class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user