Upping version to fix issue with work queue and delimited paths

This commit is contained in:
Jake Poznanski 2025-04-15 18:50:13 +00:00
parent 786b14aef5
commit 1d0c560455
4 changed files with 98 additions and 15 deletions

View File

@ -2,7 +2,7 @@ _MAJOR = "0"
_MINOR = "1"
# On main and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "61"
_PATCH = "62"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""

View File

@ -1,7 +1,9 @@
import abc
import asyncio
import csv
import datetime
import hashlib
import io
import logging
import os
import random
@ -32,6 +34,35 @@ class WorkQueue(abc.ABC):
Base class defining the interface for a work queue.
"""
@staticmethod
def _encode_csv_row(row: List[str]) -> str:
"""
Encodes a row of data for CSV storage with proper escaping.
Args:
row: List of strings to encode
Returns:
CSV-encoded string with proper escaping of commas and quotes
"""
output = io.StringIO()
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
writer.writerow(row)
return output.getvalue().strip()
@staticmethod
def _decode_csv_row(line: str) -> List[str]:
"""
Decodes a CSV row with proper unescaping.
Args:
line: CSV-encoded string
Returns:
List of unescaped string values
"""
return next(csv.reader([line]))
@abc.abstractmethod
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
@ -217,10 +248,11 @@ class LocalWorkQueue(WorkQueue):
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
parts = self._decode_csv_row(line.strip())
if parts: # Ensure we have at least one part
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
existing_path_set = {p for paths in existing_groups.values() for p in paths}
new_paths = all_paths - existing_path_set
@ -249,7 +281,8 @@ class LocalWorkQueue(WorkQueue):
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
combined_lines = [",".join([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
# Use proper CSV encoding with escaping for paths that may contain commas
combined_lines = [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
if new_groups:
# Write the combined data back to disk in zstd CSV format
@ -262,7 +295,12 @@ class LocalWorkQueue(WorkQueue):
"""
# 1) Read the index
work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
work_queue = {parts[0]: parts[1:] for line in work_queue_lines if (parts := line.strip().split(",")) and line.strip()}
work_queue = {}
for line in work_queue_lines:
if line.strip():
parts = self._decode_csv_row(line.strip())
if parts: # Ensure we have at least one part
work_queue[parts[0]] = parts[1:]
# 2) Determine which items are completed by scanning local results/*.jsonl
if not os.path.isdir(self._results_dir):
@ -422,10 +460,11 @@ class S3WorkQueue(WorkQueue):
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
parts = self._decode_csv_row(line.strip())
if parts: # Ensure we have at least one part
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
existing_path_set = {path for paths in existing_groups.values() for path in paths}
@ -456,7 +495,8 @@ class S3WorkQueue(WorkQueue):
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
combined_lines = [",".join([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
# Use proper CSV encoding with escaping for paths that may contain commas
combined_lines = [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
if new_groups:
await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, combined_lines)
@ -473,7 +513,12 @@ class S3WorkQueue(WorkQueue):
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
# Process work queue lines
work_queue = {parts[0]: parts[1:] for line in work_queue_lines if (parts := line.strip().split(",")) and line.strip()}
work_queue = {}
for line in work_queue_lines:
if line.strip():
parts = self._decode_csv_row(line.strip())
if parts: # Ensure we have at least one part
work_queue[parts[0]] = parts[1:]
# Get set of completed work hashes
done_work_hashes = {

View File

@ -117,8 +117,8 @@ def list_result_files(s3_client, workspace_path):
if "Contents" in page:
all_files.extend([f"s3://{bucket}/{obj['Key']}" for obj in page["Contents"] if obj["Key"].endswith(".jsonl") or obj["Key"].endswith(".json")])
if len(all_files) > 1000:
break
# if len(all_files) > 1000:
# break
return all_files

View File

@ -214,6 +214,44 @@ class TestS3WorkQueue(unittest.TestCase):
bucket, key = self.s3_client.delete_object.call_args[1]["Bucket"], self.s3_client.delete_object.call_args[1]["Key"]
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
@async_test
async def test_paths_with_commas(self):
"""Test handling of paths that contain commas"""
# Create paths with commas in them
paths_with_commas = ["s3://test-bucket/data/file1,with,commas.pdf", "s3://test-bucket/data/file2,comma.pdf", "s3://test-bucket/data/file3.pdf"]
# Mock empty existing index for initial population
with patch("olmocr.work_queue.download_zstd_csv", return_value=[]):
with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
# Populate the queue with these paths
await self.work_queue.populate_queue(paths_with_commas, items_per_group=3)
# Capture what would be written to the index
_, _, lines = mock_upload.call_args[0]
# Now simulate reading back these lines (which have commas in the paths)
with patch("olmocr.work_queue.download_zstd_csv", return_value=lines):
with patch("olmocr.work_queue.expand_s3_glob", return_value=[]):
# Initialize a fresh queue from these lines
await self.work_queue.initialize_queue()
# Mock ClientError for head_object (file doesn't exist)
self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
# Get a work item
work_item = await self.work_queue.get_work()
# Now verify we get a work item
self.assertIsNotNone(work_item, "Should get a work item")
# Verify the work item has the correct number of paths
self.assertEqual(len(work_item.work_paths), len(paths_with_commas), "Work item should have the correct number of paths")
# Check that all original paths with commas are preserved
for path in paths_with_commas:
print(path)
self.assertIn(path, work_item.work_paths, f"Path with commas should be preserved: {path}")
def test_queue_size(self):
"""Test queue size property"""
self.assertEqual(self.work_queue.size, 0)