Basic work queue from claude

This commit is contained in:
Jake Poznanski 2024-11-18 10:07:03 -08:00
parent 995b1d15fc
commit 04429b2862
2 changed files with 395 additions and 15 deletions

View File

@ -3,6 +3,7 @@ import random
import logging
import hashlib
import tempfile
import datetime
from typing import Optional, Tuple, List, Dict, Set
from dataclasses import dataclass
import asyncio
@ -58,26 +59,86 @@ class S3WorkQueue:
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
self._queue = asyncio.Queue()
@staticmethod
def _compute_workgroup_hash(s3_work_paths: List[str]) -> str:
"""
Compute a deterministic hash for a group of PDFs.
Compute a deterministic hash for a group of paths.
Args:
pdfs: List of PDF S3 paths
s3_work_paths: List of S3 paths
Returns:
SHA1 hash of the sorted PDF paths
SHA1 hash of the sorted paths
"""
sha1 = hashlib.sha1()
for pdf in sorted(s3_work_paths):
sha1.update(pdf.encode('utf-8'))
for path in sorted(s3_work_paths):
sha1.update(path.encode('utf-8'))
return sha1.hexdigest()
async def populate_queue(self, s3_work_paths: list[str], items_per_group: int) -> None:
"""
Add new items to the work queue.
Args:
s3_work_paths: Each individual s3 path that we will process over
items_per_group: Number of items to group together in a single work item
"""
all_paths = set(s3_work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
async def populate_queue(self, s3_work_paths: str, items_per_group: int) -> None:
pass
# Load existing work groups
existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
existing_path_set = {path for paths in existing_groups.values() for path in paths}
# Find new paths to process
new_paths = all_paths - existing_path_set
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
if not new_paths:
return
# Create new work groups
new_groups = []
current_group = []
for path in sorted(new_paths):
current_group.append(path)
if len(current_group) == items_per_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine and save updated work groups
combined_groups = existing_groups.copy()
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
combined_lines = [
",".join([group_hash] + group_paths)
for group_hash, group_paths in combined_groups.items()
]
if new_groups:
await asyncio.to_thread(
upload_zstd_csv,
self.s3_client,
self._index_path,
combined_lines
)
async def initialize_queue(self) -> None:
"""
@ -116,7 +177,7 @@ class S3WorkQueue:
# Find remaining work and shuffle
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [
WorkItem(hash_=hash_, pdfs=work_queue[hash_])
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
for hash_ in remaining_work_hashes
]
random.shuffle(remaining_items)
@ -127,7 +188,7 @@ class S3WorkQueue:
await self._queue.put(item)
logger.info(f"Initialized queue with {self._queue.qsize()} work items")
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed.
@ -138,7 +199,7 @@ class S3WorkQueue:
Returns:
True if the work is completed, False otherwise
"""
output_s3_path = ""TODO""
output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl")
bucket, key = parse_s3_path(output_s3_path)
try:
@ -151,12 +212,89 @@ class S3WorkQueue:
except self.s3_client.exceptions.ClientError:
return False
async def get_work(self) -> Optional[WorkItem]:
pass
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
while True:
try:
work_item = self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
def mark_done(self, work_item: WorkItem) -> None:
"""Mark the most recently gotten work item as complete"""
pass
# Check if work is already completed
if await self.is_completed(work_item.hash):
logger.debug(f"Work item {work_item.hash} already completed, skipping")
self._queue.task_done()
continue
# Check for worker lock
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
response = await asyncio.to_thread(
self.s3_client.head_object,
Bucket=bucket,
Key=key
)
# Check if lock is stale
last_modified = response['LastModified']
if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs:
# Lock is stale, we can take this work
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
else:
# Lock is active, skip this work
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue
except self.s3_client.exceptions.ClientError:
# No lock exists, we can take this work
pass
# Create our lock file
try:
await asyncio.to_thread(
self.s3_client.put_object,
Bucket=bucket,
Key=key,
Body=b''
)
except Exception as e:
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
self._queue.task_done()
continue
return work_item
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file.
Args:
work_item: The WorkItem to mark as done
"""
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
await asyncio.to_thread(
self.s3_client.delete_object,
Bucket=bucket,
Key=key
)
except Exception as e:
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
self._queue.task_done()
@property
def size(self) -> int:

242
tests/test_s3_work_queue.py Normal file
View File

@ -0,0 +1,242 @@
import unittest
import asyncio
import datetime
from unittest.mock import Mock, patch, call
from botocore.exceptions import ClientError
import hashlib
from typing import List, Dict
# Import the classes we're testing
from pdelfin.s3_queue import S3WorkQueue, WorkItem
class TestS3WorkQueue(unittest.TestCase):
def setUp(self):
"""Set up test fixtures before each test method."""
self.s3_client = Mock()
self.s3_client.exceptions.ClientError = ClientError
self.work_queue = S3WorkQueue(self.s3_client, "s3://test-bucket/workspace")
self.sample_paths = [
"s3://test-bucket/data/file1.pdf",
"s3://test-bucket/data/file2.pdf",
"s3://test-bucket/data/file3.pdf",
]
def tearDown(self):
"""Clean up after each test method."""
pass
def test_compute_workgroup_hash(self):
"""Test hash computation is deterministic and correct"""
paths = [
"s3://test-bucket/data/file2.pdf",
"s3://test-bucket/data/file1.pdf",
]
# Hash should be the same regardless of order
hash1 = S3WorkQueue._compute_workgroup_hash(paths)
hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
self.assertEqual(hash1, hash2)
# Verify hash is actually SHA1
sha1 = hashlib.sha1()
for path in sorted(paths):
sha1.update(path.encode('utf-8'))
self.assertEqual(hash1, sha1.hexdigest())
def test_init(self):
"""Test initialization of S3WorkQueue"""
client = Mock()
queue = S3WorkQueue(client, "s3://test-bucket/workspace/")
self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
self.assertIsInstance(queue._queue, asyncio.Queue)
def asyncSetUp(self):
"""Set up async test fixtures"""
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def asyncTearDown(self):
"""Clean up async test fixtures"""
self.loop.close()
def async_test(f):
"""Decorator for async test methods"""
def wrapper(*args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(f(*args, **kwargs))
finally:
loop.close()
return wrapper
@async_test
async def test_populate_queue_new_items(self):
"""Test populating queue with new items"""
# Mock empty existing index
with patch('pdelfin.s3_queue.download_zstd_csv', return_value=[]):
with patch('pdelfin.s3_queue.upload_zstd_csv') as mock_upload:
await self.work_queue.populate_queue(self.sample_paths, items_per_group=2)
# Verify upload was called with correct data
self.assertEqual(mock_upload.call_count, 1)
_, _, lines = mock_upload.call_args[0]
# Should create 2 work groups (2 files + 1 file)
self.assertEqual(len(lines), 2)
# Verify format of uploaded lines
for line in lines:
parts = line.split(',')
self.assertGreaterEqual(len(parts), 2) # Hash + at least one path
self.assertEqual(len(parts[0]), 40) # SHA1 hash length
@async_test
async def test_populate_queue_existing_items(self):
"""Test populating queue with mix of new and existing items"""
existing_paths = ["s3://test-bucket/data/existing1.pdf"]
new_paths = ["s3://test-bucket/data/new1.pdf"]
# Create existing index content
existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths)
existing_line = f"{existing_hash},{existing_paths[0]}"
with patch('pdelfin.s3_queue.download_zstd_csv', return_value=[existing_line]):
with patch('pdelfin.s3_queue.upload_zstd_csv') as mock_upload:
await self.work_queue.populate_queue(existing_paths + new_paths, items_per_group=1)
# Verify upload called with both existing and new items
_, _, lines = mock_upload.call_args[0]
self.assertEqual(len(lines), 2)
self.assertIn(existing_line, lines)
@async_test
async def test_initialize_queue(self):
"""Test queue initialization"""
# Mock work items and completed items
work_paths = ["s3://test/file1.pdf", "s3://test/file2.pdf"]
work_hash = S3WorkQueue._compute_workgroup_hash(work_paths)
work_line = f"{work_hash},{work_paths[0]},{work_paths[1]}"
completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"]
with patch('pdelfin.s3_queue.download_zstd_csv', return_value=[work_line]):
with patch('pdelfin.s3_queue.expand_s3_glob', return_value=completed_items):
await self.work_queue.initialize_queue()
# Queue should be empty since all work is completed
self.assertTrue(self.work_queue._queue.empty())
@async_test
async def test_is_completed(self):
"""Test completed work check"""
work_hash = "testhash123"
# Test completed work
self.s3_client.head_object.return_value = {'LastModified': datetime.datetime.now(datetime.timezone.utc)}
self.assertTrue(await self.work_queue.is_completed(work_hash))
# Test incomplete work
self.s3_client.head_object.side_effect = ClientError(
{'Error': {'Code': '404', 'Message': 'Not Found'}},
'HeadObject'
)
self.assertFalse(await self.work_queue.is_completed(work_hash))
@async_test
async def test_get_work(self):
"""Test getting work items"""
# Setup test data
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Test getting available work
self.s3_client.head_object.side_effect = ClientError(
{'Error': {'Code': '404', 'Message': 'Not Found'}},
'HeadObject'
)
result = await self.work_queue.get_work()
self.assertEqual(result, work_item)
# Verify lock file was created
self.s3_client.put_object.assert_called_once()
bucket, key = self.s3_client.put_object.call_args[1]['Bucket'], self.s3_client.put_object.call_args[1]['Key']
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
@async_test
async def test_get_work_completed(self):
"""Test getting work that's already completed"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Simulate completed work
self.s3_client.head_object.return_value = {'LastModified': datetime.datetime.now(datetime.timezone.utc)}
result = await self.work_queue.get_work()
self.assertIsNone(result) # Should skip completed work
@async_test
async def test_get_work_locked(self):
"""Test getting work that's locked by another worker"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Simulate active lock
recent_time = datetime.datetime.now(datetime.timezone.utc)
self.s3_client.head_object.side_effect = [
ClientError({'Error': {'Code': '404', 'Message': 'Not Found'}}, 'HeadObject'), # Not completed
{'LastModified': recent_time} # Active lock
]
result = await self.work_queue.get_work()
self.assertIsNone(result) # Should skip locked work
@async_test
async def test_get_work_stale_lock(self):
"""Test getting work with a stale lock"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Simulate stale lock
stale_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)
self.s3_client.head_object.side_effect = [
ClientError({'Error': {'Code': '404', 'Message': 'Not Found'}}, 'HeadObject'), # Not completed
{'LastModified': stale_time} # Stale lock
]
result = await self.work_queue.get_work()
self.assertEqual(result, work_item) # Should take work with stale lock
@async_test
async def test_mark_done(self):
"""Test marking work as done"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
await self.work_queue.mark_done(work_item)
# Verify lock file was deleted
self.s3_client.delete_object.assert_called_once()
bucket, key = self.s3_client.delete_object.call_args[1]['Bucket'], self.s3_client.delete_object.call_args[1]['Key']
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
def test_queue_size(self):
"""Test queue size property"""
self.assertEqual(self.work_queue.size, 0)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", s3_work_paths=["path1"])))
self.assertEqual(self.work_queue.size, 1)
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", s3_work_paths=["path2"])))
self.assertEqual(self.work_queue.size, 2)
self.loop.close()
if __name__ == '__main__':
unittest.main()