Better work queue

This commit is contained in:
Jake Poznanski 2024-11-18 11:04:51 -08:00
parent 04429b2862
commit e499413089
2 changed files with 82 additions and 199 deletions

View File

@ -24,9 +24,10 @@ from PIL import Image
from pypdf import PdfReader
from functools import partial
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple, List, Dict, Set
from concurrent.futures import ProcessPoolExecutor
from pdelfin.s3_queue import S3WorkQueue, WorkItem
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts import build_finetuning_prompt, PageResponse
@ -123,160 +124,6 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
}
def compute_workgroup_sha1(work_group: list[str]) -> str:
sha1 = hashlib.sha1()
# Ensure consistent ordering by sorting the list
for pdf in sorted(work_group):
sha1.update(pdf.encode('utf-8'))
return sha1.hexdigest()
async def populate_pdf_work_queue(args):
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
else:
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
all_pdfs = set(all_pdfs)
logger.info(f"Found {len(all_pdfs):,} total pdf paths")
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
# Parse existing work items into groups
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_pdfs = parts[1:]
existing_groups[group_hash] = group_pdfs
existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs)
logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")
# Remove existing PDFs from all_pdfs
new_pdfs = all_pdfs - existing_pdf_set
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")
sample_size = min(100, len(new_pdfs))
sampled_pdfs = random.sample(list(new_pdfs), sample_size)
page_counts = []
for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimial length"):
try:
# Download the PDF to a temp file
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
s3_bucket, s3_key = parse_s3_path(pdf)
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
tmp_file.flush()
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
except Exception as e:
logger.warning(f"Failed to read {pdf}: {e}")
if page_counts:
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
else:
logger.warning("Could not read any PDFs to estimate average page count.")
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
group_size = max(1, int(args.pages_per_group / avg_pages_per_pdf))
logger.info(f"Calculated group_size: {group_size} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
new_groups = []
current_group = []
for pdf in sorted(new_pdfs): # Sort for consistency
current_group.append(pdf)
if len(current_group) == group_size:
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine existing groups with new groups
combined_groups = existing_groups.copy()
for group_hash, group_pdfs in new_groups:
combined_groups[group_hash] = group_pdfs
# Prepare lines to write back
combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()]
# Upload the combined work items back to S3
if new_groups:
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
logger.info("Completed adding new PDFs.")
async def load_pdf_work_queue(args) -> asyncio.Queue:
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "dolma_documents", "*.jsonl")
# Define the two blocking I/O operations
download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path)
expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob)
# Run both tasks concurrently
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
# Process the work queue lines
work_queue = {
parts[0]: parts[1:]
for line in work_queue_lines
if (parts := line.strip().split(",")) and line.strip()
}
# Extract done work hashes
done_work_hashes = {
os.path.basename(item)[len('output_'):-len('.jsonl')]
for item in done_work_items
if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl')
}
# Determine remaining work
remaining_work_hashes = set(work_queue) - done_work_hashes
#remaining_work_hashes = set(["0e779f21fbb75d38ed4242c7e5fe57fa9a636bac"]) # If you want to debug with a specific work hash
remaining_work_queue = {
hash_: work_queue[hash_]
for hash_ in remaining_work_hashes
}
# Populate the asyncio.Queue with remaining work
queue = asyncio.Queue()
shuffled_items = list(remaining_work_queue.items())
random.shuffle(shuffled_items)
for work, pdfs in shuffled_items:
await queue.put((work, pdfs))
return queue
async def work_item_completed(args, work_hash: str) -> bool:
# Check if work item has already been completed
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
bucket, key = parse_s3_path(output_s3_path)
try:
# Check if the output file already exists
await asyncio.to_thread(workspace_s3.head_object, Bucket=bucket, Key=key)
return True
except workspace_s3.exceptions.ClientError as e:
pass
return False
async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = 3
@ -442,31 +289,32 @@ def build_dolma_document(pdf_s3_path, page_results):
}
return dolma_doc
async def worker(args, queue, semaphore, worker_id):
async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
while True:
[work_hash, pdfs] = await queue.get()
# Wait until allowed to proceed
await semaphore.acquire()
try:
await tracker.clear_work(worker_id)
work_item = await work_queue.get_work()
# Wait until allowed to proceed
await semaphore.acquire()
if work_item is None:
logger.info(f"Worker {worker_id} exiting due to empty queue")
semaphore.release()
break
if await work_item_completed(args, work_hash):
logger.info(f"Work {work_hash} was already completed, skipping")
continue
else:
logger.info(f"Proceeding with {work_hash} on worker {worker_id}")
logger.info(f"Worker {worker_id} processing work item {work_item.hash}")
await tracker.clear_work(worker_id)
try:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600),
connector=aiohttp.TCPConnector(limit=1000)) as session:
async with asyncio.TaskGroup() as tg:
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in pdfs]
logger.info(f"Created all tasks for {work_hash}")
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in work_item.s3_work_paths]
logger.info(f"Created all tasks for {work_item.hash}")
logger.info(f"Finished TaskGroup for worker on {work_hash}")
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
logger.info(f"Closed ClientSession for {work_hash}")
logger.info(f"Closed ClientSession for {work_item.hash}")
dolma_docs = []
for task in dolma_tasks:
@ -479,7 +327,7 @@ async def worker(args, queue, semaphore, worker_id):
if result is not None:
dolma_docs.append(result)
logger.info(f"Got {len(dolma_docs)} docs for {work_hash}")
logger.info(f"Got {len(dolma_docs)} docs for {work_item.hash}")
# Write the Dolma documents to a local temporary file in JSONL format
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
@ -489,7 +337,7 @@ async def worker(args, queue, semaphore, worker_id):
tf.flush()
# Define the output S3 path using the work_hash
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
output_s3_path = os.path.join(args.workspace, 'results', f'output_{work_item.hash}.jsonl')
bucket, key = parse_s3_path(output_s3_path)
workspace_s3.upload_file(tf.name, bucket, key)
@ -501,9 +349,10 @@ async def worker(args, queue, semaphore, worker_id):
# Update last batch time
last_batch_time = time.perf_counter()
except Exception as e:
logger.exception(f"Exception occurred while processing work_hash {work_hash}: {e}")
logger.exception(f"Exception occurred while processing work_hash {work_item.hash}: {e}")
finally:
queue.task_done()
await work_queue.mark_done(work_item)
semaphore.release()
async def sglang_server_task(args, semaphore):
@ -563,6 +412,7 @@ async def sglang_server_task(args, semaphore):
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
server_printed_ready_message = True
last_semaphore_release = time.time()
match = re.search(r'#running-req: (\d+)', line)
if match:
@ -631,10 +481,10 @@ async def sglang_server_ready():
raise Exception("sglang server did not become ready after waiting.")
async def metrics_reporter(queue):
async def metrics_reporter(work_queue):
while True:
# Leading newlines preserve table formatting in logs
logger.info(f"Queue remaining: {queue.qsize()}")
logger.info(f"Queue remaining: {work_queue.size}")
logger.info("\n" + str(metrics))
logger.info("\n" + str(await tracker.get_status_table()))
await asyncio.sleep(10)
@ -716,13 +566,14 @@ def submit_beaker_job(args):
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
def print_stats(args):
import concurrent.futures
from tqdm import tqdm
# Get total work items and completed items
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "dolma_documents", "*.jsonl")
index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "results", "*.jsonl")
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
done_work_items = expand_s3_glob(workspace_s3, output_glob)
@ -825,9 +676,54 @@ async def main():
check_poppler_version()
# Create work queue
work_queue = S3WorkQueue(workspace_s3, args.workspace)
if args.pdfs:
logger.info("Got --pdfs argument, going to add to the work queue")
await populate_pdf_work_queue(args)
# Expand s3 paths
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
s3_work_paths = expand_s3_glob(pdf_s3, args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
s3_work_paths = list(filter(None, (line.strip() for line in f)))
else:
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
s3_work_paths = set(s3_work_paths)
logger.info(f"Found {len(s3_work_paths):,} total pdf paths to add")
# Estimate average pages per pdf
sample_size = min(100, len(s3_work_paths))
sampled_pdfs = random.sample(list(s3_work_paths), sample_size)
page_counts = []
for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"):
try:
# Download the PDF to a temp file
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
s3_bucket, s3_key = parse_s3_path(pdf)
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
tmp_file.flush()
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
except Exception as e:
logger.warning(f"Failed to read {pdf}: {e}")
if page_counts:
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
else:
logger.warning("Could not read any PDFs to estimate average page count.")
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
items_per_group = max(1, int(args.pages_per_group / avg_pages_per_pdf))
logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
# Now call populate_queue
await work_queue.populate_queue(s3_work_paths, items_per_group)
if args.stats:
print_stats(args)
@ -839,6 +735,9 @@ async def main():
logger.info(f"Starting pipeline with PID {os.getpid()}")
# Initialize the work queue
await work_queue.initialize_queue()
# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
@ -847,9 +746,6 @@ async def main():
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
await sglang_server_ready()
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
@ -860,16 +756,9 @@ async def main():
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
worker_tasks.append(task)
# Wait for the queue to be fully processed
await work_queue.join()
# Wait for all worker tasks to finish
await asyncio.gather(*worker_tasks)
# Cancel our worker tasks.
for task in worker_tasks:
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*worker_tasks, return_exceptions=True)
# Wait for server to stop
process_pool.shutdown(wait=False)
@ -877,11 +766,12 @@ async def main():
metrics_task.cancel()
logger.info("Work done")
if __name__ == "__main__":
asyncio.run(main())
# TODO
# - Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied)
# X Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied)
# X Fix the queue release mechanism so that it just does a timeout, based on zero queue size only, so you don't block things
# - Add logging of failed pages and have the stats function read them
# X Add the page rotation check and mechanism

View File

@ -37,12 +37,6 @@ class TestS3WorkQueue(unittest.TestCase):
hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
self.assertEqual(hash1, hash2)
# Verify hash is actually SHA1
sha1 = hashlib.sha1()
for path in sorted(paths):
sha1.update(path.encode('utf-8'))
self.assertEqual(hash1, sha1.hexdigest())
def test_init(self):
"""Test initialization of S3WorkQueue"""
client = Mock()
@ -51,7 +45,6 @@ class TestS3WorkQueue(unittest.TestCase):
self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
self.assertIsInstance(queue._queue, asyncio.Queue)
def asyncSetUp(self):
"""Set up async test fixtures"""