From f8808478bdefafb218a1a91d7efd8921e4bf1314 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 29 Apr 2025 11:12:03 -0700 Subject: [PATCH] Adding some small changes to the tagging pipeline --- scripts/beaker/Dockerfile-tagging | 48 ++++++++++ scripts/tagging_pipeline.py | 143 ++++++++++++++---------------- 2 files changed, 116 insertions(+), 75 deletions(-) create mode 100644 scripts/beaker/Dockerfile-tagging diff --git a/scripts/beaker/Dockerfile-tagging b/scripts/beaker/Dockerfile-tagging new file mode 100644 index 0000000..142e131 --- /dev/null +++ b/scripts/beaker/Dockerfile-tagging @@ -0,0 +1,48 @@ +FROM --platform=linux/amd64 nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04 + +RUN apt-get update -y && apt-get install -y software-properties-common \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get -y update + +# Install requirements specific to pdfs +RUN apt-get update && apt-get -y install python3-apt +RUN echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections +RUN apt-get update -y && apt-get install -y poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools + +RUN apt-get update -y && apt-get install -y --no-install-recommends \ + git \ + python3.11 \ + python3.11-dev \ + python3.11-distutils \ + ca-certificates \ + build-essential \ + curl \ + unzip + +RUN rm -rf /var/lib/apt/lists/* \ + && unlink /usr/bin/python3 \ + && ln -s /usr/bin/python3.11 /usr/bin/python3 \ + && ln -s /usr/bin/python3 /usr/bin/python \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python \ + && pip3 install -U pip + +RUN apt-get update && apt-get -y install python3.11-venv +ADD --chmod=755 https://astral.sh/uv/install.sh /install.sh +RUN /install.sh && rm /install.sh + +ENV PYTHONUNBUFFERED=1 +WORKDIR /root +COPY pyproject.toml pyproject.toml +COPY olmocr/version.py olmocr/version.py + +RUN /root/.local/bin/uv pip install --system --no-cache -e . + +RUN /root/.local/bin/uv pip install --system --no-cache vllm==0.8.2 + +COPY olmocr olmocr + +WORKDIR /root +COPY olmocr olmocr + +RUN python3 -m vllm --help +RUN python3 -m olmocr.pipeline --help \ No newline at end of file diff --git a/scripts/tagging_pipeline.py b/scripts/tagging_pipeline.py index d9f8023..2d542fb 100644 --- a/scripts/tagging_pipeline.py +++ b/scripts/tagging_pipeline.py @@ -3,7 +3,7 @@ Tagging pipeline for Dolma JSONL datasets. For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder, -this script issues a simple SGLang completion per record (e.g., "Is this document in English?"), +this script issues a model prompt completion collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under scratch/attributes/, mirroring the input structure. """ @@ -28,7 +28,6 @@ from huggingface_hub import snapshot_download from pydantic import BaseModel, Field, ValidationError from olmocr.check import ( - check_sglang_version, check_torch_gpu_available, ) from olmocr.metrics import MetricsKeeper @@ -46,8 +45,8 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.propagate = False -sglang_logger = logging.getLogger("sglang") -sglang_logger.propagate = False +server_logger = logging.getLogger("vllm") +server_logger.propagate = False file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a") file_handler.setLevel(logging.DEBUG) @@ -60,11 +59,11 @@ console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(level # Add handlers to the logger logger.addHandler(file_handler) logger.addHandler(console_handler) -sglang_logger.addHandler(file_handler) +server_logger.addHandler(file_handler) # Default port; overridden by --port -SGLANG_SERVER_PORT = 30024 +SERVER_PORT = 30024 # Global variables for token statistics metrics = MetricsKeeper(window=60 * 5) @@ -81,8 +80,6 @@ async def _process_single_page(page_text: str) -> PIIClassification: """Helper function to process a single document or page.""" text = page_text - metrics.add_metrics(sglang_requests=1) - query = { "model": "google/gemma-3-4b-it", "messages": [ @@ -104,47 +101,49 @@ async def _process_single_page(page_text: str) -> PIIClassification: "response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}}, } - url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions" + url = f"http://localhost:{SERVER_PORT}/v1/chat/completions" # ---------- HTTP call --------------------------------------------------- try: status, body = await apost(url, json_data=query) except Exception as e: - logger.warning(f"SGLang network error: {e!s}") - metrics.add_metrics(sglang_errors=1) + logger.warning(f"Server network error: {e!s}") + metrics.add_metrics(server_errors=1) return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None) + metrics.add_metrics(server_requests=1) + if status != 200: - logger.warning(f"SGLang HTTP {status}: {body[:250]!r}") - metrics.add_metrics(sglang_errors=1) + logger.warning(f"Server HTTP {status}: {body[:250]!r}") + metrics.add_metrics(server_errors=1) return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None) # ---------- Parse base JSON -------------------------------------------- try: base = json.loads(body) except json.JSONDecodeError: - logger.warning(f"SGLang response is not valid JSON: {body[:250]!r}") - metrics.add_metrics(sglang_errors=1) + logger.warning(f"Server response is not valid JSON: {body[:250]!r}") + metrics.add_metrics(server_errors=1) return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None) # Token accounting if available if "usage" in base: metrics.add_metrics( - sglang_input_tokens=base["usage"].get("prompt_tokens", 0), - sglang_output_tokens=base["usage"].get("completion_tokens", 0), + server_input_tokens=base["usage"].get("prompt_tokens", 0), + server_output_tokens=base["usage"].get("completion_tokens", 0), ) # ---------- Extract the model message ---------------------------------- try: content = base["choices"][0]["message"].get("content") except (KeyError, IndexError, AttributeError) as e: - logger.warning(f"Missing fields in SGLang response: {e!s}") - metrics.add_metrics(sglang_errors=1) + logger.warning(f"Missing fields in Server response: {e!s}") + metrics.add_metrics(server_errors=1) return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None) if not isinstance(content, str): - logger.warning("SGLang `content` is not a string; treating as error.") - metrics.add_metrics(sglang_errors=1) + logger.warning("Server `content` is not a string; treating as error.") + metrics.add_metrics(server_errors=1) return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None) try: @@ -152,7 +151,7 @@ async def _process_single_page(page_text: str) -> PIIClassification: return pii_classification except ValidationError as e: logger.warning(f"Unable to parse pii classification object: {e!s}") - metrics.add_metrics(sglang_errors=1) + metrics.add_metrics(server_errors=1) return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None) @@ -223,7 +222,7 @@ async def apost(url, json_data): async def process_dolma_document(args, dolma_doc, sem): """ - Query SGLang to detect PII, enforcing a JSON schema. + Query model to detect PII, enforcing a JSON schema. Resilient to: • Transport / HTTP errors @@ -236,9 +235,10 @@ async def process_dolma_document(args, dolma_doc, sem): doc_id = dolma_doc.get("id") text = dolma_doc.get("text", "") or "" - key_name = f"{args.model.replace('/', '_')}_pii_classification" + language_key_name = f"{args.model.replace('/', '_')}_language" + resume_cv_key_name = f"{args.model.replace('/', '_')}_is_resume_cv" - result_attributes = {key_name: []} + result_attributes = {resume_cv_key_name: [], language_key_name: []} # If pdf_page_numbers is present, split the text and process each page separately if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]: @@ -248,11 +248,15 @@ async def process_dolma_document(args, dolma_doc, sem): # Filter pages down to actual real content selected_page_numbers = [tuple(p) for p in page_numbers if p[0] < p[1]] + first_page_number = selected_page_numbers[0] - # Sample 3 pages max per document + # Sample 3 pages max per document, but always include the first page, it's a good signal for CV classification random.shuffle(selected_page_numbers) selected_page_numbers = selected_page_numbers[:3] + if first_page_number not in selected_page_numbers: + selected_page_numbers[0] = first_page_number + for start_pos, end_pos, page_num in page_numbers: if (start_pos, end_pos, page_num) in selected_page_numbers: page_text = text[start_pos:end_pos] @@ -261,9 +265,11 @@ async def process_dolma_document(args, dolma_doc, sem): async with sem: pii_class = await _process_single_page(page_text) - result_attributes[key_name].append([start_pos, end_pos, pii_class.is_resume_cv]) + result_attributes[resume_cv_key_name].append([start_pos, end_pos, pii_class.is_resume_cv]) + result_attributes[language_key_name].append([start_pos, end_pos, pii_class.primary_language]) else: - result_attributes[key_name].append([start_pos, end_pos, None]) + result_attributes[resume_cv_key_name].append([start_pos, end_pos, None]) + result_attributes[language_key_name].append([start_pos, end_pos, None]) return result_attributes else: @@ -272,7 +278,7 @@ async def process_dolma_document(args, dolma_doc, sem): async def process_file(args, worker_id: int, file_uri: str): """ - Download a JSONL file, query SGLang per record, and collect attributes. + Download a JSONL file, query model per record, and collect attributes. """ # Fetch raw bytes (S3 or local) if file_uri.startswith("s3://"): @@ -293,8 +299,8 @@ async def process_file(args, worker_id: int, file_uri: str): lines = file_bytes.decode("utf-8").splitlines() page_tasks = {} - # Send all records in parallel, max 500 queued at a time - sem = asyncio.Semaphore(500) + # Send all records in parallel, max N queued at a time + sem = asyncio.Semaphore(args.parallel_requests) async with asyncio.TaskGroup() as tg: for line in lines: @@ -302,7 +308,7 @@ async def process_file(args, worker_id: int, file_uri: str): task = tg.create_task(process_dolma_document(args, dolma_doc, sem)) page_tasks[dolma_doc["id"]] = (task, dolma_doc) - logger.info(f"Started taskgroup with {len(page_tasks)} items for {file_uri}") + logger.info(f"Finished taskgroup with {len(page_tasks)} items for {file_uri}") # Collect results and build attributes attributes = [] @@ -389,21 +395,19 @@ async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, work semaphore.release() -async def sglang_server_task(model_name_or_path, args, semaphore): +async def server_task(model_name_or_path, args, semaphore): # Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory # mem_fraction_arg = ["--mem-fraction-static", "0.80"] cmd = [ - "python3", - "-m", - "sglang.launch_server", - "--model-path", + "vllm", + "serve", model_name_or_path, "--port", - str(SGLANG_SERVER_PORT), - "--log-level-http", + str(SERVER_PORT), + "--uvicorn-log-level", "warning", - "--mem-fraction-static", "0.40" + "--disable-log-requests", ] proc = await asyncio.create_subprocess_exec( @@ -425,34 +429,25 @@ async def sglang_server_task(model_name_or_path, args, semaphore): async def process_line(line): nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message - sglang_logger.info(line) + server_logger.info(line) # if the server hasn't initialized yet, log all the lines to the main logger also, so that the user # can see any warnings/errors more easily if not server_printed_ready_message: logger.info(line) - if "Detected errors during sampling" in line: - logger.error("Cannot continue, sampling errors detected, model is probably corrupt") - sys.exit(1) - - # TODO, need to trace down this issue in sglang itself, but it will otherwise cause the server to lock up - if "IndexError: list index out of range" in line: - logger.error("IndexError in model, restarting server") - proc.terminate() - if not server_printed_ready_message and "The server is fired up and ready to roll!" in line: server_printed_ready_message = True last_semaphore_release = time.time() - match = re.search(r"#running-req: (\d+)", line) + match = re.search(r"Running: (\d+) reqs", line) if match: last_running_req = int(match.group(1)) - match = re.search(r"#queue-req: (\d+)", line) + match = re.search(r"Waiting: (\d+) reqs", line) if match: last_queue_req = int(match.group(1)) - logger.info(f"sglang running req: {last_running_req} queue req: {last_queue_req}") + logger.info(f"running req: {last_running_req} queue req: {last_queue_req}") async def read_stream(stream): while True: @@ -485,7 +480,7 @@ async def sglang_server_task(model_name_or_path, args, semaphore): try: await proc.wait() except asyncio.CancelledError: - logger.info("Got cancellation request for SGLang server") + logger.info("Got cancellation request for server") proc.terminate() raise @@ -493,28 +488,26 @@ async def sglang_server_task(model_name_or_path, args, semaphore): await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True) -async def sglang_server_host(model_name_or_path, args, semaphore): +async def server_host(model_name_or_path, args, semaphore): MAX_RETRIES = 5 retry = 0 - await asyncio.sleep(1000000) - while retry < MAX_RETRIES: - await sglang_server_task(model_name_or_path, args, semaphore) - logger.warning("SGLang server task ended") + await server_task(model_name_or_path, args, semaphore) + logger.warning("Server task ended") retry += 1 if retry >= MAX_RETRIES: - logger.error(f"Ended up starting the sglang server more than {retry} times, cancelling pipeline") + logger.error(f"Ended up starting the server more than {retry} times, cancelling pipeline") logger.error("") - logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html") + logger.error("Please make sure vllm is installed according to the latest instructions for 0.8.4") sys.exit(1) -async def sglang_server_ready(): +async def check_server_ready(): max_attempts = 300 delay_sec = 1 - url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/models" + url = f"http://localhost:{SERVER_PORT}/v1/models" for attempt in range(1, max_attempts + 1): try: @@ -522,16 +515,16 @@ async def sglang_server_ready(): response = await session.get(url) if response.status_code == 200: - logger.info("sglang server is ready.") + logger.info("server is ready.") return else: logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}") except Exception: - logger.warning(f"Attempt {attempt}: Please wait for sglang server to become ready...") + logger.warning(f"Attempt {attempt}: Please wait for model server to become ready...") await asyncio.sleep(delay_sec) - raise Exception("sglang server did not become ready after waiting.") + raise Exception("model server did not become ready after waiting.") async def download_model(model_name_or_path: str): @@ -669,7 +662,8 @@ async def main(): parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder") parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)") parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers") - parser.add_argument("--model", default="google/gemma-3-4b-it", help="SGLang model path or name") + parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model") + parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format") parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming") # Beaker/job running stuff @@ -683,17 +677,17 @@ async def main(): parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run") parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job") - parser.add_argument("--port", type=int, default=30024, help="Port for SGLang server") + parser.add_argument("--port", type=int, default=30024, help="Port for Model server") args = parser.parse_args() - global SGLANG_SERVER_PORT, workspace_s3, dataset_s3 - SGLANG_SERVER_PORT = args.port + global SERVER_PORT, workspace_s3, dataset_s3 + SERVER_PORT = args.port workspace_s3 = boto3.client("s3") dataset_s3 = boto3.client("s3") # setup the job to work in beaker environment, load secrets, adjust logging, etc. if "BEAKER_JOB_ID" in os.environ: - sglang_logger.addHandler(console_handler) + server_logger.addHandler(console_handler) if "AWS_CREDENTIALS_FILE" in os.environ: cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") os.makedirs(os.path.dirname(cred_path), exist_ok=True) @@ -742,7 +736,6 @@ async def main(): return # If you get this far, then you are doing inference and need a GPU - check_sglang_version() check_torch_gpu_available() logger.info(f"Starting pipeline with PID {os.getpid()}") @@ -763,9 +756,9 @@ async def main(): # As soon as one worker is no longer saturating the gpu, the next one can start sending requests semaphore = asyncio.Semaphore(1) - sglang_server = asyncio.create_task(sglang_server_host(model_name_or_path, args, semaphore)) + model_server = asyncio.create_task(server_host(model_name_or_path, args, semaphore)) - await sglang_server_ready() + await check_server_ready() metrics_task = asyncio.create_task(metrics_reporter(work_queue)) @@ -778,7 +771,7 @@ async def main(): # Wait for all worker tasks to finish await asyncio.gather(*worker_tasks) - sglang_server.cancel() + model_server.cancel() metrics_task.cancel() logger.info("Work done")