2024-11-18 10:07:03 -08:00
|
|
|
import asyncio
|
|
|
|
import datetime
|
|
|
|
import hashlib
|
2025-01-29 15:25:10 -08:00
|
|
|
import unittest
|
|
|
|
from typing import Dict, List
|
|
|
|
from unittest.mock import Mock, call, patch
|
|
|
|
|
|
|
|
from botocore.exceptions import ClientError
|
2024-11-18 10:07:03 -08:00
|
|
|
|
|
|
|
# Import the classes we're testing
|
2025-01-27 20:45:28 +00:00
|
|
|
from olmocr.work_queue import S3WorkQueue, WorkItem
|
2024-11-18 10:07:03 -08:00
|
|
|
|
2025-01-29 15:25:10 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
class TestS3WorkQueue(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
|
|
"""Set up test fixtures before each test method."""
|
|
|
|
self.s3_client = Mock()
|
|
|
|
self.s3_client.exceptions.ClientError = ClientError
|
|
|
|
self.work_queue = S3WorkQueue(self.s3_client, "s3://test-bucket/workspace")
|
|
|
|
self.sample_paths = [
|
|
|
|
"s3://test-bucket/data/file1.pdf",
|
|
|
|
"s3://test-bucket/data/file2.pdf",
|
|
|
|
"s3://test-bucket/data/file3.pdf",
|
|
|
|
]
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
"""Clean up after each test method."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
def test_compute_workgroup_hash(self):
|
|
|
|
"""Test hash computation is deterministic and correct"""
|
|
|
|
paths = [
|
|
|
|
"s3://test-bucket/data/file2.pdf",
|
|
|
|
"s3://test-bucket/data/file1.pdf",
|
|
|
|
]
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Hash should be the same regardless of order
|
|
|
|
hash1 = S3WorkQueue._compute_workgroup_hash(paths)
|
|
|
|
hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
|
|
|
|
self.assertEqual(hash1, hash2)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
def test_init(self):
|
|
|
|
"""Test initialization of S3WorkQueue"""
|
|
|
|
client = Mock()
|
|
|
|
queue = S3WorkQueue(client, "s3://test-bucket/workspace/")
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
|
|
|
|
self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
|
|
|
|
self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
|
|
|
|
|
|
|
|
def asyncSetUp(self):
|
|
|
|
"""Set up async test fixtures"""
|
|
|
|
self.loop = asyncio.new_event_loop()
|
|
|
|
asyncio.set_event_loop(self.loop)
|
|
|
|
|
|
|
|
def asyncTearDown(self):
|
|
|
|
"""Clean up async test fixtures"""
|
|
|
|
self.loop.close()
|
|
|
|
|
|
|
|
def async_test(f):
|
|
|
|
"""Decorator for async test methods"""
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
try:
|
|
|
|
return loop.run_until_complete(f(*args, **kwargs))
|
|
|
|
finally:
|
|
|
|
loop.close()
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
return wrapper
|
|
|
|
|
|
|
|
@async_test
|
|
|
|
async def test_populate_queue_new_items(self):
|
|
|
|
"""Test populating queue with new items"""
|
|
|
|
# Mock empty existing index
|
2025-01-29 15:30:39 -08:00
|
|
|
with patch("olmocr.work_queue.download_zstd_csv", return_value=[]):
|
|
|
|
with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
|
2024-11-18 10:07:03 -08:00
|
|
|
await self.work_queue.populate_queue(self.sample_paths, items_per_group=2)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Verify upload was called with correct data
|
|
|
|
self.assertEqual(mock_upload.call_count, 1)
|
|
|
|
_, _, lines = mock_upload.call_args[0]
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Should create 2 work groups (2 files + 1 file)
|
|
|
|
self.assertEqual(len(lines), 2)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Verify format of uploaded lines
|
|
|
|
for line in lines:
|
2025-01-29 15:30:39 -08:00
|
|
|
parts = line.split(",")
|
2024-11-18 10:07:03 -08:00
|
|
|
self.assertGreaterEqual(len(parts), 2) # Hash + at least one path
|
|
|
|
self.assertEqual(len(parts[0]), 40) # SHA1 hash length
|
|
|
|
|
|
|
|
@async_test
|
|
|
|
async def test_populate_queue_existing_items(self):
|
|
|
|
"""Test populating queue with mix of new and existing items"""
|
|
|
|
existing_paths = ["s3://test-bucket/data/existing1.pdf"]
|
|
|
|
new_paths = ["s3://test-bucket/data/new1.pdf"]
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Create existing index content
|
|
|
|
existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths)
|
|
|
|
existing_line = f"{existing_hash},{existing_paths[0]}"
|
2025-01-29 15:30:39 -08:00
|
|
|
|
|
|
|
with patch("olmocr.work_queue.download_zstd_csv", return_value=[existing_line]):
|
|
|
|
with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
|
2024-11-18 10:07:03 -08:00
|
|
|
await self.work_queue.populate_queue(existing_paths + new_paths, items_per_group=1)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Verify upload called with both existing and new items
|
|
|
|
_, _, lines = mock_upload.call_args[0]
|
|
|
|
self.assertEqual(len(lines), 2)
|
|
|
|
self.assertIn(existing_line, lines)
|
|
|
|
|
|
|
|
@async_test
|
|
|
|
async def test_initialize_queue(self):
|
|
|
|
"""Test queue initialization"""
|
|
|
|
# Mock work items and completed items
|
|
|
|
work_paths = ["s3://test/file1.pdf", "s3://test/file2.pdf"]
|
|
|
|
work_hash = S3WorkQueue._compute_workgroup_hash(work_paths)
|
|
|
|
work_line = f"{work_hash},{work_paths[0]},{work_paths[1]}"
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"]
|
2025-01-29 15:30:39 -08:00
|
|
|
|
|
|
|
with patch("olmocr.work_queue.download_zstd_csv", return_value=[work_line]):
|
|
|
|
with patch("olmocr.work_queue.expand_s3_glob", return_value=completed_items):
|
2024-11-18 10:07:03 -08:00
|
|
|
await self.work_queue.initialize_queue()
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Queue should be empty since all work is completed
|
|
|
|
self.assertTrue(self.work_queue._queue.empty())
|
|
|
|
|
|
|
|
@async_test
|
|
|
|
async def test_is_completed(self):
|
|
|
|
"""Test completed work check"""
|
|
|
|
work_hash = "testhash123"
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Test completed work
|
2025-01-29 15:30:39 -08:00
|
|
|
self.s3_client.head_object.return_value = {"LastModified": datetime.datetime.now(datetime.timezone.utc)}
|
2024-11-18 10:07:03 -08:00
|
|
|
self.assertTrue(await self.work_queue.is_completed(work_hash))
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Test incomplete work
|
2025-01-29 15:30:39 -08:00
|
|
|
self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
|
2024-11-18 10:07:03 -08:00
|
|
|
self.assertFalse(await self.work_queue.is_completed(work_hash))
|
|
|
|
|
|
|
|
@async_test
|
|
|
|
async def test_get_work(self):
|
|
|
|
"""Test getting work items"""
|
|
|
|
# Setup test data
|
2025-01-27 20:45:28 +00:00
|
|
|
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
|
2024-11-18 10:07:03 -08:00
|
|
|
await self.work_queue._queue.put(work_item)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Test getting available work
|
2025-01-29 15:30:39 -08:00
|
|
|
self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
|
2024-11-18 10:07:03 -08:00
|
|
|
result = await self.work_queue.get_work()
|
|
|
|
self.assertEqual(result, work_item)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Verify lock file was created
|
|
|
|
self.s3_client.put_object.assert_called_once()
|
2025-01-29 15:30:39 -08:00
|
|
|
bucket, key = self.s3_client.put_object.call_args[1]["Bucket"], self.s3_client.put_object.call_args[1]["Key"]
|
2024-11-18 10:07:03 -08:00
|
|
|
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
|
|
|
|
|
|
|
|
@async_test
|
|
|
|
async def test_get_work_completed(self):
|
|
|
|
"""Test getting work that's already completed"""
|
2025-01-27 20:45:28 +00:00
|
|
|
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
|
2024-11-18 10:07:03 -08:00
|
|
|
await self.work_queue._queue.put(work_item)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Simulate completed work
|
2025-01-29 15:30:39 -08:00
|
|
|
self.s3_client.head_object.return_value = {"LastModified": datetime.datetime.now(datetime.timezone.utc)}
|
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
result = await self.work_queue.get_work()
|
|
|
|
self.assertIsNone(result) # Should skip completed work
|
|
|
|
|
|
|
|
@async_test
|
|
|
|
async def test_get_work_locked(self):
|
|
|
|
"""Test getting work that's locked by another worker"""
|
2025-01-27 20:45:28 +00:00
|
|
|
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
|
2024-11-18 10:07:03 -08:00
|
|
|
await self.work_queue._queue.put(work_item)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Simulate active lock
|
|
|
|
recent_time = datetime.datetime.now(datetime.timezone.utc)
|
|
|
|
self.s3_client.head_object.side_effect = [
|
2025-01-29 15:30:39 -08:00
|
|
|
ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"), # Not completed
|
|
|
|
{"LastModified": recent_time}, # Active lock
|
2024-11-18 10:07:03 -08:00
|
|
|
]
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
result = await self.work_queue.get_work()
|
|
|
|
self.assertIsNone(result) # Should skip locked work
|
|
|
|
|
|
|
|
@async_test
|
|
|
|
async def test_get_work_stale_lock(self):
|
|
|
|
"""Test getting work with a stale lock"""
|
2025-01-27 20:45:28 +00:00
|
|
|
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
|
2024-11-18 10:07:03 -08:00
|
|
|
await self.work_queue._queue.put(work_item)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Simulate stale lock
|
|
|
|
stale_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)
|
|
|
|
self.s3_client.head_object.side_effect = [
|
2025-01-29 15:30:39 -08:00
|
|
|
ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"), # Not completed
|
|
|
|
{"LastModified": stale_time}, # Stale lock
|
2024-11-18 10:07:03 -08:00
|
|
|
]
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
result = await self.work_queue.get_work()
|
|
|
|
self.assertEqual(result, work_item) # Should take work with stale lock
|
|
|
|
|
|
|
|
@async_test
|
|
|
|
async def test_mark_done(self):
|
|
|
|
"""Test marking work as done"""
|
2025-01-27 20:45:28 +00:00
|
|
|
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
|
2024-11-18 10:07:03 -08:00
|
|
|
await self.work_queue._queue.put(work_item)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
await self.work_queue.mark_done(work_item)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
# Verify lock file was deleted
|
|
|
|
self.s3_client.delete_object.assert_called_once()
|
2025-01-29 15:30:39 -08:00
|
|
|
bucket, key = self.s3_client.delete_object.call_args[1]["Bucket"], self.s3_client.delete_object.call_args[1]["Key"]
|
2024-11-18 10:07:03 -08:00
|
|
|
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
|
|
|
|
|
2025-04-15 18:50:13 +00:00
|
|
|
@async_test
|
|
|
|
async def test_paths_with_commas(self):
|
|
|
|
"""Test handling of paths that contain commas"""
|
|
|
|
# Create paths with commas in them
|
|
|
|
paths_with_commas = ["s3://test-bucket/data/file1,with,commas.pdf", "s3://test-bucket/data/file2,comma.pdf", "s3://test-bucket/data/file3.pdf"]
|
|
|
|
|
|
|
|
# Mock empty existing index for initial population
|
|
|
|
with patch("olmocr.work_queue.download_zstd_csv", return_value=[]):
|
|
|
|
with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
|
|
|
|
# Populate the queue with these paths
|
|
|
|
await self.work_queue.populate_queue(paths_with_commas, items_per_group=3)
|
|
|
|
|
|
|
|
# Capture what would be written to the index
|
|
|
|
_, _, lines = mock_upload.call_args[0]
|
|
|
|
|
|
|
|
# Now simulate reading back these lines (which have commas in the paths)
|
|
|
|
with patch("olmocr.work_queue.download_zstd_csv", return_value=lines):
|
|
|
|
with patch("olmocr.work_queue.expand_s3_glob", return_value=[]):
|
|
|
|
# Initialize a fresh queue from these lines
|
|
|
|
await self.work_queue.initialize_queue()
|
|
|
|
|
|
|
|
# Mock ClientError for head_object (file doesn't exist)
|
|
|
|
self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
|
|
|
|
|
|
|
|
# Get a work item
|
|
|
|
work_item = await self.work_queue.get_work()
|
|
|
|
|
|
|
|
# Now verify we get a work item
|
|
|
|
self.assertIsNotNone(work_item, "Should get a work item")
|
|
|
|
|
|
|
|
# Verify the work item has the correct number of paths
|
|
|
|
self.assertEqual(len(work_item.work_paths), len(paths_with_commas), "Work item should have the correct number of paths")
|
|
|
|
|
|
|
|
# Check that all original paths with commas are preserved
|
|
|
|
for path in paths_with_commas:
|
|
|
|
print(path)
|
|
|
|
self.assertIn(path, work_item.work_paths, f"Path with commas should be preserved: {path}")
|
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
def test_queue_size(self):
|
|
|
|
"""Test queue size property"""
|
|
|
|
self.assertEqual(self.work_queue.size, 0)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
self.loop = asyncio.new_event_loop()
|
|
|
|
asyncio.set_event_loop(self.loop)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2025-01-27 20:45:28 +00:00
|
|
|
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", work_paths=["path1"])))
|
2024-11-18 10:07:03 -08:00
|
|
|
self.assertEqual(self.work_queue.size, 1)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2025-01-27 20:45:28 +00:00
|
|
|
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", work_paths=["path2"])))
|
2024-11-18 10:07:03 -08:00
|
|
|
self.assertEqual(self.work_queue.size, 2)
|
2025-01-29 15:30:39 -08:00
|
|
|
|
2024-11-18 10:07:03 -08:00
|
|
|
self.loop.close()
|
|
|
|
|
2025-01-29 15:30:39 -08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|