mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-04 03:56:16 +00:00 
			
		
		
		
	Starting up server and workers async now
This commit is contained in:
		
							parent
							
								
									a39350e074
								
							
						
					
					
						commit
						ee72b3601e
					
				@ -74,17 +74,6 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
 | 
				
			|||||||
        sha1.update(pdf.encode('utf-8'))
 | 
					        sha1.update(pdf.encode('utf-8'))
 | 
				
			||||||
    return sha1.hexdigest()
 | 
					    return sha1.hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def start_sglang_server(args):
 | 
					 | 
				
			||||||
    model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
 | 
					 | 
				
			||||||
    download_directory(args.model, model_cache_dir)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Start up the sglang server
 | 
					 | 
				
			||||||
    sglang_process = subprocess.Popen([
 | 
					 | 
				
			||||||
            "python3", "-m", "sglang.launch_server",
 | 
					 | 
				
			||||||
            "--model-path", model_cache_dir,
 | 
					 | 
				
			||||||
            "--chat-template", args.model_chat_template,
 | 
					 | 
				
			||||||
            "--context-length", str(args.model_max_context),
 | 
					 | 
				
			||||||
        ])
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def populate_pdf_work_queue(args):
 | 
					async def populate_pdf_work_queue(args):
 | 
				
			||||||
    index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
 | 
					    index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
 | 
				
			||||||
@ -152,44 +141,73 @@ async def populate_pdf_work_queue(args):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
async def load_pdf_work_queue(args) -> asyncio.Queue:
 | 
					async def load_pdf_work_queue(args) -> asyncio.Queue:
 | 
				
			||||||
    index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
 | 
					    index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
 | 
				
			||||||
    
 | 
					    output_glob = f"{args.workspace}/dolma_documents/output_*.jsonl"
 | 
				
			||||||
    # Read in the work queue from s3
 | 
					 | 
				
			||||||
    work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
 | 
					 | 
				
			||||||
    work_queue = {}
 | 
					 | 
				
			||||||
    for line in work_queue_lines:
 | 
					 | 
				
			||||||
        if line.strip():
 | 
					 | 
				
			||||||
            parts = line.strip().split(",")
 | 
					 | 
				
			||||||
            group_hash = parts[0]
 | 
					 | 
				
			||||||
            group_pdfs = parts[1:]
 | 
					 | 
				
			||||||
            work_queue[group_hash] = group_pdfs
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Read in the done items from the s3 workspace
 | 
					    # Define the two blocking I/O operations
 | 
				
			||||||
    done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl")
 | 
					    download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path)
 | 
				
			||||||
    done_work_hashes = set()
 | 
					    expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob)
 | 
				
			||||||
    for item in done_work_items:
 | 
					 | 
				
			||||||
        filename = os.path.basename(item)
 | 
					 | 
				
			||||||
        if filename.startswith('output_') and filename.endswith('.jsonl'):
 | 
					 | 
				
			||||||
            group_hash = filename[len('output_'):-len('.jsonl')]
 | 
					 | 
				
			||||||
            done_work_hashes.add(group_hash)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    remaining_work_hashes = set(work_queue.keys()) - done_work_hashes
 | 
					    # Run both tasks concurrently
 | 
				
			||||||
    remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes}
 | 
					    work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Process the work queue lines
 | 
				
			||||||
 | 
					    work_queue = {
 | 
				
			||||||
 | 
					        parts[0]: parts[1:]
 | 
				
			||||||
 | 
					        for line in work_queue_lines
 | 
				
			||||||
 | 
					        if (parts := line.strip().split(",")) and line.strip()
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Extract done work hashes
 | 
				
			||||||
 | 
					    done_work_hashes = {
 | 
				
			||||||
 | 
					        os.path.basename(item)[len('output_'):-len('.jsonl')]
 | 
				
			||||||
 | 
					        for item in done_work_items
 | 
				
			||||||
 | 
					        if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl')
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Determine remaining work
 | 
				
			||||||
 | 
					    remaining_work_hashes = set(work_queue) - done_work_hashes
 | 
				
			||||||
 | 
					    remaining_work_queue = {
 | 
				
			||||||
 | 
					        hash_: work_queue[hash_]
 | 
				
			||||||
 | 
					        for hash_ in remaining_work_hashes
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Populate the asyncio.Queue with remaining work
 | 
				
			||||||
    queue = asyncio.Queue()
 | 
					    queue = asyncio.Queue()
 | 
				
			||||||
 | 
					    for work, pdfs in remaining_work_queue.items():
 | 
				
			||||||
    for work in remaining_work_queue:
 | 
					        await queue.put((work, pdfs))
 | 
				
			||||||
        await queue.put((work, remaining_work_queue[work]))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return queue
 | 
					    return queue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def process_pdf(args, pdf_s3_path):
 | 
				
			||||||
 | 
					    await asyncio.sleep(1)
 | 
				
			||||||
 | 
					    return f"pdf: {pdf_s3_path}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def worker(args, queue):
 | 
					async def worker(args, queue):
 | 
				
			||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        work = await queue.get()
 | 
					        [work_hash, pdfs] = await queue.get()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        completed_pdfs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
 | 
				
			||||||
 | 
					        logger.info(f"Completed {completed_pdfs}")
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        logger.info(f"Got work to do for {work}")
 | 
					 | 
				
			||||||
        queue.task_done()
 | 
					        queue.task_done()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def sglang_server_task(args):
 | 
				
			||||||
 | 
					    model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
 | 
				
			||||||
 | 
					    #download_directory(args.model, model_cache_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    proc = await asyncio.create_subprocess_exec(
 | 
				
			||||||
 | 
					        "python3",
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        "-m", "sglang.launch_server",
 | 
				
			||||||
 | 
					        "--model-path", model_cache_dir,
 | 
				
			||||||
 | 
					        "--chat-template", args.model_chat_template,
 | 
				
			||||||
 | 
					        "--context-length", str(args.model_max_context),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await proc.wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def main():
 | 
					async def main():
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
 | 
					    parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
 | 
				
			||||||
    parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
 | 
					    parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
 | 
				
			||||||
@ -222,25 +240,31 @@ async def main():
 | 
				
			|||||||
    if args.pdfs:
 | 
					    if args.pdfs:
 | 
				
			||||||
        await populate_pdf_work_queue(args)
 | 
					        await populate_pdf_work_queue(args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sglang_server = asyncio.create_task(sglang_server_task(args))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    work_queue = await load_pdf_work_queue(args)
 | 
					    work_queue = await load_pdf_work_queue(args)
 | 
				
			||||||
    logger.info(f"Work queue prepared with {work_queue.qsize()} items")
 | 
					    logger.info(f"Work queue prepared with {work_queue.qsize()} items")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Create worker tasks to process the queue concurrently.
 | 
					    # Create worker tasks to process the queue concurrently.
 | 
				
			||||||
    tasks = []
 | 
					    worker_tasks = []
 | 
				
			||||||
    for i in range(args.workers):
 | 
					    for i in range(args.workers):
 | 
				
			||||||
        task = asyncio.create_task(worker(args, work_queue))
 | 
					        task = asyncio.create_task(worker(args, work_queue))
 | 
				
			||||||
        tasks.append(task)
 | 
					        worker_tasks.append(task)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Wait for the queue to be fully processed
 | 
					    # Wait for the queue to be fully processed
 | 
				
			||||||
    await work_queue.join()
 | 
					    await work_queue.join()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Cancel our worker tasks.
 | 
					    # Cancel our worker tasks.
 | 
				
			||||||
    for task in tasks:
 | 
					    for task in worker_tasks:
 | 
				
			||||||
        task.cancel()
 | 
					        task.cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Wait until all worker tasks are cancelled.
 | 
					    # Wait until all worker tasks are cancelled.
 | 
				
			||||||
    await asyncio.gather(*tasks, return_exceptions=True)
 | 
					    await asyncio.gather(*worker_tasks, return_exceptions=True)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Wait for server to stop
 | 
				
			||||||
 | 
					    sglang_server.cancel()
 | 
				
			||||||
 | 
					    await sglang_server
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
				
			|||||||
@ -24,8 +24,8 @@ logging.basicConfig(level=logging.INFO)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_s3_path(s3_path: str) -> tuple[str, str]:
 | 
					def parse_s3_path(s3_path: str) -> tuple[str, str]:
 | 
				
			||||||
    if not s3_path.startswith('s3://'):
 | 
					    if not (s3_path.startswith('s3://') or s3_path.startswith('gs://') or s3_path.startswith('weka://')):
 | 
				
			||||||
        raise ValueError('s3_path must start with s3://')
 | 
					        raise ValueError('s3_path must start with s3://, gs://, or weka://')
 | 
				
			||||||
    parsed = urlparse(s3_path)
 | 
					    parsed = urlparse(s3_path)
 | 
				
			||||||
    bucket = parsed.netloc
 | 
					    bucket = parsed.netloc
 | 
				
			||||||
    key = parsed.path.lstrip('/')
 | 
					    key = parsed.path.lstrip('/')
 | 
				
			||||||
@ -137,10 +137,10 @@ def download_directory(model_choices: list[str], local_dir: str):
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
    Download the model to a specified local directory.
 | 
					    Download the model to a specified local directory.
 | 
				
			||||||
    The function will attempt to download from the first available source in the provided list.
 | 
					    The function will attempt to download from the first available source in the provided list.
 | 
				
			||||||
    Supports Google Cloud Storage (gs://) and Amazon S3 (s3://) links.
 | 
					    Supports Weka (weka://), Google Cloud Storage (gs://), and Amazon S3 (s3://) links.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
        model_choices (list[str]): List of model paths (gs:// or s3://).
 | 
					        model_choices (list[str]): List of model paths (weka://, gs://, or s3://).
 | 
				
			||||||
        local_dir (str): Local directory path where the model will be downloaded.
 | 
					        local_dir (str): Local directory path where the model will be downloaded.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Raises:
 | 
					    Raises:
 | 
				
			||||||
@ -151,11 +151,20 @@ def download_directory(model_choices: list[str], local_dir: str):
 | 
				
			|||||||
    local_path.mkdir(parents=True, exist_ok=True)
 | 
					    local_path.mkdir(parents=True, exist_ok=True)
 | 
				
			||||||
    logger.info(f"Local directory set to: {local_path}")
 | 
					    logger.info(f"Local directory set to: {local_path}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Reorder model_choices to prioritize weka:// links
 | 
				
			||||||
 | 
					    weka_choices = [path for path in model_choices if path.startswith("weka://")]
 | 
				
			||||||
 | 
					    other_choices = [path for path in model_choices if not path.startswith("weka://")]
 | 
				
			||||||
 | 
					    prioritized_choices = weka_choices + other_choices
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Iterate through the provided choices and attempt to download from the first available source
 | 
					    # Iterate through the provided choices and attempt to download from the first available source
 | 
				
			||||||
    for model_path in model_choices:
 | 
					    for model_path in prioritized_choices:
 | 
				
			||||||
        logger.info(f"Attempting to download from: {model_path}")
 | 
					        logger.info(f"Attempting to download from: {model_path}")
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            if model_path.startswith("gs://"):
 | 
					            if model_path.startswith("weka://"):
 | 
				
			||||||
 | 
					                download_dir_from_weka(model_path, str(local_path))
 | 
				
			||||||
 | 
					                logger.info(f"Successfully downloaded model from Weka: {model_path}")
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					            elif model_path.startswith("gs://"):
 | 
				
			||||||
                download_dir_from_gcs(model_path, str(local_path))
 | 
					                download_dir_from_gcs(model_path, str(local_path))
 | 
				
			||||||
                logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}")
 | 
					                logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}")
 | 
				
			||||||
                return
 | 
					                return
 | 
				
			||||||
@ -229,3 +238,57 @@ def download_dir_from_s3(s3_path: str, local_dir: str):
 | 
				
			|||||||
            pass
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info(f"Downloaded model from S3 to {local_dir}")
 | 
					    logger.info(f"Downloaded model from S3 to {local_dir}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def download_dir_from_weka(weka_path: str, local_dir: str):
 | 
				
			||||||
 | 
					    """Download model files from Weka to a local directory."""
 | 
				
			||||||
 | 
					    # Retrieve Weka credentials from environment variables
 | 
				
			||||||
 | 
					    weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID")
 | 
				
			||||||
 | 
					    weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY")
 | 
				
			||||||
 | 
					    if not weka_access_key or not weka_secret_key:
 | 
				
			||||||
 | 
					        raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY environment variables must be set for Weka access.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Configure the boto3 client for Weka
 | 
				
			||||||
 | 
					    weka_endpoint = "https://weka-aus.beaker.org:9000"
 | 
				
			||||||
 | 
					    boto3_config = Config(
 | 
				
			||||||
 | 
					        max_pool_connections=50,  # Adjust this number based on your requirements
 | 
				
			||||||
 | 
					        signature_version='s3v4',
 | 
				
			||||||
 | 
					        retries={'max_attempts': 10, 'mode': 'standard'}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    s3_client = boto3.client(
 | 
				
			||||||
 | 
					        's3',
 | 
				
			||||||
 | 
					        endpoint_url=weka_endpoint,
 | 
				
			||||||
 | 
					        aws_access_key_id=weka_access_key,
 | 
				
			||||||
 | 
					        aws_secret_access_key=weka_secret_key,
 | 
				
			||||||
 | 
					        config=boto3_config
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bucket, prefix = parse_s3_path(weka_path)
 | 
				
			||||||
 | 
					    paginator = s3_client.get_paginator("list_objects_v2")
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
 | 
				
			||||||
 | 
					    except s3_client.exceptions.NoSuchBucket:
 | 
				
			||||||
 | 
					        raise ValueError(f"The bucket '{bucket}' does not exist in Weka.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    objects = []
 | 
				
			||||||
 | 
					    for page in pages:
 | 
				
			||||||
 | 
					        if 'Contents' in page:
 | 
				
			||||||
 | 
					            objects.extend(page['Contents'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    total_files = len(objects)
 | 
				
			||||||
 | 
					    logger.info(f"Found {total_files} files in Weka bucket '{bucket}' with prefix '{prefix}'.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with concurrent.futures.ThreadPoolExecutor() as executor:
 | 
				
			||||||
 | 
					        futures = []
 | 
				
			||||||
 | 
					        for obj in objects:
 | 
				
			||||||
 | 
					            key = obj["Key"]
 | 
				
			||||||
 | 
					            relative_path = os.path.relpath(key, prefix)
 | 
				
			||||||
 | 
					            local_file_path = os.path.join(local_dir, relative_path)
 | 
				
			||||||
 | 
					            os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
 | 
				
			||||||
 | 
					            futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Use tqdm to display progress
 | 
				
			||||||
 | 
					        for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from Weka"):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logger.info(f"Downloaded model from Weka to {local_dir}")
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user