Adding some small changes to the tagging pipeline

This commit is contained in:
Jake Poznanski 2025-04-29 11:12:03 -07:00
parent 66d293c178
commit f8808478bd
2 changed files with 116 additions and 75 deletions

View File

@ -0,0 +1,48 @@
FROM --platform=linux/amd64 nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04
RUN apt-get update -y && apt-get install -y software-properties-common \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get -y update
# Install requirements specific to pdfs
RUN apt-get update && apt-get -y install python3-apt
RUN echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections
RUN apt-get update -y && apt-get install -y poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
RUN apt-get update -y && apt-get install -y --no-install-recommends \
git \
python3.11 \
python3.11-dev \
python3.11-distutils \
ca-certificates \
build-essential \
curl \
unzip
RUN rm -rf /var/lib/apt/lists/* \
&& unlink /usr/bin/python3 \
&& ln -s /usr/bin/python3.11 /usr/bin/python3 \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python \
&& pip3 install -U pip
RUN apt-get update && apt-get -y install python3.11-venv
ADD --chmod=755 https://astral.sh/uv/install.sh /install.sh
RUN /install.sh && rm /install.sh
ENV PYTHONUNBUFFERED=1
WORKDIR /root
COPY pyproject.toml pyproject.toml
COPY olmocr/version.py olmocr/version.py
RUN /root/.local/bin/uv pip install --system --no-cache -e .
RUN /root/.local/bin/uv pip install --system --no-cache vllm==0.8.2
COPY olmocr olmocr
WORKDIR /root
COPY olmocr olmocr
RUN python3 -m vllm --help
RUN python3 -m olmocr.pipeline --help

View File

@ -3,7 +3,7 @@
Tagging pipeline for Dolma JSONL datasets.
For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder,
this script issues a simple SGLang completion per record (e.g., "Is this document in English?"),
this script issues a model prompt completion
collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under
scratch/attributes/, mirroring the input structure.
"""
@ -28,7 +28,6 @@ from huggingface_hub import snapshot_download
from pydantic import BaseModel, Field, ValidationError
from olmocr.check import (
check_sglang_version,
check_torch_gpu_available,
)
from olmocr.metrics import MetricsKeeper
@ -46,8 +45,8 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.propagate = False
sglang_logger = logging.getLogger("sglang")
sglang_logger.propagate = False
server_logger = logging.getLogger("vllm")
server_logger.propagate = False
file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
file_handler.setLevel(logging.DEBUG)
@ -60,11 +59,11 @@ console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(level
# Add handlers to the logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)
sglang_logger.addHandler(file_handler)
server_logger.addHandler(file_handler)
# Default port; overridden by --port
SGLANG_SERVER_PORT = 30024
SERVER_PORT = 30024
# Global variables for token statistics
metrics = MetricsKeeper(window=60 * 5)
@ -81,8 +80,6 @@ async def _process_single_page(page_text: str) -> PIIClassification:
"""Helper function to process a single document or page."""
text = page_text
metrics.add_metrics(sglang_requests=1)
query = {
"model": "google/gemma-3-4b-it",
"messages": [
@ -104,47 +101,49 @@ async def _process_single_page(page_text: str) -> PIIClassification:
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
}
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
url = f"http://localhost:{SERVER_PORT}/v1/chat/completions"
# ---------- HTTP call ---------------------------------------------------
try:
status, body = await apost(url, json_data=query)
except Exception as e:
logger.warning(f"SGLang network error: {e!s}")
metrics.add_metrics(sglang_errors=1)
logger.warning(f"Server network error: {e!s}")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
metrics.add_metrics(server_requests=1)
if status != 200:
logger.warning(f"SGLang HTTP {status}: {body[:250]!r}")
metrics.add_metrics(sglang_errors=1)
logger.warning(f"Server HTTP {status}: {body[:250]!r}")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
# ---------- Parse base JSON --------------------------------------------
try:
base = json.loads(body)
except json.JSONDecodeError:
logger.warning(f"SGLang response is not valid JSON: {body[:250]!r}")
metrics.add_metrics(sglang_errors=1)
logger.warning(f"Server response is not valid JSON: {body[:250]!r}")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
# Token accounting if available
if "usage" in base:
metrics.add_metrics(
sglang_input_tokens=base["usage"].get("prompt_tokens", 0),
sglang_output_tokens=base["usage"].get("completion_tokens", 0),
server_input_tokens=base["usage"].get("prompt_tokens", 0),
server_output_tokens=base["usage"].get("completion_tokens", 0),
)
# ---------- Extract the model message ----------------------------------
try:
content = base["choices"][0]["message"].get("content")
except (KeyError, IndexError, AttributeError) as e:
logger.warning(f"Missing fields in SGLang response: {e!s}")
metrics.add_metrics(sglang_errors=1)
logger.warning(f"Missing fields in Server response: {e!s}")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
if not isinstance(content, str):
logger.warning("SGLang `content` is not a string; treating as error.")
metrics.add_metrics(sglang_errors=1)
logger.warning("Server `content` is not a string; treating as error.")
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
try:
@ -152,7 +151,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
return pii_classification
except ValidationError as e:
logger.warning(f"Unable to parse pii classification object: {e!s}")
metrics.add_metrics(sglang_errors=1)
metrics.add_metrics(server_errors=1)
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
@ -223,7 +222,7 @@ async def apost(url, json_data):
async def process_dolma_document(args, dolma_doc, sem):
"""
Query SGLang to detect PII, enforcing a JSON schema.
Query model to detect PII, enforcing a JSON schema.
Resilient to:
Transport / HTTP errors
@ -236,9 +235,10 @@ async def process_dolma_document(args, dolma_doc, sem):
doc_id = dolma_doc.get("id")
text = dolma_doc.get("text", "") or ""
key_name = f"{args.model.replace('/', '_')}_pii_classification"
language_key_name = f"{args.model.replace('/', '_')}_language"
resume_cv_key_name = f"{args.model.replace('/', '_')}_is_resume_cv"
result_attributes = {key_name: []}
result_attributes = {resume_cv_key_name: [], language_key_name: []}
# If pdf_page_numbers is present, split the text and process each page separately
if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]:
@ -248,11 +248,15 @@ async def process_dolma_document(args, dolma_doc, sem):
# Filter pages down to actual real content
selected_page_numbers = [tuple(p) for p in page_numbers if p[0] < p[1]]
first_page_number = selected_page_numbers[0]
# Sample 3 pages max per document
# Sample 3 pages max per document, but always include the first page, it's a good signal for CV classification
random.shuffle(selected_page_numbers)
selected_page_numbers = selected_page_numbers[:3]
if first_page_number not in selected_page_numbers:
selected_page_numbers[0] = first_page_number
for start_pos, end_pos, page_num in page_numbers:
if (start_pos, end_pos, page_num) in selected_page_numbers:
page_text = text[start_pos:end_pos]
@ -261,9 +265,11 @@ async def process_dolma_document(args, dolma_doc, sem):
async with sem:
pii_class = await _process_single_page(page_text)
result_attributes[key_name].append([start_pos, end_pos, pii_class.is_resume_cv])
result_attributes[resume_cv_key_name].append([start_pos, end_pos, pii_class.is_resume_cv])
result_attributes[language_key_name].append([start_pos, end_pos, pii_class.primary_language])
else:
result_attributes[key_name].append([start_pos, end_pos, None])
result_attributes[resume_cv_key_name].append([start_pos, end_pos, None])
result_attributes[language_key_name].append([start_pos, end_pos, None])
return result_attributes
else:
@ -272,7 +278,7 @@ async def process_dolma_document(args, dolma_doc, sem):
async def process_file(args, worker_id: int, file_uri: str):
"""
Download a JSONL file, query SGLang per record, and collect attributes.
Download a JSONL file, query model per record, and collect attributes.
"""
# Fetch raw bytes (S3 or local)
if file_uri.startswith("s3://"):
@ -293,8 +299,8 @@ async def process_file(args, worker_id: int, file_uri: str):
lines = file_bytes.decode("utf-8").splitlines()
page_tasks = {}
# Send all records in parallel, max 500 queued at a time
sem = asyncio.Semaphore(500)
# Send all records in parallel, max N queued at a time
sem = asyncio.Semaphore(args.parallel_requests)
async with asyncio.TaskGroup() as tg:
for line in lines:
@ -302,7 +308,7 @@ async def process_file(args, worker_id: int, file_uri: str):
task = tg.create_task(process_dolma_document(args, dolma_doc, sem))
page_tasks[dolma_doc["id"]] = (task, dolma_doc)
logger.info(f"Started taskgroup with {len(page_tasks)} items for {file_uri}")
logger.info(f"Finished taskgroup with {len(page_tasks)} items for {file_uri}")
# Collect results and build attributes
attributes = []
@ -389,21 +395,19 @@ async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, work
semaphore.release()
async def sglang_server_task(model_name_or_path, args, semaphore):
async def server_task(model_name_or_path, args, semaphore):
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
# mem_fraction_arg = ["--mem-fraction-static", "0.80"]
cmd = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
"vllm",
"serve",
model_name_or_path,
"--port",
str(SGLANG_SERVER_PORT),
"--log-level-http",
str(SERVER_PORT),
"--uvicorn-log-level",
"warning",
"--mem-fraction-static", "0.40"
"--disable-log-requests",
]
proc = await asyncio.create_subprocess_exec(
@ -425,34 +429,25 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
async def process_line(line):
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
sglang_logger.info(line)
server_logger.info(line)
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
# can see any warnings/errors more easily
if not server_printed_ready_message:
logger.info(line)
if "Detected errors during sampling" in line:
logger.error("Cannot continue, sampling errors detected, model is probably corrupt")
sys.exit(1)
# TODO, need to trace down this issue in sglang itself, but it will otherwise cause the server to lock up
if "IndexError: list index out of range" in line:
logger.error("IndexError in model, restarting server")
proc.terminate()
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
server_printed_ready_message = True
last_semaphore_release = time.time()
match = re.search(r"#running-req: (\d+)", line)
match = re.search(r"Running: (\d+) reqs", line)
if match:
last_running_req = int(match.group(1))
match = re.search(r"#queue-req: (\d+)", line)
match = re.search(r"Waiting: (\d+) reqs", line)
if match:
last_queue_req = int(match.group(1))
logger.info(f"sglang running req: {last_running_req} queue req: {last_queue_req}")
logger.info(f"running req: {last_running_req} queue req: {last_queue_req}")
async def read_stream(stream):
while True:
@ -485,7 +480,7 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
try:
await proc.wait()
except asyncio.CancelledError:
logger.info("Got cancellation request for SGLang server")
logger.info("Got cancellation request for server")
proc.terminate()
raise
@ -493,28 +488,26 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
async def sglang_server_host(model_name_or_path, args, semaphore):
async def server_host(model_name_or_path, args, semaphore):
MAX_RETRIES = 5
retry = 0
await asyncio.sleep(1000000)
while retry < MAX_RETRIES:
await sglang_server_task(model_name_or_path, args, semaphore)
logger.warning("SGLang server task ended")
await server_task(model_name_or_path, args, semaphore)
logger.warning("Server task ended")
retry += 1
if retry >= MAX_RETRIES:
logger.error(f"Ended up starting the sglang server more than {retry} times, cancelling pipeline")
logger.error(f"Ended up starting the server more than {retry} times, cancelling pipeline")
logger.error("")
logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
logger.error("Please make sure vllm is installed according to the latest instructions for 0.8.4")
sys.exit(1)
async def sglang_server_ready():
async def check_server_ready():
max_attempts = 300
delay_sec = 1
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/models"
url = f"http://localhost:{SERVER_PORT}/v1/models"
for attempt in range(1, max_attempts + 1):
try:
@ -522,16 +515,16 @@ async def sglang_server_ready():
response = await session.get(url)
if response.status_code == 200:
logger.info("sglang server is ready.")
logger.info("server is ready.")
return
else:
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
except Exception:
logger.warning(f"Attempt {attempt}: Please wait for sglang server to become ready...")
logger.warning(f"Attempt {attempt}: Please wait for model server to become ready...")
await asyncio.sleep(delay_sec)
raise Exception("sglang server did not become ready after waiting.")
raise Exception("model server did not become ready after waiting.")
async def download_model(model_name_or_path: str):
@ -669,7 +662,8 @@ async def main():
parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder")
parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)")
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
parser.add_argument("--model", default="google/gemma-3-4b-it", help="SGLang model path or name")
parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
# Beaker/job running stuff
@ -683,17 +677,17 @@ async def main():
parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
parser.add_argument("--port", type=int, default=30024, help="Port for SGLang server")
parser.add_argument("--port", type=int, default=30024, help="Port for Model server")
args = parser.parse_args()
global SGLANG_SERVER_PORT, workspace_s3, dataset_s3
SGLANG_SERVER_PORT = args.port
global SERVER_PORT, workspace_s3, dataset_s3
SERVER_PORT = args.port
workspace_s3 = boto3.client("s3")
dataset_s3 = boto3.client("s3")
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
if "BEAKER_JOB_ID" in os.environ:
sglang_logger.addHandler(console_handler)
server_logger.addHandler(console_handler)
if "AWS_CREDENTIALS_FILE" in os.environ:
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
@ -742,7 +736,6 @@ async def main():
return
# If you get this far, then you are doing inference and need a GPU
check_sglang_version()
check_torch_gpu_available()
logger.info(f"Starting pipeline with PID {os.getpid()}")
@ -763,9 +756,9 @@ async def main():
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)
sglang_server = asyncio.create_task(sglang_server_host(model_name_or_path, args, semaphore))
model_server = asyncio.create_task(server_host(model_name_or_path, args, semaphore))
await sglang_server_ready()
await check_server_ready()
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
@ -778,7 +771,7 @@ async def main():
# Wait for all worker tasks to finish
await asyncio.gather(*worker_tasks)
sglang_server.cancel()
model_server.cancel()
metrics_task.cancel()
logger.info("Work done")