mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-14 17:52:53 +00:00
Adding some small changes to the tagging pipeline
This commit is contained in:
parent
66d293c178
commit
f8808478bd
48
scripts/beaker/Dockerfile-tagging
Normal file
48
scripts/beaker/Dockerfile-tagging
Normal file
@ -0,0 +1,48 @@
|
||||
FROM --platform=linux/amd64 nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04
|
||||
|
||||
RUN apt-get update -y && apt-get install -y software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get -y update
|
||||
|
||||
# Install requirements specific to pdfs
|
||||
RUN apt-get update && apt-get -y install python3-apt
|
||||
RUN echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections
|
||||
RUN apt-get update -y && apt-get install -y poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
|
||||
|
||||
RUN apt-get update -y && apt-get install -y --no-install-recommends \
|
||||
git \
|
||||
python3.11 \
|
||||
python3.11-dev \
|
||||
python3.11-distutils \
|
||||
ca-certificates \
|
||||
build-essential \
|
||||
curl \
|
||||
unzip
|
||||
|
||||
RUN rm -rf /var/lib/apt/lists/* \
|
||||
&& unlink /usr/bin/python3 \
|
||||
&& ln -s /usr/bin/python3.11 /usr/bin/python3 \
|
||||
&& ln -s /usr/bin/python3 /usr/bin/python \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python \
|
||||
&& pip3 install -U pip
|
||||
|
||||
RUN apt-get update && apt-get -y install python3.11-venv
|
||||
ADD --chmod=755 https://astral.sh/uv/install.sh /install.sh
|
||||
RUN /install.sh && rm /install.sh
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
WORKDIR /root
|
||||
COPY pyproject.toml pyproject.toml
|
||||
COPY olmocr/version.py olmocr/version.py
|
||||
|
||||
RUN /root/.local/bin/uv pip install --system --no-cache -e .
|
||||
|
||||
RUN /root/.local/bin/uv pip install --system --no-cache vllm==0.8.2
|
||||
|
||||
COPY olmocr olmocr
|
||||
|
||||
WORKDIR /root
|
||||
COPY olmocr olmocr
|
||||
|
||||
RUN python3 -m vllm --help
|
||||
RUN python3 -m olmocr.pipeline --help
|
@ -3,7 +3,7 @@
|
||||
Tagging pipeline for Dolma JSONL datasets.
|
||||
|
||||
For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder,
|
||||
this script issues a simple SGLang completion per record (e.g., "Is this document in English?"),
|
||||
this script issues a model prompt completion
|
||||
collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under
|
||||
scratch/attributes/, mirroring the input structure.
|
||||
"""
|
||||
@ -28,7 +28,6 @@ from huggingface_hub import snapshot_download
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from olmocr.check import (
|
||||
check_sglang_version,
|
||||
check_torch_gpu_available,
|
||||
)
|
||||
from olmocr.metrics import MetricsKeeper
|
||||
@ -46,8 +45,8 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.propagate = False
|
||||
|
||||
sglang_logger = logging.getLogger("sglang")
|
||||
sglang_logger.propagate = False
|
||||
server_logger = logging.getLogger("vllm")
|
||||
server_logger.propagate = False
|
||||
|
||||
file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
@ -60,11 +59,11 @@ console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(level
|
||||
# Add handlers to the logger
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(console_handler)
|
||||
sglang_logger.addHandler(file_handler)
|
||||
server_logger.addHandler(file_handler)
|
||||
|
||||
|
||||
# Default port; overridden by --port
|
||||
SGLANG_SERVER_PORT = 30024
|
||||
SERVER_PORT = 30024
|
||||
|
||||
# Global variables for token statistics
|
||||
metrics = MetricsKeeper(window=60 * 5)
|
||||
@ -81,8 +80,6 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
||||
"""Helper function to process a single document or page."""
|
||||
text = page_text
|
||||
|
||||
metrics.add_metrics(sglang_requests=1)
|
||||
|
||||
query = {
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"messages": [
|
||||
@ -104,47 +101,49 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
||||
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
|
||||
}
|
||||
|
||||
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||
url = f"http://localhost:{SERVER_PORT}/v1/chat/completions"
|
||||
|
||||
# ---------- HTTP call ---------------------------------------------------
|
||||
try:
|
||||
status, body = await apost(url, json_data=query)
|
||||
except Exception as e:
|
||||
logger.warning(f"SGLang network error: {e!s}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
logger.warning(f"Server network error: {e!s}")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
metrics.add_metrics(server_requests=1)
|
||||
|
||||
if status != 200:
|
||||
logger.warning(f"SGLang HTTP {status}: {body[:250]!r}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
logger.warning(f"Server HTTP {status}: {body[:250]!r}")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
# ---------- Parse base JSON --------------------------------------------
|
||||
try:
|
||||
base = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"SGLang response is not valid JSON: {body[:250]!r}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
logger.warning(f"Server response is not valid JSON: {body[:250]!r}")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
# Token accounting if available
|
||||
if "usage" in base:
|
||||
metrics.add_metrics(
|
||||
sglang_input_tokens=base["usage"].get("prompt_tokens", 0),
|
||||
sglang_output_tokens=base["usage"].get("completion_tokens", 0),
|
||||
server_input_tokens=base["usage"].get("prompt_tokens", 0),
|
||||
server_output_tokens=base["usage"].get("completion_tokens", 0),
|
||||
)
|
||||
|
||||
# ---------- Extract the model message ----------------------------------
|
||||
try:
|
||||
content = base["choices"][0]["message"].get("content")
|
||||
except (KeyError, IndexError, AttributeError) as e:
|
||||
logger.warning(f"Missing fields in SGLang response: {e!s}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
logger.warning(f"Missing fields in Server response: {e!s}")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
if not isinstance(content, str):
|
||||
logger.warning("SGLang `content` is not a string; treating as error.")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
logger.warning("Server `content` is not a string; treating as error.")
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
try:
|
||||
@ -152,7 +151,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
||||
return pii_classification
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Unable to parse pii classification object: {e!s}")
|
||||
metrics.add_metrics(sglang_errors=1)
|
||||
metrics.add_metrics(server_errors=1)
|
||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||
|
||||
|
||||
@ -223,7 +222,7 @@ async def apost(url, json_data):
|
||||
|
||||
async def process_dolma_document(args, dolma_doc, sem):
|
||||
"""
|
||||
Query SGLang to detect PII, enforcing a JSON schema.
|
||||
Query model to detect PII, enforcing a JSON schema.
|
||||
|
||||
Resilient to:
|
||||
• Transport / HTTP errors
|
||||
@ -236,9 +235,10 @@ async def process_dolma_document(args, dolma_doc, sem):
|
||||
doc_id = dolma_doc.get("id")
|
||||
text = dolma_doc.get("text", "") or ""
|
||||
|
||||
key_name = f"{args.model.replace('/', '_')}_pii_classification"
|
||||
language_key_name = f"{args.model.replace('/', '_')}_language"
|
||||
resume_cv_key_name = f"{args.model.replace('/', '_')}_is_resume_cv"
|
||||
|
||||
result_attributes = {key_name: []}
|
||||
result_attributes = {resume_cv_key_name: [], language_key_name: []}
|
||||
|
||||
# If pdf_page_numbers is present, split the text and process each page separately
|
||||
if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]:
|
||||
@ -248,11 +248,15 @@ async def process_dolma_document(args, dolma_doc, sem):
|
||||
|
||||
# Filter pages down to actual real content
|
||||
selected_page_numbers = [tuple(p) for p in page_numbers if p[0] < p[1]]
|
||||
first_page_number = selected_page_numbers[0]
|
||||
|
||||
# Sample 3 pages max per document
|
||||
# Sample 3 pages max per document, but always include the first page, it's a good signal for CV classification
|
||||
random.shuffle(selected_page_numbers)
|
||||
selected_page_numbers = selected_page_numbers[:3]
|
||||
|
||||
if first_page_number not in selected_page_numbers:
|
||||
selected_page_numbers[0] = first_page_number
|
||||
|
||||
for start_pos, end_pos, page_num in page_numbers:
|
||||
if (start_pos, end_pos, page_num) in selected_page_numbers:
|
||||
page_text = text[start_pos:end_pos]
|
||||
@ -261,9 +265,11 @@ async def process_dolma_document(args, dolma_doc, sem):
|
||||
async with sem:
|
||||
pii_class = await _process_single_page(page_text)
|
||||
|
||||
result_attributes[key_name].append([start_pos, end_pos, pii_class.is_resume_cv])
|
||||
result_attributes[resume_cv_key_name].append([start_pos, end_pos, pii_class.is_resume_cv])
|
||||
result_attributes[language_key_name].append([start_pos, end_pos, pii_class.primary_language])
|
||||
else:
|
||||
result_attributes[key_name].append([start_pos, end_pos, None])
|
||||
result_attributes[resume_cv_key_name].append([start_pos, end_pos, None])
|
||||
result_attributes[language_key_name].append([start_pos, end_pos, None])
|
||||
|
||||
return result_attributes
|
||||
else:
|
||||
@ -272,7 +278,7 @@ async def process_dolma_document(args, dolma_doc, sem):
|
||||
|
||||
async def process_file(args, worker_id: int, file_uri: str):
|
||||
"""
|
||||
Download a JSONL file, query SGLang per record, and collect attributes.
|
||||
Download a JSONL file, query model per record, and collect attributes.
|
||||
"""
|
||||
# Fetch raw bytes (S3 or local)
|
||||
if file_uri.startswith("s3://"):
|
||||
@ -293,8 +299,8 @@ async def process_file(args, worker_id: int, file_uri: str):
|
||||
lines = file_bytes.decode("utf-8").splitlines()
|
||||
page_tasks = {}
|
||||
|
||||
# Send all records in parallel, max 500 queued at a time
|
||||
sem = asyncio.Semaphore(500)
|
||||
# Send all records in parallel, max N queued at a time
|
||||
sem = asyncio.Semaphore(args.parallel_requests)
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for line in lines:
|
||||
@ -302,7 +308,7 @@ async def process_file(args, worker_id: int, file_uri: str):
|
||||
task = tg.create_task(process_dolma_document(args, dolma_doc, sem))
|
||||
page_tasks[dolma_doc["id"]] = (task, dolma_doc)
|
||||
|
||||
logger.info(f"Started taskgroup with {len(page_tasks)} items for {file_uri}")
|
||||
logger.info(f"Finished taskgroup with {len(page_tasks)} items for {file_uri}")
|
||||
|
||||
# Collect results and build attributes
|
||||
attributes = []
|
||||
@ -389,21 +395,19 @@ async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, work
|
||||
semaphore.release()
|
||||
|
||||
|
||||
async def sglang_server_task(model_name_or_path, args, semaphore):
|
||||
async def server_task(model_name_or_path, args, semaphore):
|
||||
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
|
||||
# mem_fraction_arg = ["--mem-fraction-static", "0.80"]
|
||||
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
"vllm",
|
||||
"serve",
|
||||
model_name_or_path,
|
||||
"--port",
|
||||
str(SGLANG_SERVER_PORT),
|
||||
"--log-level-http",
|
||||
str(SERVER_PORT),
|
||||
"--uvicorn-log-level",
|
||||
"warning",
|
||||
"--mem-fraction-static", "0.40"
|
||||
"--disable-log-requests",
|
||||
]
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
@ -425,34 +429,25 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
|
||||
|
||||
async def process_line(line):
|
||||
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
|
||||
sglang_logger.info(line)
|
||||
server_logger.info(line)
|
||||
|
||||
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
|
||||
# can see any warnings/errors more easily
|
||||
if not server_printed_ready_message:
|
||||
logger.info(line)
|
||||
|
||||
if "Detected errors during sampling" in line:
|
||||
logger.error("Cannot continue, sampling errors detected, model is probably corrupt")
|
||||
sys.exit(1)
|
||||
|
||||
# TODO, need to trace down this issue in sglang itself, but it will otherwise cause the server to lock up
|
||||
if "IndexError: list index out of range" in line:
|
||||
logger.error("IndexError in model, restarting server")
|
||||
proc.terminate()
|
||||
|
||||
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
|
||||
server_printed_ready_message = True
|
||||
last_semaphore_release = time.time()
|
||||
|
||||
match = re.search(r"#running-req: (\d+)", line)
|
||||
match = re.search(r"Running: (\d+) reqs", line)
|
||||
if match:
|
||||
last_running_req = int(match.group(1))
|
||||
|
||||
match = re.search(r"#queue-req: (\d+)", line)
|
||||
match = re.search(r"Waiting: (\d+) reqs", line)
|
||||
if match:
|
||||
last_queue_req = int(match.group(1))
|
||||
logger.info(f"sglang running req: {last_running_req} queue req: {last_queue_req}")
|
||||
logger.info(f"running req: {last_running_req} queue req: {last_queue_req}")
|
||||
|
||||
async def read_stream(stream):
|
||||
while True:
|
||||
@ -485,7 +480,7 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
|
||||
try:
|
||||
await proc.wait()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Got cancellation request for SGLang server")
|
||||
logger.info("Got cancellation request for server")
|
||||
proc.terminate()
|
||||
raise
|
||||
|
||||
@ -493,28 +488,26 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
|
||||
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
|
||||
|
||||
|
||||
async def sglang_server_host(model_name_or_path, args, semaphore):
|
||||
async def server_host(model_name_or_path, args, semaphore):
|
||||
MAX_RETRIES = 5
|
||||
retry = 0
|
||||
|
||||
await asyncio.sleep(1000000)
|
||||
|
||||
while retry < MAX_RETRIES:
|
||||
await sglang_server_task(model_name_or_path, args, semaphore)
|
||||
logger.warning("SGLang server task ended")
|
||||
await server_task(model_name_or_path, args, semaphore)
|
||||
logger.warning("Server task ended")
|
||||
retry += 1
|
||||
|
||||
if retry >= MAX_RETRIES:
|
||||
logger.error(f"Ended up starting the sglang server more than {retry} times, cancelling pipeline")
|
||||
logger.error(f"Ended up starting the server more than {retry} times, cancelling pipeline")
|
||||
logger.error("")
|
||||
logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
|
||||
logger.error("Please make sure vllm is installed according to the latest instructions for 0.8.4")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
async def sglang_server_ready():
|
||||
async def check_server_ready():
|
||||
max_attempts = 300
|
||||
delay_sec = 1
|
||||
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/models"
|
||||
url = f"http://localhost:{SERVER_PORT}/v1/models"
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
@ -522,16 +515,16 @@ async def sglang_server_ready():
|
||||
response = await session.get(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info("sglang server is ready.")
|
||||
logger.info("server is ready.")
|
||||
return
|
||||
else:
|
||||
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
|
||||
except Exception:
|
||||
logger.warning(f"Attempt {attempt}: Please wait for sglang server to become ready...")
|
||||
logger.warning(f"Attempt {attempt}: Please wait for model server to become ready...")
|
||||
|
||||
await asyncio.sleep(delay_sec)
|
||||
|
||||
raise Exception("sglang server did not become ready after waiting.")
|
||||
raise Exception("model server did not become ready after waiting.")
|
||||
|
||||
|
||||
async def download_model(model_name_or_path: str):
|
||||
@ -669,7 +662,8 @@ async def main():
|
||||
parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder")
|
||||
parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)")
|
||||
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
|
||||
parser.add_argument("--model", default="google/gemma-3-4b-it", help="SGLang model path or name")
|
||||
parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
|
||||
parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
|
||||
parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
|
||||
|
||||
# Beaker/job running stuff
|
||||
@ -683,17 +677,17 @@ async def main():
|
||||
parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
|
||||
parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
|
||||
|
||||
parser.add_argument("--port", type=int, default=30024, help="Port for SGLang server")
|
||||
parser.add_argument("--port", type=int, default=30024, help="Port for Model server")
|
||||
args = parser.parse_args()
|
||||
|
||||
global SGLANG_SERVER_PORT, workspace_s3, dataset_s3
|
||||
SGLANG_SERVER_PORT = args.port
|
||||
global SERVER_PORT, workspace_s3, dataset_s3
|
||||
SERVER_PORT = args.port
|
||||
workspace_s3 = boto3.client("s3")
|
||||
dataset_s3 = boto3.client("s3")
|
||||
|
||||
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
|
||||
if "BEAKER_JOB_ID" in os.environ:
|
||||
sglang_logger.addHandler(console_handler)
|
||||
server_logger.addHandler(console_handler)
|
||||
if "AWS_CREDENTIALS_FILE" in os.environ:
|
||||
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
|
||||
@ -742,7 +736,6 @@ async def main():
|
||||
return
|
||||
|
||||
# If you get this far, then you are doing inference and need a GPU
|
||||
check_sglang_version()
|
||||
check_torch_gpu_available()
|
||||
|
||||
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
||||
@ -763,9 +756,9 @@ async def main():
|
||||
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
sglang_server = asyncio.create_task(sglang_server_host(model_name_or_path, args, semaphore))
|
||||
model_server = asyncio.create_task(server_host(model_name_or_path, args, semaphore))
|
||||
|
||||
await sglang_server_ready()
|
||||
await check_server_ready()
|
||||
|
||||
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
||||
|
||||
@ -778,7 +771,7 @@ async def main():
|
||||
# Wait for all worker tasks to finish
|
||||
await asyncio.gather(*worker_tasks)
|
||||
|
||||
sglang_server.cancel()
|
||||
model_server.cancel()
|
||||
metrics_task.cancel()
|
||||
logger.info("Work done")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user