mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
636 lines
24 KiB
Python
636 lines
24 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Tagging pipeline for Dolma JSONL datasets.
|
|
|
|
For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder,
|
|
this script issues a simple SGLang completion per record (e.g., "Is this document in English?"),
|
|
collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under
|
|
scratch/attributes/, mirroring the input structure.
|
|
"""
|
|
import argparse
|
|
import asyncio
|
|
import atexit
|
|
import gzip
|
|
import json
|
|
import logging
|
|
import multiprocessing
|
|
import os
|
|
import random
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
|
from concurrent.futures.process import BrokenProcessPool
|
|
from dataclasses import dataclass
|
|
from functools import cache, partial
|
|
from io import BytesIO
|
|
from urllib.parse import urlparse
|
|
import zstandard as zstd
|
|
|
|
import boto3
|
|
import httpx
|
|
import torch
|
|
from botocore.exceptions import ClientError
|
|
from huggingface_hub import snapshot_download
|
|
from PIL import Image
|
|
from pypdf import PdfReader
|
|
from tqdm import tqdm
|
|
|
|
from olmocr.check import (
|
|
check_poppler_version,
|
|
check_sglang_version,
|
|
check_torch_gpu_available,
|
|
)
|
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
|
from olmocr.filter.filter import Language, PdfFilter
|
|
from olmocr.image_utils import convert_image_to_pdf_bytes, is_jpeg, is_png
|
|
from olmocr.metrics import MetricsKeeper, WorkerTracker
|
|
from olmocr.prompts import PageResponse, build_finetuning_prompt
|
|
from olmocr.prompts.anchor import get_anchor_text
|
|
from olmocr.s3_utils import (
|
|
download_directory,
|
|
download_zstd_csv,
|
|
expand_s3_glob,
|
|
get_s3_bytes,
|
|
get_s3_bytes_with_backoff,
|
|
parse_s3_path,
|
|
)
|
|
from olmocr.version import VERSION
|
|
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
|
|
|
|
|
|
# Initialize logger
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
logger.propagate = False
|
|
|
|
sglang_logger = logging.getLogger("sglang")
|
|
sglang_logger.propagate = False
|
|
|
|
file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
|
|
file_handler.setLevel(logging.DEBUG)
|
|
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
|
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setLevel(logging.INFO)
|
|
console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
|
|
|
# Add handlers to the logger
|
|
logger.addHandler(file_handler)
|
|
logger.addHandler(console_handler)
|
|
sglang_logger.addHandler(file_handler)
|
|
|
|
|
|
# Default port; overridden by --port
|
|
SGLANG_SERVER_PORT = 30024
|
|
|
|
# Global variables for token statistics
|
|
metrics = MetricsKeeper(window=60 * 5)
|
|
tracker = WorkerTracker()
|
|
|
|
|
|
# Process pool for offloading cpu bound work, like calculating anchor texts, max 32 workers, otherwise it can spawn way too many workers on a big machine
|
|
process_pool = ProcessPoolExecutor(max_workers=min(multiprocessing.cpu_count() // 2 + 1, 32), mp_context=multiprocessing.get_context("spawn"))
|
|
|
|
|
|
async def process_dolma_document(dolma_doc):
|
|
"""
|
|
Send the text to SGLang server to classify PII presence.
|
|
Returns tuple (doc_id, contains_pii, text_length).
|
|
"""
|
|
query = {
|
|
"model": "google/gemma-3-4b-it",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": (
|
|
f"{dolma_doc['text']}\n\n-----------\n"
|
|
"Given the text above, does it contain any Personally Identifiable Information (PII)? "
|
|
"Answer in a single JSON object with a single field named 'contains_pii' that's a bool."
|
|
)
|
|
}
|
|
],
|
|
}
|
|
],
|
|
"temperature": 0.0,
|
|
}
|
|
async with httpx.AsyncClient() as client:
|
|
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
|
resp = await client.post(url, json=query)
|
|
resp.raise_for_status()
|
|
response_json = resp.json()
|
|
|
|
# Extract the JSON content from the model's response
|
|
content = (
|
|
response_json.get('choices', [])[0]
|
|
.get('message', {})
|
|
.get('content', '')
|
|
)
|
|
try:
|
|
result = json.loads(content)
|
|
contains_pii = bool(result.get('contains_pii', False))
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse JSON from SGLang response: {content}")
|
|
contains_pii = False
|
|
|
|
text_length = len(dolma_doc.get('text', ''))
|
|
return dolma_doc.get('id'), contains_pii, text_length
|
|
|
|
async def process_file(args, worker_id: int, file_uri: str):
|
|
"""
|
|
Download a JSONL file, query SGLang per record, and collect attributes.
|
|
"""
|
|
# Fetch raw bytes (S3 or local)
|
|
if file_uri.startswith("s3://"):
|
|
raw = await asyncio.to_thread(get_s3_bytes_with_backoff, dataset_s3, file_uri)
|
|
else:
|
|
with open(file_uri, 'rb') as f:
|
|
raw = f.read()
|
|
|
|
# Decompress if needed
|
|
if file_uri.endswith('.gz'):
|
|
file_bytes = gzip.decompress(raw)
|
|
elif file_uri.endswith('.ztd') or file_uri.endswith('.zst') or file_uri.endswith('.zstd'):
|
|
dctx = zstd.ZstdDecompressor()
|
|
file_bytes = dctx.decompress(raw, max_output_size=1_000_000_000)
|
|
else:
|
|
file_bytes = raw
|
|
|
|
lines = file_bytes.decode('utf-8').splitlines()
|
|
page_tasks = {}
|
|
|
|
# Send all records in parallel
|
|
async with asyncio.TaskGroup() as tg:
|
|
for line in lines:
|
|
data = json.loads(line)
|
|
task = tg.create_task(process_dolma_document(data))
|
|
page_tasks[data['id']] = (task, data)
|
|
|
|
# Collect results and build attributes
|
|
attributes = []
|
|
key_name = f"{args.model.replace('/', '_')}_pii_classification"
|
|
for doc_id, (task, data) in page_tasks.items():
|
|
_, contains_pii, text_length = task.result()
|
|
score_or_flag = 1.0 if contains_pii else False
|
|
span = [0, text_length, score_or_flag]
|
|
attributes.append({
|
|
"id": doc_id,
|
|
"attributes": { key_name: [span] }
|
|
})
|
|
|
|
return attributes
|
|
|
|
|
|
async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
|
|
while True:
|
|
await semaphore.acquire()
|
|
work_item = await work_queue.get_work()
|
|
if work_item is None:
|
|
logger.info(f"Worker {worker_id} exiting due to empty queue")
|
|
semaphore.release()
|
|
break
|
|
|
|
file_uri = work_item.work_paths[0]
|
|
logger.info(f"Worker {worker_id} processing work item {file_uri}")
|
|
await tracker.clear_work(worker_id)
|
|
|
|
try:
|
|
attrs = await process_file(args, worker_id, file_uri)
|
|
|
|
logger.info("Got attrs", attrs)
|
|
|
|
# Write out attributes JSONL to scratch/attributes/... mirroring input structure
|
|
if file_uri.startswith('s3://'):
|
|
_, key = parse_s3_path(file_uri)
|
|
# assume args.dataset is s3://bucket/prefix
|
|
_, docs_prefix = parse_s3_path(args.dataset)
|
|
rel_path = key[len(os.path.join(docs_prefix, 'documents/')):]
|
|
else:
|
|
docs_root = os.path.join(args.dataset, 'documents')
|
|
rel_path = os.path.relpath(file_uri, docs_root)
|
|
|
|
out_rel = os.path.join('attributes', rel_path)
|
|
out_jsonl = '\n'.join(json.dumps(x) for x in attrs) + '\n'
|
|
|
|
if args.scratch.startswith('s3://'):
|
|
out_bucket, out_prefix = parse_s3_path(args.scratch)
|
|
out_key = os.path.join(out_prefix, out_rel)
|
|
workspace_s3.put_object(Bucket=out_bucket, Key=out_key,
|
|
Body=out_jsonl.encode('utf-8'))
|
|
else:
|
|
out_path = os.path.join(args.scratch, out_rel)
|
|
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
|
with open(out_path, 'w', encoding='utf-8') as f:
|
|
f.write(out_jsonl)
|
|
|
|
await work_queue.mark_done(work_item)
|
|
except Exception as e:
|
|
logger.exception(f"Exception occurred while processing work item {work_item.hash}: {e}")
|
|
finally:
|
|
semaphore.release()
|
|
|
|
|
|
async def sglang_server_task(model_name_or_path, args, semaphore):
|
|
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB
|
|
mem_fraction_arg = ["--mem-fraction-static", "0.80"] if gpu_memory < 60 else []
|
|
|
|
cmd = [
|
|
"python3",
|
|
"-m",
|
|
"sglang.launch_server",
|
|
"--model-path",
|
|
model_name_or_path,
|
|
"--chat-template",
|
|
args.model_chat_template,
|
|
# "--context-length", str(args.model_max_context), # Commented out due to crashes
|
|
"--port",
|
|
str(SGLANG_SERVER_PORT),
|
|
"--log-level-http",
|
|
"warning",
|
|
]
|
|
cmd.extend(mem_fraction_arg)
|
|
|
|
proc = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
|
|
# Ensure the subprocess is terminated on exit
|
|
def _kill_proc():
|
|
proc.terminate()
|
|
|
|
atexit.register(_kill_proc)
|
|
|
|
# Shared variables between tasks
|
|
last_running_req, last_queue_req = 0, 0
|
|
server_printed_ready_message = False
|
|
last_semaphore_release = time.time()
|
|
|
|
async def process_line(line):
|
|
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
|
|
sglang_logger.info(line)
|
|
|
|
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
|
|
# can see any warnings/errors more easily
|
|
if not server_printed_ready_message:
|
|
logger.info(line)
|
|
|
|
if "Detected errors during sampling" in line:
|
|
logger.error("Cannot continue, sampling errors detected, model is probably corrupt")
|
|
sys.exit(1)
|
|
|
|
# TODO, need to trace down this issue in sglang itself, but it will otherwise cause the server to lock up
|
|
if "IndexError: list index out of range" in line:
|
|
logger.error("IndexError in model, restarting server")
|
|
proc.terminate()
|
|
|
|
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
|
|
server_printed_ready_message = True
|
|
last_semaphore_release = time.time()
|
|
|
|
match = re.search(r"#running-req: (\d+)", line)
|
|
if match:
|
|
last_running_req = int(match.group(1))
|
|
|
|
match = re.search(r"#queue-req: (\d+)", line)
|
|
if match:
|
|
last_queue_req = int(match.group(1))
|
|
logger.info(f"sglang running req: {last_running_req} queue req: {last_queue_req}")
|
|
|
|
async def read_stream(stream):
|
|
while True:
|
|
line = await stream.readline()
|
|
if not line:
|
|
break
|
|
try:
|
|
line = line.decode("utf-8").rstrip()
|
|
await process_line(line)
|
|
except Exception as ex:
|
|
logger.warning(f"Got {ex} when reading log line from inference server, skipping")
|
|
|
|
async def timeout_task():
|
|
nonlocal last_running_req, last_queue_req, last_semaphore_release
|
|
try:
|
|
while True:
|
|
await asyncio.sleep(1)
|
|
if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
|
|
semaphore.release()
|
|
last_semaphore_release = time.time()
|
|
logger.info("Semaphore released, allowing a worker to proceed.")
|
|
except asyncio.CancelledError:
|
|
pass # Clean up if the task is cancelled
|
|
|
|
# Start tasks to read stdout, stderr, and handle timeout logic
|
|
stdout_task = asyncio.create_task(read_stream(proc.stdout))
|
|
stderr_task = asyncio.create_task(read_stream(proc.stderr))
|
|
timeout_task = asyncio.create_task(timeout_task())
|
|
|
|
try:
|
|
await proc.wait()
|
|
except asyncio.CancelledError:
|
|
logger.info("Got cancellation request for SGLang server")
|
|
proc.terminate()
|
|
raise
|
|
|
|
timeout_task.cancel()
|
|
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
|
|
|
|
|
|
async def sglang_server_host(model_name_or_path, args, semaphore):
|
|
MAX_RETRIES = 5
|
|
retry = 0
|
|
|
|
while retry < MAX_RETRIES:
|
|
await sglang_server_task(model_name_or_path, args, semaphore)
|
|
logger.warning("SGLang server task ended")
|
|
retry += 1
|
|
|
|
if retry >= MAX_RETRIES:
|
|
logger.error(f"Ended up starting the sglang server more than {retry} times, cancelling pipeline")
|
|
logger.error("")
|
|
logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
|
|
sys.exit(1)
|
|
|
|
|
|
async def sglang_server_ready():
|
|
max_attempts = 300
|
|
delay_sec = 1
|
|
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/models"
|
|
|
|
for attempt in range(1, max_attempts + 1):
|
|
try:
|
|
async with httpx.AsyncClient() as session:
|
|
response = await session.get(url)
|
|
|
|
if response.status_code == 200:
|
|
logger.info("sglang server is ready.")
|
|
return
|
|
else:
|
|
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
|
|
except Exception:
|
|
logger.warning(f"Attempt {attempt}: Please wait for sglang server to become ready...")
|
|
|
|
await asyncio.sleep(delay_sec)
|
|
|
|
raise Exception("sglang server did not become ready after waiting.")
|
|
|
|
|
|
async def download_model(model_name_or_path: str):
|
|
if model_name_or_path.startswith("s3://") or model_name_or_path.startswith("gs://") or model_name_or_path.startswith("weka://"):
|
|
logger.info(f"Downloading model directory from '{model_name_or_path}'")
|
|
model_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "olmocr", "model")
|
|
download_directory([model_name_or_path], model_cache_dir)
|
|
return model_cache_dir
|
|
elif os.path.isabs(model_name_or_path) and os.path.isdir(model_name_or_path):
|
|
logger.info(f"Using local model path at '{model_name_or_path}'")
|
|
return model_name_or_path
|
|
else:
|
|
logger.info(f"Downloading model with hugging face '{model_name_or_path}'")
|
|
snapshot_download(repo_id=model_name_or_path)
|
|
return model_name_or_path
|
|
|
|
|
|
async def metrics_reporter(work_queue):
|
|
while True:
|
|
# Leading newlines preserve table formatting in logs
|
|
logger.info(f"Queue remaining: {work_queue.size}")
|
|
logger.info("\n" + str(metrics))
|
|
logger.info("\n" + str(await tracker.get_status_table()))
|
|
await asyncio.sleep(10)
|
|
|
|
|
|
def submit_beaker_job(args):
|
|
from beaker import ( # type: ignore
|
|
Beaker,
|
|
Constraints,
|
|
EnvVar,
|
|
ExperimentSpec,
|
|
ImageSource,
|
|
Priority,
|
|
ResultSpec,
|
|
SecretNotFound,
|
|
TaskContext,
|
|
TaskResources,
|
|
TaskSpec,
|
|
)
|
|
|
|
b = Beaker.from_env(default_workspace=args.beaker_workspace)
|
|
account = b.account.whoami()
|
|
owner = account.name
|
|
beaker_image = f"jakep/olmocr-inference-{VERSION}"
|
|
|
|
task_name = f"olmocr-{os.path.basename(args.dataset.rstrip('/'))}"
|
|
|
|
# Take out --beaker flag so the workers will just run things
|
|
args_list = [arg for arg in sys.argv[1:] if arg != "--beaker"]
|
|
|
|
# Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally
|
|
args_list = [arg for i, arg in enumerate(args_list) if not (arg.startswith("--pdfs") or (i > 0 and args_list[i - 1] == "--pdfs"))]
|
|
|
|
try:
|
|
b.secret.get(f"{owner}-WEKA_ACCESS_KEY_ID", args.beaker_workspace)
|
|
b.secret.get(f"{owner}-WEKA_SECRET_ACCESS_KEY", args.beaker_workspace)
|
|
b.secret.get(f"{owner}-AWS_CREDENTIALS_FILE", args.beaker_workspace)
|
|
except SecretNotFound:
|
|
print(
|
|
f"Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace {args.beaker_workspace}? [y/n]"
|
|
)
|
|
|
|
if input().strip().lower() != "y":
|
|
print("Exiting...")
|
|
sys.exit(1)
|
|
|
|
b.secret.write(f"{owner}-WEKA_ACCESS_KEY_ID", os.environ.get("WEKA_ACCESS_KEY_ID", ""), args.beaker_workspace)
|
|
b.secret.write(f"{owner}-WEKA_SECRET_ACCESS_KEY", os.environ.get("WEKA_SECRET_ACCESS_KEY", ""), args.beaker_workspace)
|
|
b.secret.write(
|
|
f"{owner}-AWS_CREDENTIALS_FILE",
|
|
open(os.path.join(os.path.expanduser("~"), ".aws", "credentials")).read(),
|
|
args.beaker_workspace,
|
|
)
|
|
|
|
env_var_secrets = [
|
|
EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
|
|
EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
|
|
EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
|
|
]
|
|
|
|
try:
|
|
b.secret.get("OLMOCR_PREVIEW_HF_TOKEN", args.beaker_workspace)
|
|
env_var_secrets.append(EnvVar(name="HF_TOKEN", secret="OLMOCR_PREVIEW_HF_TOKEN"))
|
|
except SecretNotFound:
|
|
pass
|
|
|
|
try:
|
|
b.secret.get("OE_DATA_GCS_SA_KEY", args.beaker_workspace)
|
|
env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
|
|
except SecretNotFound:
|
|
print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):")
|
|
lines = []
|
|
prev_empty = False
|
|
for line in iter(input, None):
|
|
if not line and prev_empty:
|
|
break
|
|
prev_empty = not line
|
|
lines.append(line)
|
|
gcs_sa_key = "\n".join(lines[:-1]).strip() # Remove the last empty line
|
|
if gcs_sa_key:
|
|
b.secret.write("OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
|
|
env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
|
|
|
|
# Create the experiment spec
|
|
experiment_spec = ExperimentSpec(
|
|
budget="ai2/oe-data",
|
|
description=task_name,
|
|
tasks=[
|
|
TaskSpec(
|
|
name=task_name,
|
|
propagate_failure=False,
|
|
propagate_preemption=False,
|
|
replicas=args.beaker_gpus,
|
|
context=TaskContext(
|
|
priority=Priority(args.beaker_priority),
|
|
preemptible=True,
|
|
),
|
|
image=ImageSource(beaker=beaker_image),
|
|
command=["python", "-m", "scripts/tagging_pipeline.py"] + args_list,
|
|
env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
|
|
resources=TaskResources(gpu_count=1),
|
|
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
|
|
result=ResultSpec(path="/noop-results"),
|
|
)
|
|
],
|
|
)
|
|
|
|
experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace)
|
|
|
|
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
|
|
|
|
|
|
async def main():
|
|
parser = argparse.ArgumentParser(description="Tagging pipeline for Dolma JSONL dataset")
|
|
parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder")
|
|
parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)")
|
|
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
|
|
parser.add_argument("--model", default="google/gemma-3-4b-it", help="SGLang model path or name")
|
|
|
|
# Beaker/job running stuff
|
|
parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
|
|
parser.add_argument("--beaker_workspace", help="Beaker workspace to submit to", default="ai2/olmocr")
|
|
parser.add_argument(
|
|
"--beaker_cluster",
|
|
help="Beaker clusters you want to run on",
|
|
default=["ai2/jupiter-cirrascale-2", "ai2/ceres-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"],
|
|
)
|
|
parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
|
|
parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
|
|
|
|
parser.add_argument("--port", type=int, default=30024, help="Port for SGLang server")
|
|
args = parser.parse_args()
|
|
|
|
global SGLANG_SERVER_PORT, workspace_s3, dataset_s3
|
|
SGLANG_SERVER_PORT = args.port
|
|
workspace_s3 = boto3.client("s3")
|
|
dataset_s3 = boto3.client("s3")
|
|
|
|
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
|
|
if "BEAKER_JOB_NAME" in os.environ:
|
|
sglang_logger.addHandler(console_handler)
|
|
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
|
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
|
|
with open(cred_path, "w") as f:
|
|
f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
|
|
cred_path = os.path.join(os.path.expanduser("~"), ".gcs", "credentials")
|
|
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
|
|
with open(cred_path, "w") as f:
|
|
f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE"))
|
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
|
|
workspace_s3 = boto3.client("s3")
|
|
pdf_s3 = boto3.client("s3")
|
|
|
|
# Wait a little bit so that not all beaker jobs in a task start at the same time and download the model at the same time
|
|
replica_count = int(os.environ.get("BEAKER_REPLICA_COUNT", "1"))
|
|
interval = 10 if (replica_count - 1) * 10 <= 240 else 240 / max(1, replica_count - 1)
|
|
sleep_time = int(int(os.environ.get("BEAKER_REPLICA_RANK", "0")) * interval)
|
|
logger.info(f"Beaker job sleeping for {sleep_time} seconds to stagger model downloads")
|
|
await asyncio.sleep(sleep_time)
|
|
|
|
# Initialize work queue
|
|
if args.scratch.startswith("s3://"):
|
|
work_queue = S3WorkQueue(workspace_s3, args.scratch)
|
|
else:
|
|
work_queue = LocalWorkQueue(args.scratch)
|
|
|
|
# Discover input files
|
|
files = set()
|
|
if args.dataset.startswith("s3://"):
|
|
pattern = args.dataset.rstrip("/") + "/documents/*.jsonl*"
|
|
matched = expand_s3_glob(dataset_s3, pattern)
|
|
files = set(matched.keys())
|
|
else:
|
|
docs_dir = os.path.join(args.dataset, "documents")
|
|
for root, _, fns in os.walk(docs_dir):
|
|
for fn in fns:
|
|
if fn.endswith((".jsonl", ".jsonl.gz", ".jsonl.ztd")):
|
|
files.add(os.path.join(root, fn))
|
|
|
|
# Populate the work queue if needed
|
|
await work_queue.populate_queue(list(files), items_per_group=1)
|
|
|
|
if args.beaker:
|
|
submit_beaker_job(args)
|
|
return
|
|
|
|
# If you get this far, then you are doing inference and need a GPU
|
|
# check_sglang_version()
|
|
# check_torch_gpu_available()
|
|
|
|
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
|
|
|
# Download the model before you do anything else
|
|
model_name_or_path = await download_model(args.model)
|
|
|
|
# Initialize the work queue
|
|
qsize = await work_queue.initialize_queue()
|
|
|
|
if qsize == 0:
|
|
logger.info("No work to do, exiting")
|
|
return
|
|
|
|
# Create a semaphore to control worker access
|
|
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
|
|
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
|
|
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
|
semaphore = asyncio.Semaphore(1)
|
|
|
|
# sglang_server = asyncio.create_task(sglang_server_host(model_name_or_path, args, semaphore))
|
|
|
|
# await sglang_server_ready()
|
|
|
|
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
|
|
|
# Create worker tasks to process the queue concurrently.
|
|
worker_tasks = []
|
|
for i in range(args.workers):
|
|
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
|
|
worker_tasks.append(task)
|
|
|
|
# Wait for all worker tasks to finish
|
|
await asyncio.gather(*worker_tasks)
|
|
|
|
# Wait for server to stop
|
|
process_pool.shutdown(wait=False)
|
|
|
|
# sglang_server.cancel()
|
|
metrics_task.cancel()
|
|
logger.info("Work done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |