diff --git a/pdelfin/s3_queue.py b/pdelfin/s3_queue.py index 43a2fe9..5c91ed5 100644 --- a/pdelfin/s3_queue.py +++ b/pdelfin/s3_queue.py @@ -3,6 +3,7 @@ import random import logging import hashlib import tempfile +import datetime from typing import Optional, Tuple, List, Dict, Set from dataclasses import dataclass import asyncio @@ -58,26 +59,86 @@ class S3WorkQueue: self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd") self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl") + self._queue = asyncio.Queue() @staticmethod def _compute_workgroup_hash(s3_work_paths: List[str]) -> str: """ - Compute a deterministic hash for a group of PDFs. + Compute a deterministic hash for a group of paths. Args: - pdfs: List of PDF S3 paths + s3_work_paths: List of S3 paths Returns: - SHA1 hash of the sorted PDF paths + SHA1 hash of the sorted paths """ sha1 = hashlib.sha1() - for pdf in sorted(s3_work_paths): - sha1.update(pdf.encode('utf-8')) + for path in sorted(s3_work_paths): + sha1.update(path.encode('utf-8')) return sha1.hexdigest() + async def populate_queue(self, s3_work_paths: list[str], items_per_group: int) -> None: + """ + Add new items to the work queue. + + Args: + s3_work_paths: Each individual s3 path that we will process over + items_per_group: Number of items to group together in a single work item + """ + all_paths = set(s3_work_paths) + logger.info(f"Found {len(all_paths):,} total paths") - async def populate_queue(self, s3_work_paths: str, items_per_group: int) -> None: - pass + # Load existing work groups + existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path) + existing_groups = {} + for line in existing_lines: + if line.strip(): + parts = line.strip().split(",") + group_hash = parts[0] + group_paths = parts[1:] + existing_groups[group_hash] = group_paths + + existing_path_set = {path for paths in existing_groups.values() for path in paths} + + # Find new paths to process + new_paths = all_paths - existing_path_set + logger.info(f"{len(new_paths):,} new paths to add to the workspace") + + if not new_paths: + return + + # Create new work groups + new_groups = [] + current_group = [] + for path in sorted(new_paths): + current_group.append(path) + if len(current_group) == items_per_group: + group_hash = self._compute_workgroup_hash(current_group) + new_groups.append((group_hash, current_group)) + current_group = [] + if current_group: + group_hash = self._compute_workgroup_hash(current_group) + new_groups.append((group_hash, current_group)) + + logger.info(f"Created {len(new_groups):,} new work groups") + + # Combine and save updated work groups + combined_groups = existing_groups.copy() + for group_hash, group_paths in new_groups: + combined_groups[group_hash] = group_paths + + combined_lines = [ + ",".join([group_hash] + group_paths) + for group_hash, group_paths in combined_groups.items() + ] + + if new_groups: + await asyncio.to_thread( + upload_zstd_csv, + self.s3_client, + self._index_path, + combined_lines + ) async def initialize_queue(self) -> None: """ @@ -116,7 +177,7 @@ class S3WorkQueue: # Find remaining work and shuffle remaining_work_hashes = set(work_queue) - done_work_hashes remaining_items = [ - WorkItem(hash_=hash_, pdfs=work_queue[hash_]) + WorkItem(hash=hash_, s3_work_paths=work_queue[hash_]) for hash_ in remaining_work_hashes ] random.shuffle(remaining_items) @@ -127,7 +188,7 @@ class S3WorkQueue: await self._queue.put(item) logger.info(f"Initialized queue with {self._queue.qsize()} work items") - + async def is_completed(self, work_hash: str) -> bool: """ Check if a work item has been completed. @@ -138,7 +199,7 @@ class S3WorkQueue: Returns: True if the work is completed, False otherwise """ - output_s3_path = ""TODO"" + output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl") bucket, key = parse_s3_path(output_s3_path) try: @@ -151,12 +212,89 @@ class S3WorkQueue: except self.s3_client.exceptions.ClientError: return False - async def get_work(self) -> Optional[WorkItem]: - pass + async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]: + """ + Get the next available work item that isn't completed or locked. + + Args: + worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins) + + Returns: + WorkItem if work is available, None if queue is empty + """ + while True: + try: + work_item = self._queue.get_nowait() + except asyncio.QueueEmpty: + return None - def mark_done(self, work_item: WorkItem) -> None: - """Mark the most recently gotten work item as complete""" - pass + # Check if work is already completed + if await self.is_completed(work_item.hash): + logger.debug(f"Work item {work_item.hash} already completed, skipping") + self._queue.task_done() + continue + + # Check for worker lock + lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl") + bucket, key = parse_s3_path(lock_path) + + try: + response = await asyncio.to_thread( + self.s3_client.head_object, + Bucket=bucket, + Key=key + ) + + # Check if lock is stale + last_modified = response['LastModified'] + if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs: + # Lock is stale, we can take this work + logger.debug(f"Found stale lock for {work_item.hash}, taking work item") + else: + # Lock is active, skip this work + logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping") + self._queue.task_done() + continue + + except self.s3_client.exceptions.ClientError: + # No lock exists, we can take this work + pass + + # Create our lock file + try: + await asyncio.to_thread( + self.s3_client.put_object, + Bucket=bucket, + Key=key, + Body=b'' + ) + except Exception as e: + logger.warning(f"Failed to create lock file for {work_item.hash}: {e}") + self._queue.task_done() + continue + + return work_item + + async def mark_done(self, work_item: WorkItem) -> None: + """ + Mark a work item as done by removing its lock file. + + Args: + work_item: The WorkItem to mark as done + """ + lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl") + bucket, key = parse_s3_path(lock_path) + + try: + await asyncio.to_thread( + self.s3_client.delete_object, + Bucket=bucket, + Key=key + ) + except Exception as e: + logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}") + + self._queue.task_done() @property def size(self) -> int: diff --git a/tests/test_s3_work_queue.py b/tests/test_s3_work_queue.py new file mode 100644 index 0000000..a78bd6a --- /dev/null +++ b/tests/test_s3_work_queue.py @@ -0,0 +1,242 @@ +import unittest +import asyncio +import datetime +from unittest.mock import Mock, patch, call +from botocore.exceptions import ClientError +import hashlib +from typing import List, Dict + +# Import the classes we're testing +from pdelfin.s3_queue import S3WorkQueue, WorkItem + +class TestS3WorkQueue(unittest.TestCase): + def setUp(self): + """Set up test fixtures before each test method.""" + self.s3_client = Mock() + self.s3_client.exceptions.ClientError = ClientError + self.work_queue = S3WorkQueue(self.s3_client, "s3://test-bucket/workspace") + self.sample_paths = [ + "s3://test-bucket/data/file1.pdf", + "s3://test-bucket/data/file2.pdf", + "s3://test-bucket/data/file3.pdf", + ] + + def tearDown(self): + """Clean up after each test method.""" + pass + + def test_compute_workgroup_hash(self): + """Test hash computation is deterministic and correct""" + paths = [ + "s3://test-bucket/data/file2.pdf", + "s3://test-bucket/data/file1.pdf", + ] + + # Hash should be the same regardless of order + hash1 = S3WorkQueue._compute_workgroup_hash(paths) + hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths)) + self.assertEqual(hash1, hash2) + + # Verify hash is actually SHA1 + sha1 = hashlib.sha1() + for path in sorted(paths): + sha1.update(path.encode('utf-8')) + self.assertEqual(hash1, sha1.hexdigest()) + + def test_init(self): + """Test initialization of S3WorkQueue""" + client = Mock() + queue = S3WorkQueue(client, "s3://test-bucket/workspace/") + + self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace") + self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd") + self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl") + self.assertIsInstance(queue._queue, asyncio.Queue) + + def asyncSetUp(self): + """Set up async test fixtures""" + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def asyncTearDown(self): + """Clean up async test fixtures""" + self.loop.close() + + def async_test(f): + """Decorator for async test methods""" + def wrapper(*args, **kwargs): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(f(*args, **kwargs)) + finally: + loop.close() + return wrapper + + @async_test + async def test_populate_queue_new_items(self): + """Test populating queue with new items""" + # Mock empty existing index + with patch('pdelfin.s3_queue.download_zstd_csv', return_value=[]): + with patch('pdelfin.s3_queue.upload_zstd_csv') as mock_upload: + await self.work_queue.populate_queue(self.sample_paths, items_per_group=2) + + # Verify upload was called with correct data + self.assertEqual(mock_upload.call_count, 1) + _, _, lines = mock_upload.call_args[0] + + # Should create 2 work groups (2 files + 1 file) + self.assertEqual(len(lines), 2) + + # Verify format of uploaded lines + for line in lines: + parts = line.split(',') + self.assertGreaterEqual(len(parts), 2) # Hash + at least one path + self.assertEqual(len(parts[0]), 40) # SHA1 hash length + + @async_test + async def test_populate_queue_existing_items(self): + """Test populating queue with mix of new and existing items""" + existing_paths = ["s3://test-bucket/data/existing1.pdf"] + new_paths = ["s3://test-bucket/data/new1.pdf"] + + # Create existing index content + existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths) + existing_line = f"{existing_hash},{existing_paths[0]}" + + with patch('pdelfin.s3_queue.download_zstd_csv', return_value=[existing_line]): + with patch('pdelfin.s3_queue.upload_zstd_csv') as mock_upload: + await self.work_queue.populate_queue(existing_paths + new_paths, items_per_group=1) + + # Verify upload called with both existing and new items + _, _, lines = mock_upload.call_args[0] + self.assertEqual(len(lines), 2) + self.assertIn(existing_line, lines) + + @async_test + async def test_initialize_queue(self): + """Test queue initialization""" + # Mock work items and completed items + work_paths = ["s3://test/file1.pdf", "s3://test/file2.pdf"] + work_hash = S3WorkQueue._compute_workgroup_hash(work_paths) + work_line = f"{work_hash},{work_paths[0]},{work_paths[1]}" + + completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"] + + with patch('pdelfin.s3_queue.download_zstd_csv', return_value=[work_line]): + with patch('pdelfin.s3_queue.expand_s3_glob', return_value=completed_items): + await self.work_queue.initialize_queue() + + # Queue should be empty since all work is completed + self.assertTrue(self.work_queue._queue.empty()) + + @async_test + async def test_is_completed(self): + """Test completed work check""" + work_hash = "testhash123" + + # Test completed work + self.s3_client.head_object.return_value = {'LastModified': datetime.datetime.now(datetime.timezone.utc)} + self.assertTrue(await self.work_queue.is_completed(work_hash)) + + # Test incomplete work + self.s3_client.head_object.side_effect = ClientError( + {'Error': {'Code': '404', 'Message': 'Not Found'}}, + 'HeadObject' + ) + self.assertFalse(await self.work_queue.is_completed(work_hash)) + + @async_test + async def test_get_work(self): + """Test getting work items""" + # Setup test data + work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"]) + await self.work_queue._queue.put(work_item) + + # Test getting available work + self.s3_client.head_object.side_effect = ClientError( + {'Error': {'Code': '404', 'Message': 'Not Found'}}, + 'HeadObject' + ) + result = await self.work_queue.get_work() + self.assertEqual(result, work_item) + + # Verify lock file was created + self.s3_client.put_object.assert_called_once() + bucket, key = self.s3_client.put_object.call_args[1]['Bucket'], self.s3_client.put_object.call_args[1]['Key'] + self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl")) + + @async_test + async def test_get_work_completed(self): + """Test getting work that's already completed""" + work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"]) + await self.work_queue._queue.put(work_item) + + # Simulate completed work + self.s3_client.head_object.return_value = {'LastModified': datetime.datetime.now(datetime.timezone.utc)} + + result = await self.work_queue.get_work() + self.assertIsNone(result) # Should skip completed work + + @async_test + async def test_get_work_locked(self): + """Test getting work that's locked by another worker""" + work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"]) + await self.work_queue._queue.put(work_item) + + # Simulate active lock + recent_time = datetime.datetime.now(datetime.timezone.utc) + self.s3_client.head_object.side_effect = [ + ClientError({'Error': {'Code': '404', 'Message': 'Not Found'}}, 'HeadObject'), # Not completed + {'LastModified': recent_time} # Active lock + ] + + result = await self.work_queue.get_work() + self.assertIsNone(result) # Should skip locked work + + @async_test + async def test_get_work_stale_lock(self): + """Test getting work with a stale lock""" + work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"]) + await self.work_queue._queue.put(work_item) + + # Simulate stale lock + stale_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1) + self.s3_client.head_object.side_effect = [ + ClientError({'Error': {'Code': '404', 'Message': 'Not Found'}}, 'HeadObject'), # Not completed + {'LastModified': stale_time} # Stale lock + ] + + result = await self.work_queue.get_work() + self.assertEqual(result, work_item) # Should take work with stale lock + + @async_test + async def test_mark_done(self): + """Test marking work as done""" + work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"]) + await self.work_queue._queue.put(work_item) + + await self.work_queue.mark_done(work_item) + + # Verify lock file was deleted + self.s3_client.delete_object.assert_called_once() + bucket, key = self.s3_client.delete_object.call_args[1]['Bucket'], self.s3_client.delete_object.call_args[1]['Key'] + self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl")) + + def test_queue_size(self): + """Test queue size property""" + self.assertEqual(self.work_queue.size, 0) + + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", s3_work_paths=["path1"]))) + self.assertEqual(self.work_queue.size, 1) + + self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", s3_work_paths=["path2"]))) + self.assertEqual(self.work_queue.size, 2) + + self.loop.close() + +if __name__ == '__main__': + unittest.main() \ No newline at end of file