mirror of
https://github.com/allenai/olmocr.git
synced 2026-01-06 04:12:30 +00:00
Upping version to fix issue with work queue and delimited paths
This commit is contained in:
parent
786b14aef5
commit
1d0c560455
@ -2,7 +2,7 @@ _MAJOR = "0"
|
||||
_MINOR = "1"
|
||||
# On main and in a nightly release the patch should be one ahead of the last
|
||||
# released build.
|
||||
_PATCH = "61"
|
||||
_PATCH = "62"
|
||||
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
|
||||
# https://semver.org/#is-v123-a-semantic-version for the semantics.
|
||||
_SUFFIX = ""
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import abc
|
||||
import asyncio
|
||||
import csv
|
||||
import datetime
|
||||
import hashlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@ -32,6 +34,35 @@ class WorkQueue(abc.ABC):
|
||||
Base class defining the interface for a work queue.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _encode_csv_row(row: List[str]) -> str:
|
||||
"""
|
||||
Encodes a row of data for CSV storage with proper escaping.
|
||||
|
||||
Args:
|
||||
row: List of strings to encode
|
||||
|
||||
Returns:
|
||||
CSV-encoded string with proper escaping of commas and quotes
|
||||
"""
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
|
||||
writer.writerow(row)
|
||||
return output.getvalue().strip()
|
||||
|
||||
@staticmethod
|
||||
def _decode_csv_row(line: str) -> List[str]:
|
||||
"""
|
||||
Decodes a CSV row with proper unescaping.
|
||||
|
||||
Args:
|
||||
line: CSV-encoded string
|
||||
|
||||
Returns:
|
||||
List of unescaped string values
|
||||
"""
|
||||
return next(csv.reader([line]))
|
||||
|
||||
@abc.abstractmethod
|
||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||
"""
|
||||
@ -217,10 +248,11 @@ class LocalWorkQueue(WorkQueue):
|
||||
existing_groups = {}
|
||||
for line in existing_lines:
|
||||
if line.strip():
|
||||
parts = line.strip().split(",")
|
||||
group_hash = parts[0]
|
||||
group_paths = parts[1:]
|
||||
existing_groups[group_hash] = group_paths
|
||||
parts = self._decode_csv_row(line.strip())
|
||||
if parts: # Ensure we have at least one part
|
||||
group_hash = parts[0]
|
||||
group_paths = parts[1:]
|
||||
existing_groups[group_hash] = group_paths
|
||||
|
||||
existing_path_set = {p for paths in existing_groups.values() for p in paths}
|
||||
new_paths = all_paths - existing_path_set
|
||||
@ -249,7 +281,8 @@ class LocalWorkQueue(WorkQueue):
|
||||
for group_hash, group_paths in new_groups:
|
||||
combined_groups[group_hash] = group_paths
|
||||
|
||||
combined_lines = [",".join([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
|
||||
# Use proper CSV encoding with escaping for paths that may contain commas
|
||||
combined_lines = [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
|
||||
|
||||
if new_groups:
|
||||
# Write the combined data back to disk in zstd CSV format
|
||||
@ -262,7 +295,12 @@ class LocalWorkQueue(WorkQueue):
|
||||
"""
|
||||
# 1) Read the index
|
||||
work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
|
||||
work_queue = {parts[0]: parts[1:] for line in work_queue_lines if (parts := line.strip().split(",")) and line.strip()}
|
||||
work_queue = {}
|
||||
for line in work_queue_lines:
|
||||
if line.strip():
|
||||
parts = self._decode_csv_row(line.strip())
|
||||
if parts: # Ensure we have at least one part
|
||||
work_queue[parts[0]] = parts[1:]
|
||||
|
||||
# 2) Determine which items are completed by scanning local results/*.jsonl
|
||||
if not os.path.isdir(self._results_dir):
|
||||
@ -422,10 +460,11 @@ class S3WorkQueue(WorkQueue):
|
||||
existing_groups = {}
|
||||
for line in existing_lines:
|
||||
if line.strip():
|
||||
parts = line.strip().split(",")
|
||||
group_hash = parts[0]
|
||||
group_paths = parts[1:]
|
||||
existing_groups[group_hash] = group_paths
|
||||
parts = self._decode_csv_row(line.strip())
|
||||
if parts: # Ensure we have at least one part
|
||||
group_hash = parts[0]
|
||||
group_paths = parts[1:]
|
||||
existing_groups[group_hash] = group_paths
|
||||
|
||||
existing_path_set = {path for paths in existing_groups.values() for path in paths}
|
||||
|
||||
@ -456,7 +495,8 @@ class S3WorkQueue(WorkQueue):
|
||||
for group_hash, group_paths in new_groups:
|
||||
combined_groups[group_hash] = group_paths
|
||||
|
||||
combined_lines = [",".join([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
|
||||
# Use proper CSV encoding with escaping for paths that may contain commas
|
||||
combined_lines = [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
|
||||
|
||||
if new_groups:
|
||||
await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, combined_lines)
|
||||
@ -473,7 +513,12 @@ class S3WorkQueue(WorkQueue):
|
||||
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
|
||||
|
||||
# Process work queue lines
|
||||
work_queue = {parts[0]: parts[1:] for line in work_queue_lines if (parts := line.strip().split(",")) and line.strip()}
|
||||
work_queue = {}
|
||||
for line in work_queue_lines:
|
||||
if line.strip():
|
||||
parts = self._decode_csv_row(line.strip())
|
||||
if parts: # Ensure we have at least one part
|
||||
work_queue[parts[0]] = parts[1:]
|
||||
|
||||
# Get set of completed work hashes
|
||||
done_work_hashes = {
|
||||
|
||||
@ -117,8 +117,8 @@ def list_result_files(s3_client, workspace_path):
|
||||
if "Contents" in page:
|
||||
all_files.extend([f"s3://{bucket}/{obj['Key']}" for obj in page["Contents"] if obj["Key"].endswith(".jsonl") or obj["Key"].endswith(".json")])
|
||||
|
||||
if len(all_files) > 1000:
|
||||
break
|
||||
# if len(all_files) > 1000:
|
||||
# break
|
||||
|
||||
return all_files
|
||||
|
||||
|
||||
@ -214,6 +214,44 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
bucket, key = self.s3_client.delete_object.call_args[1]["Bucket"], self.s3_client.delete_object.call_args[1]["Key"]
|
||||
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
|
||||
|
||||
@async_test
|
||||
async def test_paths_with_commas(self):
|
||||
"""Test handling of paths that contain commas"""
|
||||
# Create paths with commas in them
|
||||
paths_with_commas = ["s3://test-bucket/data/file1,with,commas.pdf", "s3://test-bucket/data/file2,comma.pdf", "s3://test-bucket/data/file3.pdf"]
|
||||
|
||||
# Mock empty existing index for initial population
|
||||
with patch("olmocr.work_queue.download_zstd_csv", return_value=[]):
|
||||
with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
|
||||
# Populate the queue with these paths
|
||||
await self.work_queue.populate_queue(paths_with_commas, items_per_group=3)
|
||||
|
||||
# Capture what would be written to the index
|
||||
_, _, lines = mock_upload.call_args[0]
|
||||
|
||||
# Now simulate reading back these lines (which have commas in the paths)
|
||||
with patch("olmocr.work_queue.download_zstd_csv", return_value=lines):
|
||||
with patch("olmocr.work_queue.expand_s3_glob", return_value=[]):
|
||||
# Initialize a fresh queue from these lines
|
||||
await self.work_queue.initialize_queue()
|
||||
|
||||
# Mock ClientError for head_object (file doesn't exist)
|
||||
self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
|
||||
|
||||
# Get a work item
|
||||
work_item = await self.work_queue.get_work()
|
||||
|
||||
# Now verify we get a work item
|
||||
self.assertIsNotNone(work_item, "Should get a work item")
|
||||
|
||||
# Verify the work item has the correct number of paths
|
||||
self.assertEqual(len(work_item.work_paths), len(paths_with_commas), "Work item should have the correct number of paths")
|
||||
|
||||
# Check that all original paths with commas are preserved
|
||||
for path in paths_with_commas:
|
||||
print(path)
|
||||
self.assertIn(path, work_item.work_paths, f"Path with commas should be preserved: {path}")
|
||||
|
||||
def test_queue_size(self):
|
||||
"""Test queue size property"""
|
||||
self.assertEqual(self.work_queue.size, 0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user