Starting up server and workers async now

This commit is contained in:
Jake Poznanski 2024-11-08 09:14:00 -08:00
parent a39350e074
commit ee72b3601e
2 changed files with 133 additions and 46 deletions

View File

@ -74,17 +74,6 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
sha1.update(pdf.encode('utf-8')) sha1.update(pdf.encode('utf-8'))
return sha1.hexdigest() return sha1.hexdigest()
async def start_sglang_server(args):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
download_directory(args.model, model_cache_dir)
# Start up the sglang server
sglang_process = subprocess.Popen([
"python3", "-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
])
async def populate_pdf_work_queue(args): async def populate_pdf_work_queue(args):
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
@ -152,44 +141,73 @@ async def populate_pdf_work_queue(args):
async def load_pdf_work_queue(args) -> asyncio.Queue: async def load_pdf_work_queue(args) -> asyncio.Queue:
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
output_glob = f"{args.workspace}/dolma_documents/output_*.jsonl"
# Read in the work queue from s3
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
work_queue = {}
for line in work_queue_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_pdfs = parts[1:]
work_queue[group_hash] = group_pdfs
# Read in the done items from the s3 workspace # Define the two blocking I/O operations
done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl") download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path)
done_work_hashes = set() expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob)
for item in done_work_items:
filename = os.path.basename(item)
if filename.startswith('output_') and filename.endswith('.jsonl'):
group_hash = filename[len('output_'):-len('.jsonl')]
done_work_hashes.add(group_hash)
remaining_work_hashes = set(work_queue.keys()) - done_work_hashes # Run both tasks concurrently
remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes} work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
# Process the work queue lines
work_queue = {
parts[0]: parts[1:]
for line in work_queue_lines
if (parts := line.strip().split(",")) and line.strip()
}
# Extract done work hashes
done_work_hashes = {
os.path.basename(item)[len('output_'):-len('.jsonl')]
for item in done_work_items
if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl')
}
# Determine remaining work
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_work_queue = {
hash_: work_queue[hash_]
for hash_ in remaining_work_hashes
}
# Populate the asyncio.Queue with remaining work
queue = asyncio.Queue() queue = asyncio.Queue()
for work, pdfs in remaining_work_queue.items():
for work in remaining_work_queue: await queue.put((work, pdfs))
await queue.put((work, remaining_work_queue[work]))
return queue return queue
async def process_pdf(args, pdf_s3_path):
await asyncio.sleep(1)
return f"pdf: {pdf_s3_path}"
async def worker(args, queue): async def worker(args, queue):
while True: while True:
work = await queue.get() [work_hash, pdfs] = await queue.get()
completed_pdfs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
logger.info(f"Completed {completed_pdfs}")
logger.info(f"Got work to do for {work}")
queue.task_done() queue.task_done()
async def sglang_server_task(args):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
#download_directory(args.model, model_cache_dir)
proc = await asyncio.create_subprocess_exec(
"python3",
"-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
)
await proc.wait()
async def main(): async def main():
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
@ -222,25 +240,31 @@ async def main():
if args.pdfs: if args.pdfs:
await populate_pdf_work_queue(args) await populate_pdf_work_queue(args)
sglang_server = asyncio.create_task(sglang_server_task(args))
work_queue = await load_pdf_work_queue(args) work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items") logger.info(f"Work queue prepared with {work_queue.qsize()} items")
# Create worker tasks to process the queue concurrently. # Create worker tasks to process the queue concurrently.
tasks = [] worker_tasks = []
for i in range(args.workers): for i in range(args.workers):
task = asyncio.create_task(worker(args, work_queue)) task = asyncio.create_task(worker(args, work_queue))
tasks.append(task) worker_tasks.append(task)
# Wait for the queue to be fully processed # Wait for the queue to be fully processed
await work_queue.join() await work_queue.join()
# Cancel our worker tasks. # Cancel our worker tasks.
for task in tasks: for task in worker_tasks:
task.cancel() task.cancel()
# Wait until all worker tasks are cancelled. # Wait until all worker tasks are cancelled.
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*worker_tasks, return_exceptions=True)
# Wait for server to stop
sglang_server.cancel()
await sglang_server
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -24,8 +24,8 @@ logging.basicConfig(level=logging.INFO)
def parse_s3_path(s3_path: str) -> tuple[str, str]: def parse_s3_path(s3_path: str) -> tuple[str, str]:
if not s3_path.startswith('s3://'): if not (s3_path.startswith('s3://') or s3_path.startswith('gs://') or s3_path.startswith('weka://')):
raise ValueError('s3_path must start with s3://') raise ValueError('s3_path must start with s3://, gs://, or weka://')
parsed = urlparse(s3_path) parsed = urlparse(s3_path)
bucket = parsed.netloc bucket = parsed.netloc
key = parsed.path.lstrip('/') key = parsed.path.lstrip('/')
@ -137,10 +137,10 @@ def download_directory(model_choices: list[str], local_dir: str):
""" """
Download the model to a specified local directory. Download the model to a specified local directory.
The function will attempt to download from the first available source in the provided list. The function will attempt to download from the first available source in the provided list.
Supports Google Cloud Storage (gs://) and Amazon S3 (s3://) links. Supports Weka (weka://), Google Cloud Storage (gs://), and Amazon S3 (s3://) links.
Args: Args:
model_choices (list[str]): List of model paths (gs:// or s3://). model_choices (list[str]): List of model paths (weka://, gs://, or s3://).
local_dir (str): Local directory path where the model will be downloaded. local_dir (str): Local directory path where the model will be downloaded.
Raises: Raises:
@ -151,11 +151,20 @@ def download_directory(model_choices: list[str], local_dir: str):
local_path.mkdir(parents=True, exist_ok=True) local_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Local directory set to: {local_path}") logger.info(f"Local directory set to: {local_path}")
# Reorder model_choices to prioritize weka:// links
weka_choices = [path for path in model_choices if path.startswith("weka://")]
other_choices = [path for path in model_choices if not path.startswith("weka://")]
prioritized_choices = weka_choices + other_choices
# Iterate through the provided choices and attempt to download from the first available source # Iterate through the provided choices and attempt to download from the first available source
for model_path in model_choices: for model_path in prioritized_choices:
logger.info(f"Attempting to download from: {model_path}") logger.info(f"Attempting to download from: {model_path}")
try: try:
if model_path.startswith("gs://"): if model_path.startswith("weka://"):
download_dir_from_weka(model_path, str(local_path))
logger.info(f"Successfully downloaded model from Weka: {model_path}")
return
elif model_path.startswith("gs://"):
download_dir_from_gcs(model_path, str(local_path)) download_dir_from_gcs(model_path, str(local_path))
logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}") logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}")
return return
@ -229,3 +238,57 @@ def download_dir_from_s3(s3_path: str, local_dir: str):
pass pass
logger.info(f"Downloaded model from S3 to {local_dir}") logger.info(f"Downloaded model from S3 to {local_dir}")
def download_dir_from_weka(weka_path: str, local_dir: str):
"""Download model files from Weka to a local directory."""
# Retrieve Weka credentials from environment variables
weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID")
weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY")
if not weka_access_key or not weka_secret_key:
raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY environment variables must be set for Weka access.")
# Configure the boto3 client for Weka
weka_endpoint = "https://weka-aus.beaker.org:9000"
boto3_config = Config(
max_pool_connections=50, # Adjust this number based on your requirements
signature_version='s3v4',
retries={'max_attempts': 10, 'mode': 'standard'}
)
s3_client = boto3.client(
's3',
endpoint_url=weka_endpoint,
aws_access_key_id=weka_access_key,
aws_secret_access_key=weka_secret_key,
config=boto3_config
)
bucket, prefix = parse_s3_path(weka_path)
paginator = s3_client.get_paginator("list_objects_v2")
try:
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
except s3_client.exceptions.NoSuchBucket:
raise ValueError(f"The bucket '{bucket}' does not exist in Weka.")
objects = []
for page in pages:
if 'Contents' in page:
objects.extend(page['Contents'])
total_files = len(objects)
logger.info(f"Found {total_files} files in Weka bucket '{bucket}' with prefix '{prefix}'.")
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for obj in objects:
key = obj["Key"]
relative_path = os.path.relpath(key, prefix)
local_file_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path))
# Use tqdm to display progress
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from Weka"):
pass
logger.info(f"Downloaded model from Weka to {local_dir}")