diff --git a/olmocr/version.py b/olmocr/version.py index 6693d87..03a705c 100644 --- a/olmocr/version.py +++ b/olmocr/version.py @@ -2,7 +2,7 @@ _MAJOR = "0" _MINOR = "1" # On main and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "61" +_PATCH = "62" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = "" diff --git a/olmocr/work_queue.py b/olmocr/work_queue.py index 32e0903..8e74abe 100644 --- a/olmocr/work_queue.py +++ b/olmocr/work_queue.py @@ -1,7 +1,9 @@ import abc import asyncio +import csv import datetime import hashlib +import io import logging import os import random @@ -32,6 +34,35 @@ class WorkQueue(abc.ABC): Base class defining the interface for a work queue. """ + @staticmethod + def _encode_csv_row(row: List[str]) -> str: + """ + Encodes a row of data for CSV storage with proper escaping. + + Args: + row: List of strings to encode + + Returns: + CSV-encoded string with proper escaping of commas and quotes + """ + output = io.StringIO() + writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL) + writer.writerow(row) + return output.getvalue().strip() + + @staticmethod + def _decode_csv_row(line: str) -> List[str]: + """ + Decodes a CSV row with proper unescaping. + + Args: + line: CSV-encoded string + + Returns: + List of unescaped string values + """ + return next(csv.reader([line])) + @abc.abstractmethod async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None: """ @@ -217,10 +248,11 @@ class LocalWorkQueue(WorkQueue): existing_groups = {} for line in existing_lines: if line.strip(): - parts = line.strip().split(",") - group_hash = parts[0] - group_paths = parts[1:] - existing_groups[group_hash] = group_paths + parts = self._decode_csv_row(line.strip()) + if parts: # Ensure we have at least one part + group_hash = parts[0] + group_paths = parts[1:] + existing_groups[group_hash] = group_paths existing_path_set = {p for paths in existing_groups.values() for p in paths} new_paths = all_paths - existing_path_set @@ -249,7 +281,8 @@ class LocalWorkQueue(WorkQueue): for group_hash, group_paths in new_groups: combined_groups[group_hash] = group_paths - combined_lines = [",".join([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()] + # Use proper CSV encoding with escaping for paths that may contain commas + combined_lines = [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()] if new_groups: # Write the combined data back to disk in zstd CSV format @@ -262,7 +295,12 @@ class LocalWorkQueue(WorkQueue): """ # 1) Read the index work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path) - work_queue = {parts[0]: parts[1:] for line in work_queue_lines if (parts := line.strip().split(",")) and line.strip()} + work_queue = {} + for line in work_queue_lines: + if line.strip(): + parts = self._decode_csv_row(line.strip()) + if parts: # Ensure we have at least one part + work_queue[parts[0]] = parts[1:] # 2) Determine which items are completed by scanning local results/*.jsonl if not os.path.isdir(self._results_dir): @@ -422,10 +460,11 @@ class S3WorkQueue(WorkQueue): existing_groups = {} for line in existing_lines: if line.strip(): - parts = line.strip().split(",") - group_hash = parts[0] - group_paths = parts[1:] - existing_groups[group_hash] = group_paths + parts = self._decode_csv_row(line.strip()) + if parts: # Ensure we have at least one part + group_hash = parts[0] + group_paths = parts[1:] + existing_groups[group_hash] = group_paths existing_path_set = {path for paths in existing_groups.values() for path in paths} @@ -456,7 +495,8 @@ class S3WorkQueue(WorkQueue): for group_hash, group_paths in new_groups: combined_groups[group_hash] = group_paths - combined_lines = [",".join([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()] + # Use proper CSV encoding with escaping for paths that may contain commas + combined_lines = [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()] if new_groups: await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, combined_lines) @@ -473,7 +513,12 @@ class S3WorkQueue(WorkQueue): work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task) # Process work queue lines - work_queue = {parts[0]: parts[1:] for line in work_queue_lines if (parts := line.strip().split(",")) and line.strip()} + work_queue = {} + for line in work_queue_lines: + if line.strip(): + parts = self._decode_csv_row(line.strip()) + if parts: # Ensure we have at least one part + work_queue[parts[0]] = parts[1:] # Get set of completed work hashes done_work_hashes = { diff --git a/scripts/scan_dolmadocs.py b/scripts/scan_dolmadocs.py index a166734..6d45ff2 100644 --- a/scripts/scan_dolmadocs.py +++ b/scripts/scan_dolmadocs.py @@ -117,8 +117,8 @@ def list_result_files(s3_client, workspace_path): if "Contents" in page: all_files.extend([f"s3://{bucket}/{obj['Key']}" for obj in page["Contents"] if obj["Key"].endswith(".jsonl") or obj["Key"].endswith(".json")]) - if len(all_files) > 1000: - break + # if len(all_files) > 1000: + # break return all_files diff --git a/tests/test_s3_work_queue.py b/tests/test_s3_work_queue.py index dc0e905..a8ce1cd 100644 --- a/tests/test_s3_work_queue.py +++ b/tests/test_s3_work_queue.py @@ -214,6 +214,44 @@ class TestS3WorkQueue(unittest.TestCase): bucket, key = self.s3_client.delete_object.call_args[1]["Bucket"], self.s3_client.delete_object.call_args[1]["Key"] self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl")) + @async_test + async def test_paths_with_commas(self): + """Test handling of paths that contain commas""" + # Create paths with commas in them + paths_with_commas = ["s3://test-bucket/data/file1,with,commas.pdf", "s3://test-bucket/data/file2,comma.pdf", "s3://test-bucket/data/file3.pdf"] + + # Mock empty existing index for initial population + with patch("olmocr.work_queue.download_zstd_csv", return_value=[]): + with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload: + # Populate the queue with these paths + await self.work_queue.populate_queue(paths_with_commas, items_per_group=3) + + # Capture what would be written to the index + _, _, lines = mock_upload.call_args[0] + + # Now simulate reading back these lines (which have commas in the paths) + with patch("olmocr.work_queue.download_zstd_csv", return_value=lines): + with patch("olmocr.work_queue.expand_s3_glob", return_value=[]): + # Initialize a fresh queue from these lines + await self.work_queue.initialize_queue() + + # Mock ClientError for head_object (file doesn't exist) + self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") + + # Get a work item + work_item = await self.work_queue.get_work() + + # Now verify we get a work item + self.assertIsNotNone(work_item, "Should get a work item") + + # Verify the work item has the correct number of paths + self.assertEqual(len(work_item.work_paths), len(paths_with_commas), "Work item should have the correct number of paths") + + # Check that all original paths with commas are preserved + for path in paths_with_commas: + print(path) + self.assertIn(path, work_item.work_paths, f"Path with commas should be preserved: {path}") + def test_queue_size(self): """Test queue size property""" self.assertEqual(self.work_queue.size, 0)