mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-01 04:46:16 +00:00
Better work queue
This commit is contained in:
parent
04429b2862
commit
e499413089
@ -24,9 +24,10 @@ from PIL import Image
|
|||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, List, Dict, Set
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
|
from pdelfin.s3_queue import S3WorkQueue, WorkItem
|
||||||
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
|
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
|
||||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||||
from pdelfin.prompts import build_finetuning_prompt, PageResponse
|
from pdelfin.prompts import build_finetuning_prompt, PageResponse
|
||||||
@ -123,160 +124,6 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def compute_workgroup_sha1(work_group: list[str]) -> str:
|
|
||||||
sha1 = hashlib.sha1()
|
|
||||||
# Ensure consistent ordering by sorting the list
|
|
||||||
for pdf in sorted(work_group):
|
|
||||||
sha1.update(pdf.encode('utf-8'))
|
|
||||||
return sha1.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
async def populate_pdf_work_queue(args):
|
|
||||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
|
||||||
|
|
||||||
if args.pdfs.startswith("s3://"):
|
|
||||||
logger.info(f"Expanding s3 glob at {args.pdfs}")
|
|
||||||
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
|
|
||||||
elif os.path.exists(args.pdfs):
|
|
||||||
logger.info(f"Loading file at {args.pdfs}")
|
|
||||||
with open(args.pdfs, "r") as f:
|
|
||||||
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
|
|
||||||
else:
|
|
||||||
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
|
||||||
|
|
||||||
all_pdfs = set(all_pdfs)
|
|
||||||
logger.info(f"Found {len(all_pdfs):,} total pdf paths")
|
|
||||||
|
|
||||||
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
|
||||||
|
|
||||||
# Parse existing work items into groups
|
|
||||||
existing_groups = {}
|
|
||||||
for line in existing_lines:
|
|
||||||
if line.strip():
|
|
||||||
parts = line.strip().split(",")
|
|
||||||
group_hash = parts[0]
|
|
||||||
group_pdfs = parts[1:]
|
|
||||||
existing_groups[group_hash] = group_pdfs
|
|
||||||
existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs)
|
|
||||||
|
|
||||||
logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")
|
|
||||||
|
|
||||||
# Remove existing PDFs from all_pdfs
|
|
||||||
new_pdfs = all_pdfs - existing_pdf_set
|
|
||||||
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")
|
|
||||||
|
|
||||||
sample_size = min(100, len(new_pdfs))
|
|
||||||
sampled_pdfs = random.sample(list(new_pdfs), sample_size)
|
|
||||||
|
|
||||||
page_counts = []
|
|
||||||
|
|
||||||
for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimial length"):
|
|
||||||
try:
|
|
||||||
# Download the PDF to a temp file
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
|
|
||||||
s3_bucket, s3_key = parse_s3_path(pdf)
|
|
||||||
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
|
|
||||||
tmp_file.flush()
|
|
||||||
reader = PdfReader(tmp_file.name)
|
|
||||||
page_counts.append(len(reader.pages))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to read {pdf}: {e}")
|
|
||||||
|
|
||||||
if page_counts:
|
|
||||||
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
|
|
||||||
else:
|
|
||||||
logger.warning("Could not read any PDFs to estimate average page count.")
|
|
||||||
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
|
|
||||||
|
|
||||||
group_size = max(1, int(args.pages_per_group / avg_pages_per_pdf))
|
|
||||||
logger.info(f"Calculated group_size: {group_size} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
|
|
||||||
|
|
||||||
new_groups = []
|
|
||||||
current_group = []
|
|
||||||
for pdf in sorted(new_pdfs): # Sort for consistency
|
|
||||||
current_group.append(pdf)
|
|
||||||
if len(current_group) == group_size:
|
|
||||||
group_hash = compute_workgroup_sha1(current_group)
|
|
||||||
new_groups.append((group_hash, current_group))
|
|
||||||
current_group = []
|
|
||||||
if current_group:
|
|
||||||
group_hash = compute_workgroup_sha1(current_group)
|
|
||||||
new_groups.append((group_hash, current_group))
|
|
||||||
|
|
||||||
logger.info(f"Created {len(new_groups):,} new work groups")
|
|
||||||
|
|
||||||
# Combine existing groups with new groups
|
|
||||||
combined_groups = existing_groups.copy()
|
|
||||||
for group_hash, group_pdfs in new_groups:
|
|
||||||
combined_groups[group_hash] = group_pdfs
|
|
||||||
|
|
||||||
# Prepare lines to write back
|
|
||||||
combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()]
|
|
||||||
|
|
||||||
# Upload the combined work items back to S3
|
|
||||||
if new_groups:
|
|
||||||
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
|
|
||||||
|
|
||||||
logger.info("Completed adding new PDFs.")
|
|
||||||
|
|
||||||
async def load_pdf_work_queue(args) -> asyncio.Queue:
|
|
||||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
|
||||||
output_glob = os.path.join(args.workspace, "dolma_documents", "*.jsonl")
|
|
||||||
|
|
||||||
# Define the two blocking I/O operations
|
|
||||||
download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path)
|
|
||||||
expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob)
|
|
||||||
|
|
||||||
# Run both tasks concurrently
|
|
||||||
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
|
|
||||||
|
|
||||||
# Process the work queue lines
|
|
||||||
work_queue = {
|
|
||||||
parts[0]: parts[1:]
|
|
||||||
for line in work_queue_lines
|
|
||||||
if (parts := line.strip().split(",")) and line.strip()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Extract done work hashes
|
|
||||||
done_work_hashes = {
|
|
||||||
os.path.basename(item)[len('output_'):-len('.jsonl')]
|
|
||||||
for item in done_work_items
|
|
||||||
if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl')
|
|
||||||
}
|
|
||||||
|
|
||||||
# Determine remaining work
|
|
||||||
remaining_work_hashes = set(work_queue) - done_work_hashes
|
|
||||||
#remaining_work_hashes = set(["0e779f21fbb75d38ed4242c7e5fe57fa9a636bac"]) # If you want to debug with a specific work hash
|
|
||||||
remaining_work_queue = {
|
|
||||||
hash_: work_queue[hash_]
|
|
||||||
for hash_ in remaining_work_hashes
|
|
||||||
}
|
|
||||||
|
|
||||||
# Populate the asyncio.Queue with remaining work
|
|
||||||
queue = asyncio.Queue()
|
|
||||||
shuffled_items = list(remaining_work_queue.items())
|
|
||||||
random.shuffle(shuffled_items)
|
|
||||||
|
|
||||||
for work, pdfs in shuffled_items:
|
|
||||||
await queue.put((work, pdfs))
|
|
||||||
|
|
||||||
return queue
|
|
||||||
|
|
||||||
async def work_item_completed(args, work_hash: str) -> bool:
|
|
||||||
# Check if work item has already been completed
|
|
||||||
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
|
|
||||||
bucket, key = parse_s3_path(output_s3_path)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check if the output file already exists
|
|
||||||
await asyncio.to_thread(workspace_s3.head_object, Bucket=bucket, Key=key)
|
|
||||||
return True
|
|
||||||
except workspace_s3.exceptions.ClientError as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||||
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||||
MAX_RETRIES = 3
|
MAX_RETRIES = 3
|
||||||
@ -442,31 +289,32 @@ def build_dolma_document(pdf_s3_path, page_results):
|
|||||||
}
|
}
|
||||||
return dolma_doc
|
return dolma_doc
|
||||||
|
|
||||||
async def worker(args, queue, semaphore, worker_id):
|
|
||||||
|
async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
|
||||||
while True:
|
while True:
|
||||||
[work_hash, pdfs] = await queue.get()
|
# Wait until allowed to proceed
|
||||||
|
await semaphore.acquire()
|
||||||
|
|
||||||
try:
|
work_item = await work_queue.get_work()
|
||||||
await tracker.clear_work(worker_id)
|
|
||||||
|
|
||||||
# Wait until allowed to proceed
|
if work_item is None:
|
||||||
await semaphore.acquire()
|
logger.info(f"Worker {worker_id} exiting due to empty queue")
|
||||||
|
semaphore.release()
|
||||||
|
break
|
||||||
|
|
||||||
if await work_item_completed(args, work_hash):
|
logger.info(f"Worker {worker_id} processing work item {work_item.hash}")
|
||||||
logger.info(f"Work {work_hash} was already completed, skipping")
|
await tracker.clear_work(worker_id)
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logger.info(f"Proceeding with {work_hash} on worker {worker_id}")
|
|
||||||
|
|
||||||
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600),
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600),
|
||||||
connector=aiohttp.TCPConnector(limit=1000)) as session:
|
connector=aiohttp.TCPConnector(limit=1000)) as session:
|
||||||
async with asyncio.TaskGroup() as tg:
|
async with asyncio.TaskGroup() as tg:
|
||||||
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in pdfs]
|
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in work_item.s3_work_paths]
|
||||||
logger.info(f"Created all tasks for {work_hash}")
|
logger.info(f"Created all tasks for {work_item.hash}")
|
||||||
|
|
||||||
logger.info(f"Finished TaskGroup for worker on {work_hash}")
|
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
|
||||||
|
|
||||||
logger.info(f"Closed ClientSession for {work_hash}")
|
logger.info(f"Closed ClientSession for {work_item.hash}")
|
||||||
|
|
||||||
dolma_docs = []
|
dolma_docs = []
|
||||||
for task in dolma_tasks:
|
for task in dolma_tasks:
|
||||||
@ -479,7 +327,7 @@ async def worker(args, queue, semaphore, worker_id):
|
|||||||
if result is not None:
|
if result is not None:
|
||||||
dolma_docs.append(result)
|
dolma_docs.append(result)
|
||||||
|
|
||||||
logger.info(f"Got {len(dolma_docs)} docs for {work_hash}")
|
logger.info(f"Got {len(dolma_docs)} docs for {work_item.hash}")
|
||||||
|
|
||||||
# Write the Dolma documents to a local temporary file in JSONL format
|
# Write the Dolma documents to a local temporary file in JSONL format
|
||||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
|
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
|
||||||
@ -489,7 +337,7 @@ async def worker(args, queue, semaphore, worker_id):
|
|||||||
tf.flush()
|
tf.flush()
|
||||||
|
|
||||||
# Define the output S3 path using the work_hash
|
# Define the output S3 path using the work_hash
|
||||||
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
|
output_s3_path = os.path.join(args.workspace, 'results', f'output_{work_item.hash}.jsonl')
|
||||||
|
|
||||||
bucket, key = parse_s3_path(output_s3_path)
|
bucket, key = parse_s3_path(output_s3_path)
|
||||||
workspace_s3.upload_file(tf.name, bucket, key)
|
workspace_s3.upload_file(tf.name, bucket, key)
|
||||||
@ -501,9 +349,10 @@ async def worker(args, queue, semaphore, worker_id):
|
|||||||
# Update last batch time
|
# Update last batch time
|
||||||
last_batch_time = time.perf_counter()
|
last_batch_time = time.perf_counter()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Exception occurred while processing work_hash {work_hash}: {e}")
|
logger.exception(f"Exception occurred while processing work_hash {work_item.hash}: {e}")
|
||||||
finally:
|
finally:
|
||||||
queue.task_done()
|
await work_queue.mark_done(work_item)
|
||||||
|
semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
async def sglang_server_task(args, semaphore):
|
async def sglang_server_task(args, semaphore):
|
||||||
@ -563,6 +412,7 @@ async def sglang_server_task(args, semaphore):
|
|||||||
|
|
||||||
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
|
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
|
||||||
server_printed_ready_message = True
|
server_printed_ready_message = True
|
||||||
|
last_semaphore_release = time.time()
|
||||||
|
|
||||||
match = re.search(r'#running-req: (\d+)', line)
|
match = re.search(r'#running-req: (\d+)', line)
|
||||||
if match:
|
if match:
|
||||||
@ -631,10 +481,10 @@ async def sglang_server_ready():
|
|||||||
raise Exception("sglang server did not become ready after waiting.")
|
raise Exception("sglang server did not become ready after waiting.")
|
||||||
|
|
||||||
|
|
||||||
async def metrics_reporter(queue):
|
async def metrics_reporter(work_queue):
|
||||||
while True:
|
while True:
|
||||||
# Leading newlines preserve table formatting in logs
|
# Leading newlines preserve table formatting in logs
|
||||||
logger.info(f"Queue remaining: {queue.qsize()}")
|
logger.info(f"Queue remaining: {work_queue.size}")
|
||||||
logger.info("\n" + str(metrics))
|
logger.info("\n" + str(metrics))
|
||||||
logger.info("\n" + str(await tracker.get_status_table()))
|
logger.info("\n" + str(await tracker.get_status_table()))
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
@ -716,13 +566,14 @@ def submit_beaker_job(args):
|
|||||||
|
|
||||||
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
|
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
|
||||||
|
|
||||||
|
|
||||||
def print_stats(args):
|
def print_stats(args):
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Get total work items and completed items
|
# Get total work items and completed items
|
||||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd")
|
||||||
output_glob = os.path.join(args.workspace, "dolma_documents", "*.jsonl")
|
output_glob = os.path.join(args.workspace, "results", "*.jsonl")
|
||||||
|
|
||||||
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
||||||
done_work_items = expand_s3_glob(workspace_s3, output_glob)
|
done_work_items = expand_s3_glob(workspace_s3, output_glob)
|
||||||
@ -825,9 +676,54 @@ async def main():
|
|||||||
|
|
||||||
check_poppler_version()
|
check_poppler_version()
|
||||||
|
|
||||||
|
# Create work queue
|
||||||
|
work_queue = S3WorkQueue(workspace_s3, args.workspace)
|
||||||
|
|
||||||
if args.pdfs:
|
if args.pdfs:
|
||||||
logger.info("Got --pdfs argument, going to add to the work queue")
|
logger.info("Got --pdfs argument, going to add to the work queue")
|
||||||
await populate_pdf_work_queue(args)
|
|
||||||
|
# Expand s3 paths
|
||||||
|
if args.pdfs.startswith("s3://"):
|
||||||
|
logger.info(f"Expanding s3 glob at {args.pdfs}")
|
||||||
|
s3_work_paths = expand_s3_glob(pdf_s3, args.pdfs)
|
||||||
|
elif os.path.exists(args.pdfs):
|
||||||
|
logger.info(f"Loading file at {args.pdfs}")
|
||||||
|
with open(args.pdfs, "r") as f:
|
||||||
|
s3_work_paths = list(filter(None, (line.strip() for line in f)))
|
||||||
|
else:
|
||||||
|
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
||||||
|
|
||||||
|
s3_work_paths = set(s3_work_paths)
|
||||||
|
logger.info(f"Found {len(s3_work_paths):,} total pdf paths to add")
|
||||||
|
|
||||||
|
# Estimate average pages per pdf
|
||||||
|
sample_size = min(100, len(s3_work_paths))
|
||||||
|
sampled_pdfs = random.sample(list(s3_work_paths), sample_size)
|
||||||
|
page_counts = []
|
||||||
|
|
||||||
|
for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"):
|
||||||
|
try:
|
||||||
|
# Download the PDF to a temp file
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
|
||||||
|
s3_bucket, s3_key = parse_s3_path(pdf)
|
||||||
|
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
|
||||||
|
tmp_file.flush()
|
||||||
|
reader = PdfReader(tmp_file.name)
|
||||||
|
page_counts.append(len(reader.pages))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to read {pdf}: {e}")
|
||||||
|
|
||||||
|
if page_counts:
|
||||||
|
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
|
||||||
|
else:
|
||||||
|
logger.warning("Could not read any PDFs to estimate average page count.")
|
||||||
|
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
|
||||||
|
|
||||||
|
items_per_group = max(1, int(args.pages_per_group / avg_pages_per_pdf))
|
||||||
|
logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
|
||||||
|
|
||||||
|
# Now call populate_queue
|
||||||
|
await work_queue.populate_queue(s3_work_paths, items_per_group)
|
||||||
|
|
||||||
if args.stats:
|
if args.stats:
|
||||||
print_stats(args)
|
print_stats(args)
|
||||||
@ -839,6 +735,9 @@ async def main():
|
|||||||
|
|
||||||
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
||||||
|
|
||||||
|
# Initialize the work queue
|
||||||
|
await work_queue.initialize_queue()
|
||||||
|
|
||||||
# Create a semaphore to control worker access
|
# Create a semaphore to control worker access
|
||||||
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
|
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
|
||||||
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
|
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
|
||||||
@ -847,9 +746,6 @@ async def main():
|
|||||||
|
|
||||||
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
|
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
|
||||||
|
|
||||||
work_queue = await load_pdf_work_queue(args)
|
|
||||||
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
|
|
||||||
|
|
||||||
await sglang_server_ready()
|
await sglang_server_ready()
|
||||||
|
|
||||||
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
||||||
@ -860,16 +756,9 @@ async def main():
|
|||||||
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
|
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
|
||||||
worker_tasks.append(task)
|
worker_tasks.append(task)
|
||||||
|
|
||||||
# Wait for the queue to be fully processed
|
# Wait for all worker tasks to finish
|
||||||
await work_queue.join()
|
await asyncio.gather(*worker_tasks)
|
||||||
|
|
||||||
# Cancel our worker tasks.
|
|
||||||
for task in worker_tasks:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
# Wait until all worker tasks are cancelled.
|
|
||||||
await asyncio.gather(*worker_tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
# Wait for server to stop
|
# Wait for server to stop
|
||||||
process_pool.shutdown(wait=False)
|
process_pool.shutdown(wait=False)
|
||||||
|
|
||||||
@ -877,11 +766,12 @@ async def main():
|
|||||||
metrics_task.cancel()
|
metrics_task.cancel()
|
||||||
logger.info("Work done")
|
logger.info("Work done")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# - Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied)
|
# X Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied)
|
||||||
# X Fix the queue release mechanism so that it just does a timeout, based on zero queue size only, so you don't block things
|
# X Fix the queue release mechanism so that it just does a timeout, based on zero queue size only, so you don't block things
|
||||||
# - Add logging of failed pages and have the stats function read them
|
# - Add logging of failed pages and have the stats function read them
|
||||||
# X Add the page rotation check and mechanism
|
# X Add the page rotation check and mechanism
|
||||||
|
@ -37,12 +37,6 @@ class TestS3WorkQueue(unittest.TestCase):
|
|||||||
hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
|
hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
|
||||||
self.assertEqual(hash1, hash2)
|
self.assertEqual(hash1, hash2)
|
||||||
|
|
||||||
# Verify hash is actually SHA1
|
|
||||||
sha1 = hashlib.sha1()
|
|
||||||
for path in sorted(paths):
|
|
||||||
sha1.update(path.encode('utf-8'))
|
|
||||||
self.assertEqual(hash1, sha1.hexdigest())
|
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
"""Test initialization of S3WorkQueue"""
|
"""Test initialization of S3WorkQueue"""
|
||||||
client = Mock()
|
client = Mock()
|
||||||
@ -51,7 +45,6 @@ class TestS3WorkQueue(unittest.TestCase):
|
|||||||
self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
|
self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
|
||||||
self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
|
self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
|
||||||
self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
|
self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
|
||||||
self.assertIsInstance(queue._queue, asyncio.Queue)
|
|
||||||
|
|
||||||
def asyncSetUp(self):
|
def asyncSetUp(self):
|
||||||
"""Set up async test fixtures"""
|
"""Set up async test fixtures"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user