Basic work queue from claude

This commit is contained in:
Jake Poznanski 2024-11-18 10:07:03 -08:00
parent 995b1d15fc
commit 04429b2862
2 changed files with 395 additions and 15 deletions

View File

@ -3,6 +3,7 @@ import random
import logging import logging
import hashlib import hashlib
import tempfile import tempfile
import datetime
from typing import Optional, Tuple, List, Dict, Set from typing import Optional, Tuple, List, Dict, Set
from dataclasses import dataclass from dataclasses import dataclass
import asyncio import asyncio
@ -58,26 +59,86 @@ class S3WorkQueue:
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd") self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl") self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
self._queue = asyncio.Queue()
@staticmethod @staticmethod
def _compute_workgroup_hash(s3_work_paths: List[str]) -> str: def _compute_workgroup_hash(s3_work_paths: List[str]) -> str:
""" """
Compute a deterministic hash for a group of PDFs. Compute a deterministic hash for a group of paths.
Args: Args:
pdfs: List of PDF S3 paths s3_work_paths: List of S3 paths
Returns: Returns:
SHA1 hash of the sorted PDF paths SHA1 hash of the sorted paths
""" """
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
for pdf in sorted(s3_work_paths): for path in sorted(s3_work_paths):
sha1.update(pdf.encode('utf-8')) sha1.update(path.encode('utf-8'))
return sha1.hexdigest() return sha1.hexdigest()
async def populate_queue(self, s3_work_paths: list[str], items_per_group: int) -> None:
"""
Add new items to the work queue.
async def populate_queue(self, s3_work_paths: str, items_per_group: int) -> None: Args:
pass s3_work_paths: Each individual s3 path that we will process over
items_per_group: Number of items to group together in a single work item
"""
all_paths = set(s3_work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
# Load existing work groups
existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
existing_path_set = {path for paths in existing_groups.values() for path in paths}
# Find new paths to process
new_paths = all_paths - existing_path_set
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
if not new_paths:
return
# Create new work groups
new_groups = []
current_group = []
for path in sorted(new_paths):
current_group.append(path)
if len(current_group) == items_per_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine and save updated work groups
combined_groups = existing_groups.copy()
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
combined_lines = [
",".join([group_hash] + group_paths)
for group_hash, group_paths in combined_groups.items()
]
if new_groups:
await asyncio.to_thread(
upload_zstd_csv,
self.s3_client,
self._index_path,
combined_lines
)
async def initialize_queue(self) -> None: async def initialize_queue(self) -> None:
""" """
@ -116,7 +177,7 @@ class S3WorkQueue:
# Find remaining work and shuffle # Find remaining work and shuffle
remaining_work_hashes = set(work_queue) - done_work_hashes remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [ remaining_items = [
WorkItem(hash_=hash_, pdfs=work_queue[hash_]) WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
for hash_ in remaining_work_hashes for hash_ in remaining_work_hashes
] ]
random.shuffle(remaining_items) random.shuffle(remaining_items)
@ -138,7 +199,7 @@ class S3WorkQueue:
Returns: Returns:
True if the work is completed, False otherwise True if the work is completed, False otherwise
""" """
output_s3_path = ""TODO"" output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl")
bucket, key = parse_s3_path(output_s3_path) bucket, key = parse_s3_path(output_s3_path)
try: try:
@ -151,12 +212,89 @@ class S3WorkQueue:
except self.s3_client.exceptions.ClientError: except self.s3_client.exceptions.ClientError:
return False return False
async def get_work(self) -> Optional[WorkItem]: async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
pass """
Get the next available work item that isn't completed or locked.
def mark_done(self, work_item: WorkItem) -> None: Args:
"""Mark the most recently gotten work item as complete""" worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins)
pass
Returns:
WorkItem if work is available, None if queue is empty
"""
while True:
try:
work_item = self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
# Check if work is already completed
if await self.is_completed(work_item.hash):
logger.debug(f"Work item {work_item.hash} already completed, skipping")
self._queue.task_done()
continue
# Check for worker lock
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
response = await asyncio.to_thread(
self.s3_client.head_object,
Bucket=bucket,
Key=key
)
# Check if lock is stale
last_modified = response['LastModified']
if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs:
# Lock is stale, we can take this work
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
else:
# Lock is active, skip this work
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue
except self.s3_client.exceptions.ClientError:
# No lock exists, we can take this work
pass
# Create our lock file
try:
await asyncio.to_thread(
self.s3_client.put_object,
Bucket=bucket,
Key=key,
Body=b''
)
except Exception as e:
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
self._queue.task_done()
continue
return work_item
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file.
Args:
work_item: The WorkItem to mark as done
"""
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
await asyncio.to_thread(
self.s3_client.delete_object,
Bucket=bucket,
Key=key
)
except Exception as e:
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
self._queue.task_done()
@property @property
def size(self) -> int: def size(self) -> int:

242
tests/test_s3_work_queue.py Normal file
View File

@ -0,0 +1,242 @@
import unittest
import asyncio
import datetime
from unittest.mock import Mock, patch, call
from botocore.exceptions import ClientError
import hashlib
from typing import List, Dict
# Import the classes we're testing
from pdelfin.s3_queue import S3WorkQueue, WorkItem
class TestS3WorkQueue(unittest.TestCase):
def setUp(self):
"""Set up test fixtures before each test method."""
self.s3_client = Mock()
self.s3_client.exceptions.ClientError = ClientError
self.work_queue = S3WorkQueue(self.s3_client, "s3://test-bucket/workspace")
self.sample_paths = [
"s3://test-bucket/data/file1.pdf",
"s3://test-bucket/data/file2.pdf",
"s3://test-bucket/data/file3.pdf",
]
def tearDown(self):
"""Clean up after each test method."""
pass
def test_compute_workgroup_hash(self):
"""Test hash computation is deterministic and correct"""
paths = [
"s3://test-bucket/data/file2.pdf",
"s3://test-bucket/data/file1.pdf",
]
# Hash should be the same regardless of order
hash1 = S3WorkQueue._compute_workgroup_hash(paths)
hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
self.assertEqual(hash1, hash2)
# Verify hash is actually SHA1
sha1 = hashlib.sha1()
for path in sorted(paths):
sha1.update(path.encode('utf-8'))
self.assertEqual(hash1, sha1.hexdigest())
def test_init(self):
"""Test initialization of S3WorkQueue"""
client = Mock()
queue = S3WorkQueue(client, "s3://test-bucket/workspace/")
self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
self.assertIsInstance(queue._queue, asyncio.Queue)
def asyncSetUp(self):
"""Set up async test fixtures"""
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def asyncTearDown(self):
"""Clean up async test fixtures"""
self.loop.close()
def async_test(f):
"""Decorator for async test methods"""
def wrapper(*args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(f(*args, **kwargs))
finally:
loop.close()
return wrapper
@async_test
async def test_populate_queue_new_items(self):
"""Test populating queue with new items"""
# Mock empty existing index
with patch('pdelfin.s3_queue.download_zstd_csv', return_value=[]):
with patch('pdelfin.s3_queue.upload_zstd_csv') as mock_upload:
await self.work_queue.populate_queue(self.sample_paths, items_per_group=2)
# Verify upload was called with correct data
self.assertEqual(mock_upload.call_count, 1)
_, _, lines = mock_upload.call_args[0]
# Should create 2 work groups (2 files + 1 file)
self.assertEqual(len(lines), 2)
# Verify format of uploaded lines
for line in lines:
parts = line.split(',')
self.assertGreaterEqual(len(parts), 2) # Hash + at least one path
self.assertEqual(len(parts[0]), 40) # SHA1 hash length
@async_test
async def test_populate_queue_existing_items(self):
"""Test populating queue with mix of new and existing items"""
existing_paths = ["s3://test-bucket/data/existing1.pdf"]
new_paths = ["s3://test-bucket/data/new1.pdf"]
# Create existing index content
existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths)
existing_line = f"{existing_hash},{existing_paths[0]}"
with patch('pdelfin.s3_queue.download_zstd_csv', return_value=[existing_line]):
with patch('pdelfin.s3_queue.upload_zstd_csv') as mock_upload:
await self.work_queue.populate_queue(existing_paths + new_paths, items_per_group=1)
# Verify upload called with both existing and new items
_, _, lines = mock_upload.call_args[0]
self.assertEqual(len(lines), 2)
self.assertIn(existing_line, lines)
@async_test
async def test_initialize_queue(self):
"""Test queue initialization"""
# Mock work items and completed items
work_paths = ["s3://test/file1.pdf", "s3://test/file2.pdf"]
work_hash = S3WorkQueue._compute_workgroup_hash(work_paths)
work_line = f"{work_hash},{work_paths[0]},{work_paths[1]}"
completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"]
with patch('pdelfin.s3_queue.download_zstd_csv', return_value=[work_line]):
with patch('pdelfin.s3_queue.expand_s3_glob', return_value=completed_items):
await self.work_queue.initialize_queue()
# Queue should be empty since all work is completed
self.assertTrue(self.work_queue._queue.empty())
@async_test
async def test_is_completed(self):
"""Test completed work check"""
work_hash = "testhash123"
# Test completed work
self.s3_client.head_object.return_value = {'LastModified': datetime.datetime.now(datetime.timezone.utc)}
self.assertTrue(await self.work_queue.is_completed(work_hash))
# Test incomplete work
self.s3_client.head_object.side_effect = ClientError(
{'Error': {'Code': '404', 'Message': 'Not Found'}},
'HeadObject'
)
self.assertFalse(await self.work_queue.is_completed(work_hash))
@async_test
async def test_get_work(self):
"""Test getting work items"""
# Setup test data
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Test getting available work
self.s3_client.head_object.side_effect = ClientError(
{'Error': {'Code': '404', 'Message': 'Not Found'}},
'HeadObject'
)
result = await self.work_queue.get_work()
self.assertEqual(result, work_item)
# Verify lock file was created
self.s3_client.put_object.assert_called_once()
bucket, key = self.s3_client.put_object.call_args[1]['Bucket'], self.s3_client.put_object.call_args[1]['Key']
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
@async_test
async def test_get_work_completed(self):
"""Test getting work that's already completed"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Simulate completed work
self.s3_client.head_object.return_value = {'LastModified': datetime.datetime.now(datetime.timezone.utc)}
result = await self.work_queue.get_work()
self.assertIsNone(result) # Should skip completed work
@async_test
async def test_get_work_locked(self):
"""Test getting work that's locked by another worker"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Simulate active lock
recent_time = datetime.datetime.now(datetime.timezone.utc)
self.s3_client.head_object.side_effect = [
ClientError({'Error': {'Code': '404', 'Message': 'Not Found'}}, 'HeadObject'), # Not completed
{'LastModified': recent_time} # Active lock
]
result = await self.work_queue.get_work()
self.assertIsNone(result) # Should skip locked work
@async_test
async def test_get_work_stale_lock(self):
"""Test getting work with a stale lock"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Simulate stale lock
stale_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)
self.s3_client.head_object.side_effect = [
ClientError({'Error': {'Code': '404', 'Message': 'Not Found'}}, 'HeadObject'), # Not completed
{'LastModified': stale_time} # Stale lock
]
result = await self.work_queue.get_work()
self.assertEqual(result, work_item) # Should take work with stale lock
@async_test
async def test_mark_done(self):
"""Test marking work as done"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
await self.work_queue.mark_done(work_item)
# Verify lock file was deleted
self.s3_client.delete_object.assert_called_once()
bucket, key = self.s3_client.delete_object.call_args[1]['Bucket'], self.s3_client.delete_object.call_args[1]['Key']
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
def test_queue_size(self):
"""Test queue size property"""
self.assertEqual(self.work_queue.size, 0)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", s3_work_paths=["path1"])))
self.assertEqual(self.work_queue.size, 1)
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", s3_work_paths=["path2"])))
self.assertEqual(self.work_queue.size, 2)
self.loop.close()
if __name__ == '__main__':
unittest.main()