mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-16 18:52:50 +00:00
Adding some small changes to the tagging pipeline
This commit is contained in:
parent
66d293c178
commit
f8808478bd
48
scripts/beaker/Dockerfile-tagging
Normal file
48
scripts/beaker/Dockerfile-tagging
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
FROM --platform=linux/amd64 nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04
|
||||||
|
|
||||||
|
RUN apt-get update -y && apt-get install -y software-properties-common \
|
||||||
|
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||||
|
&& apt-get -y update
|
||||||
|
|
||||||
|
# Install requirements specific to pdfs
|
||||||
|
RUN apt-get update && apt-get -y install python3-apt
|
||||||
|
RUN echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections
|
||||||
|
RUN apt-get update -y && apt-get install -y poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
|
||||||
|
|
||||||
|
RUN apt-get update -y && apt-get install -y --no-install-recommends \
|
||||||
|
git \
|
||||||
|
python3.11 \
|
||||||
|
python3.11-dev \
|
||||||
|
python3.11-distutils \
|
||||||
|
ca-certificates \
|
||||||
|
build-essential \
|
||||||
|
curl \
|
||||||
|
unzip
|
||||||
|
|
||||||
|
RUN rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& unlink /usr/bin/python3 \
|
||||||
|
&& ln -s /usr/bin/python3.11 /usr/bin/python3 \
|
||||||
|
&& ln -s /usr/bin/python3 /usr/bin/python \
|
||||||
|
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python \
|
||||||
|
&& pip3 install -U pip
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get -y install python3.11-venv
|
||||||
|
ADD --chmod=755 https://astral.sh/uv/install.sh /install.sh
|
||||||
|
RUN /install.sh && rm /install.sh
|
||||||
|
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
WORKDIR /root
|
||||||
|
COPY pyproject.toml pyproject.toml
|
||||||
|
COPY olmocr/version.py olmocr/version.py
|
||||||
|
|
||||||
|
RUN /root/.local/bin/uv pip install --system --no-cache -e .
|
||||||
|
|
||||||
|
RUN /root/.local/bin/uv pip install --system --no-cache vllm==0.8.2
|
||||||
|
|
||||||
|
COPY olmocr olmocr
|
||||||
|
|
||||||
|
WORKDIR /root
|
||||||
|
COPY olmocr olmocr
|
||||||
|
|
||||||
|
RUN python3 -m vllm --help
|
||||||
|
RUN python3 -m olmocr.pipeline --help
|
@ -3,7 +3,7 @@
|
|||||||
Tagging pipeline for Dolma JSONL datasets.
|
Tagging pipeline for Dolma JSONL datasets.
|
||||||
|
|
||||||
For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder,
|
For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder,
|
||||||
this script issues a simple SGLang completion per record (e.g., "Is this document in English?"),
|
this script issues a model prompt completion
|
||||||
collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under
|
collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under
|
||||||
scratch/attributes/, mirroring the input structure.
|
scratch/attributes/, mirroring the input structure.
|
||||||
"""
|
"""
|
||||||
@ -28,7 +28,6 @@ from huggingface_hub import snapshot_download
|
|||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
from olmocr.check import (
|
from olmocr.check import (
|
||||||
check_sglang_version,
|
|
||||||
check_torch_gpu_available,
|
check_torch_gpu_available,
|
||||||
)
|
)
|
||||||
from olmocr.metrics import MetricsKeeper
|
from olmocr.metrics import MetricsKeeper
|
||||||
@ -46,8 +45,8 @@ logger = logging.getLogger(__name__)
|
|||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
sglang_logger = logging.getLogger("sglang")
|
server_logger = logging.getLogger("vllm")
|
||||||
sglang_logger.propagate = False
|
server_logger.propagate = False
|
||||||
|
|
||||||
file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
|
file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
|
||||||
file_handler.setLevel(logging.DEBUG)
|
file_handler.setLevel(logging.DEBUG)
|
||||||
@ -60,11 +59,11 @@ console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(level
|
|||||||
# Add handlers to the logger
|
# Add handlers to the logger
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
sglang_logger.addHandler(file_handler)
|
server_logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
|
||||||
# Default port; overridden by --port
|
# Default port; overridden by --port
|
||||||
SGLANG_SERVER_PORT = 30024
|
SERVER_PORT = 30024
|
||||||
|
|
||||||
# Global variables for token statistics
|
# Global variables for token statistics
|
||||||
metrics = MetricsKeeper(window=60 * 5)
|
metrics = MetricsKeeper(window=60 * 5)
|
||||||
@ -81,8 +80,6 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
|||||||
"""Helper function to process a single document or page."""
|
"""Helper function to process a single document or page."""
|
||||||
text = page_text
|
text = page_text
|
||||||
|
|
||||||
metrics.add_metrics(sglang_requests=1)
|
|
||||||
|
|
||||||
query = {
|
query = {
|
||||||
"model": "google/gemma-3-4b-it",
|
"model": "google/gemma-3-4b-it",
|
||||||
"messages": [
|
"messages": [
|
||||||
@ -104,47 +101,49 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
|||||||
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
|
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
|
||||||
}
|
}
|
||||||
|
|
||||||
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
url = f"http://localhost:{SERVER_PORT}/v1/chat/completions"
|
||||||
|
|
||||||
# ---------- HTTP call ---------------------------------------------------
|
# ---------- HTTP call ---------------------------------------------------
|
||||||
try:
|
try:
|
||||||
status, body = await apost(url, json_data=query)
|
status, body = await apost(url, json_data=query)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"SGLang network error: {e!s}")
|
logger.warning(f"Server network error: {e!s}")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(server_errors=1)
|
||||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
|
metrics.add_metrics(server_requests=1)
|
||||||
|
|
||||||
if status != 200:
|
if status != 200:
|
||||||
logger.warning(f"SGLang HTTP {status}: {body[:250]!r}")
|
logger.warning(f"Server HTTP {status}: {body[:250]!r}")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(server_errors=1)
|
||||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
# ---------- Parse base JSON --------------------------------------------
|
# ---------- Parse base JSON --------------------------------------------
|
||||||
try:
|
try:
|
||||||
base = json.loads(body)
|
base = json.loads(body)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"SGLang response is not valid JSON: {body[:250]!r}")
|
logger.warning(f"Server response is not valid JSON: {body[:250]!r}")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(server_errors=1)
|
||||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
# Token accounting if available
|
# Token accounting if available
|
||||||
if "usage" in base:
|
if "usage" in base:
|
||||||
metrics.add_metrics(
|
metrics.add_metrics(
|
||||||
sglang_input_tokens=base["usage"].get("prompt_tokens", 0),
|
server_input_tokens=base["usage"].get("prompt_tokens", 0),
|
||||||
sglang_output_tokens=base["usage"].get("completion_tokens", 0),
|
server_output_tokens=base["usage"].get("completion_tokens", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---------- Extract the model message ----------------------------------
|
# ---------- Extract the model message ----------------------------------
|
||||||
try:
|
try:
|
||||||
content = base["choices"][0]["message"].get("content")
|
content = base["choices"][0]["message"].get("content")
|
||||||
except (KeyError, IndexError, AttributeError) as e:
|
except (KeyError, IndexError, AttributeError) as e:
|
||||||
logger.warning(f"Missing fields in SGLang response: {e!s}")
|
logger.warning(f"Missing fields in Server response: {e!s}")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(server_errors=1)
|
||||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
if not isinstance(content, str):
|
if not isinstance(content, str):
|
||||||
logger.warning("SGLang `content` is not a string; treating as error.")
|
logger.warning("Server `content` is not a string; treating as error.")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(server_errors=1)
|
||||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -152,7 +151,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
|||||||
return pii_classification
|
return pii_classification
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.warning(f"Unable to parse pii classification object: {e!s}")
|
logger.warning(f"Unable to parse pii classification object: {e!s}")
|
||||||
metrics.add_metrics(sglang_errors=1)
|
metrics.add_metrics(server_errors=1)
|
||||||
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
|
||||||
|
|
||||||
|
|
||||||
@ -223,7 +222,7 @@ async def apost(url, json_data):
|
|||||||
|
|
||||||
async def process_dolma_document(args, dolma_doc, sem):
|
async def process_dolma_document(args, dolma_doc, sem):
|
||||||
"""
|
"""
|
||||||
Query SGLang to detect PII, enforcing a JSON schema.
|
Query model to detect PII, enforcing a JSON schema.
|
||||||
|
|
||||||
Resilient to:
|
Resilient to:
|
||||||
• Transport / HTTP errors
|
• Transport / HTTP errors
|
||||||
@ -236,9 +235,10 @@ async def process_dolma_document(args, dolma_doc, sem):
|
|||||||
doc_id = dolma_doc.get("id")
|
doc_id = dolma_doc.get("id")
|
||||||
text = dolma_doc.get("text", "") or ""
|
text = dolma_doc.get("text", "") or ""
|
||||||
|
|
||||||
key_name = f"{args.model.replace('/', '_')}_pii_classification"
|
language_key_name = f"{args.model.replace('/', '_')}_language"
|
||||||
|
resume_cv_key_name = f"{args.model.replace('/', '_')}_is_resume_cv"
|
||||||
|
|
||||||
result_attributes = {key_name: []}
|
result_attributes = {resume_cv_key_name: [], language_key_name: []}
|
||||||
|
|
||||||
# If pdf_page_numbers is present, split the text and process each page separately
|
# If pdf_page_numbers is present, split the text and process each page separately
|
||||||
if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]:
|
if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]:
|
||||||
@ -248,11 +248,15 @@ async def process_dolma_document(args, dolma_doc, sem):
|
|||||||
|
|
||||||
# Filter pages down to actual real content
|
# Filter pages down to actual real content
|
||||||
selected_page_numbers = [tuple(p) for p in page_numbers if p[0] < p[1]]
|
selected_page_numbers = [tuple(p) for p in page_numbers if p[0] < p[1]]
|
||||||
|
first_page_number = selected_page_numbers[0]
|
||||||
|
|
||||||
# Sample 3 pages max per document
|
# Sample 3 pages max per document, but always include the first page, it's a good signal for CV classification
|
||||||
random.shuffle(selected_page_numbers)
|
random.shuffle(selected_page_numbers)
|
||||||
selected_page_numbers = selected_page_numbers[:3]
|
selected_page_numbers = selected_page_numbers[:3]
|
||||||
|
|
||||||
|
if first_page_number not in selected_page_numbers:
|
||||||
|
selected_page_numbers[0] = first_page_number
|
||||||
|
|
||||||
for start_pos, end_pos, page_num in page_numbers:
|
for start_pos, end_pos, page_num in page_numbers:
|
||||||
if (start_pos, end_pos, page_num) in selected_page_numbers:
|
if (start_pos, end_pos, page_num) in selected_page_numbers:
|
||||||
page_text = text[start_pos:end_pos]
|
page_text = text[start_pos:end_pos]
|
||||||
@ -261,9 +265,11 @@ async def process_dolma_document(args, dolma_doc, sem):
|
|||||||
async with sem:
|
async with sem:
|
||||||
pii_class = await _process_single_page(page_text)
|
pii_class = await _process_single_page(page_text)
|
||||||
|
|
||||||
result_attributes[key_name].append([start_pos, end_pos, pii_class.is_resume_cv])
|
result_attributes[resume_cv_key_name].append([start_pos, end_pos, pii_class.is_resume_cv])
|
||||||
|
result_attributes[language_key_name].append([start_pos, end_pos, pii_class.primary_language])
|
||||||
else:
|
else:
|
||||||
result_attributes[key_name].append([start_pos, end_pos, None])
|
result_attributes[resume_cv_key_name].append([start_pos, end_pos, None])
|
||||||
|
result_attributes[language_key_name].append([start_pos, end_pos, None])
|
||||||
|
|
||||||
return result_attributes
|
return result_attributes
|
||||||
else:
|
else:
|
||||||
@ -272,7 +278,7 @@ async def process_dolma_document(args, dolma_doc, sem):
|
|||||||
|
|
||||||
async def process_file(args, worker_id: int, file_uri: str):
|
async def process_file(args, worker_id: int, file_uri: str):
|
||||||
"""
|
"""
|
||||||
Download a JSONL file, query SGLang per record, and collect attributes.
|
Download a JSONL file, query model per record, and collect attributes.
|
||||||
"""
|
"""
|
||||||
# Fetch raw bytes (S3 or local)
|
# Fetch raw bytes (S3 or local)
|
||||||
if file_uri.startswith("s3://"):
|
if file_uri.startswith("s3://"):
|
||||||
@ -293,8 +299,8 @@ async def process_file(args, worker_id: int, file_uri: str):
|
|||||||
lines = file_bytes.decode("utf-8").splitlines()
|
lines = file_bytes.decode("utf-8").splitlines()
|
||||||
page_tasks = {}
|
page_tasks = {}
|
||||||
|
|
||||||
# Send all records in parallel, max 500 queued at a time
|
# Send all records in parallel, max N queued at a time
|
||||||
sem = asyncio.Semaphore(500)
|
sem = asyncio.Semaphore(args.parallel_requests)
|
||||||
|
|
||||||
async with asyncio.TaskGroup() as tg:
|
async with asyncio.TaskGroup() as tg:
|
||||||
for line in lines:
|
for line in lines:
|
||||||
@ -302,7 +308,7 @@ async def process_file(args, worker_id: int, file_uri: str):
|
|||||||
task = tg.create_task(process_dolma_document(args, dolma_doc, sem))
|
task = tg.create_task(process_dolma_document(args, dolma_doc, sem))
|
||||||
page_tasks[dolma_doc["id"]] = (task, dolma_doc)
|
page_tasks[dolma_doc["id"]] = (task, dolma_doc)
|
||||||
|
|
||||||
logger.info(f"Started taskgroup with {len(page_tasks)} items for {file_uri}")
|
logger.info(f"Finished taskgroup with {len(page_tasks)} items for {file_uri}")
|
||||||
|
|
||||||
# Collect results and build attributes
|
# Collect results and build attributes
|
||||||
attributes = []
|
attributes = []
|
||||||
@ -389,21 +395,19 @@ async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, work
|
|||||||
semaphore.release()
|
semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
async def sglang_server_task(model_name_or_path, args, semaphore):
|
async def server_task(model_name_or_path, args, semaphore):
|
||||||
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
|
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
|
||||||
# mem_fraction_arg = ["--mem-fraction-static", "0.80"]
|
# mem_fraction_arg = ["--mem-fraction-static", "0.80"]
|
||||||
|
|
||||||
cmd = [
|
cmd = [
|
||||||
"python3",
|
"vllm",
|
||||||
"-m",
|
"serve",
|
||||||
"sglang.launch_server",
|
|
||||||
"--model-path",
|
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
"--port",
|
"--port",
|
||||||
str(SGLANG_SERVER_PORT),
|
str(SERVER_PORT),
|
||||||
"--log-level-http",
|
"--uvicorn-log-level",
|
||||||
"warning",
|
"warning",
|
||||||
"--mem-fraction-static", "0.40"
|
"--disable-log-requests",
|
||||||
]
|
]
|
||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
@ -425,34 +429,25 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
|
|||||||
|
|
||||||
async def process_line(line):
|
async def process_line(line):
|
||||||
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
|
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
|
||||||
sglang_logger.info(line)
|
server_logger.info(line)
|
||||||
|
|
||||||
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
|
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
|
||||||
# can see any warnings/errors more easily
|
# can see any warnings/errors more easily
|
||||||
if not server_printed_ready_message:
|
if not server_printed_ready_message:
|
||||||
logger.info(line)
|
logger.info(line)
|
||||||
|
|
||||||
if "Detected errors during sampling" in line:
|
|
||||||
logger.error("Cannot continue, sampling errors detected, model is probably corrupt")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# TODO, need to trace down this issue in sglang itself, but it will otherwise cause the server to lock up
|
|
||||||
if "IndexError: list index out of range" in line:
|
|
||||||
logger.error("IndexError in model, restarting server")
|
|
||||||
proc.terminate()
|
|
||||||
|
|
||||||
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
|
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
|
||||||
server_printed_ready_message = True
|
server_printed_ready_message = True
|
||||||
last_semaphore_release = time.time()
|
last_semaphore_release = time.time()
|
||||||
|
|
||||||
match = re.search(r"#running-req: (\d+)", line)
|
match = re.search(r"Running: (\d+) reqs", line)
|
||||||
if match:
|
if match:
|
||||||
last_running_req = int(match.group(1))
|
last_running_req = int(match.group(1))
|
||||||
|
|
||||||
match = re.search(r"#queue-req: (\d+)", line)
|
match = re.search(r"Waiting: (\d+) reqs", line)
|
||||||
if match:
|
if match:
|
||||||
last_queue_req = int(match.group(1))
|
last_queue_req = int(match.group(1))
|
||||||
logger.info(f"sglang running req: {last_running_req} queue req: {last_queue_req}")
|
logger.info(f"running req: {last_running_req} queue req: {last_queue_req}")
|
||||||
|
|
||||||
async def read_stream(stream):
|
async def read_stream(stream):
|
||||||
while True:
|
while True:
|
||||||
@ -485,7 +480,7 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
|
|||||||
try:
|
try:
|
||||||
await proc.wait()
|
await proc.wait()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Got cancellation request for SGLang server")
|
logger.info("Got cancellation request for server")
|
||||||
proc.terminate()
|
proc.terminate()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -493,28 +488,26 @@ async def sglang_server_task(model_name_or_path, args, semaphore):
|
|||||||
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
|
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
async def sglang_server_host(model_name_or_path, args, semaphore):
|
async def server_host(model_name_or_path, args, semaphore):
|
||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
retry = 0
|
retry = 0
|
||||||
|
|
||||||
await asyncio.sleep(1000000)
|
|
||||||
|
|
||||||
while retry < MAX_RETRIES:
|
while retry < MAX_RETRIES:
|
||||||
await sglang_server_task(model_name_or_path, args, semaphore)
|
await server_task(model_name_or_path, args, semaphore)
|
||||||
logger.warning("SGLang server task ended")
|
logger.warning("Server task ended")
|
||||||
retry += 1
|
retry += 1
|
||||||
|
|
||||||
if retry >= MAX_RETRIES:
|
if retry >= MAX_RETRIES:
|
||||||
logger.error(f"Ended up starting the sglang server more than {retry} times, cancelling pipeline")
|
logger.error(f"Ended up starting the server more than {retry} times, cancelling pipeline")
|
||||||
logger.error("")
|
logger.error("")
|
||||||
logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
|
logger.error("Please make sure vllm is installed according to the latest instructions for 0.8.4")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
async def sglang_server_ready():
|
async def check_server_ready():
|
||||||
max_attempts = 300
|
max_attempts = 300
|
||||||
delay_sec = 1
|
delay_sec = 1
|
||||||
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/models"
|
url = f"http://localhost:{SERVER_PORT}/v1/models"
|
||||||
|
|
||||||
for attempt in range(1, max_attempts + 1):
|
for attempt in range(1, max_attempts + 1):
|
||||||
try:
|
try:
|
||||||
@ -522,16 +515,16 @@ async def sglang_server_ready():
|
|||||||
response = await session.get(url)
|
response = await session.get(url)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
logger.info("sglang server is ready.")
|
logger.info("server is ready.")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
|
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(f"Attempt {attempt}: Please wait for sglang server to become ready...")
|
logger.warning(f"Attempt {attempt}: Please wait for model server to become ready...")
|
||||||
|
|
||||||
await asyncio.sleep(delay_sec)
|
await asyncio.sleep(delay_sec)
|
||||||
|
|
||||||
raise Exception("sglang server did not become ready after waiting.")
|
raise Exception("model server did not become ready after waiting.")
|
||||||
|
|
||||||
|
|
||||||
async def download_model(model_name_or_path: str):
|
async def download_model(model_name_or_path: str):
|
||||||
@ -669,7 +662,8 @@ async def main():
|
|||||||
parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder")
|
parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder")
|
||||||
parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)")
|
parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)")
|
||||||
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
|
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
|
||||||
parser.add_argument("--model", default="google/gemma-3-4b-it", help="SGLang model path or name")
|
parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
|
||||||
|
parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
|
||||||
parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
|
parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
|
||||||
|
|
||||||
# Beaker/job running stuff
|
# Beaker/job running stuff
|
||||||
@ -683,17 +677,17 @@ async def main():
|
|||||||
parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
|
parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
|
||||||
parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
|
parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
|
||||||
|
|
||||||
parser.add_argument("--port", type=int, default=30024, help="Port for SGLang server")
|
parser.add_argument("--port", type=int, default=30024, help="Port for Model server")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
global SGLANG_SERVER_PORT, workspace_s3, dataset_s3
|
global SERVER_PORT, workspace_s3, dataset_s3
|
||||||
SGLANG_SERVER_PORT = args.port
|
SERVER_PORT = args.port
|
||||||
workspace_s3 = boto3.client("s3")
|
workspace_s3 = boto3.client("s3")
|
||||||
dataset_s3 = boto3.client("s3")
|
dataset_s3 = boto3.client("s3")
|
||||||
|
|
||||||
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
|
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
|
||||||
if "BEAKER_JOB_ID" in os.environ:
|
if "BEAKER_JOB_ID" in os.environ:
|
||||||
sglang_logger.addHandler(console_handler)
|
server_logger.addHandler(console_handler)
|
||||||
if "AWS_CREDENTIALS_FILE" in os.environ:
|
if "AWS_CREDENTIALS_FILE" in os.environ:
|
||||||
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||||
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
|
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
|
||||||
@ -742,7 +736,6 @@ async def main():
|
|||||||
return
|
return
|
||||||
|
|
||||||
# If you get this far, then you are doing inference and need a GPU
|
# If you get this far, then you are doing inference and need a GPU
|
||||||
check_sglang_version()
|
|
||||||
check_torch_gpu_available()
|
check_torch_gpu_available()
|
||||||
|
|
||||||
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
||||||
@ -763,9 +756,9 @@ async def main():
|
|||||||
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||||
semaphore = asyncio.Semaphore(1)
|
semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
sglang_server = asyncio.create_task(sglang_server_host(model_name_or_path, args, semaphore))
|
model_server = asyncio.create_task(server_host(model_name_or_path, args, semaphore))
|
||||||
|
|
||||||
await sglang_server_ready()
|
await check_server_ready()
|
||||||
|
|
||||||
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
||||||
|
|
||||||
@ -778,7 +771,7 @@ async def main():
|
|||||||
# Wait for all worker tasks to finish
|
# Wait for all worker tasks to finish
|
||||||
await asyncio.gather(*worker_tasks)
|
await asyncio.gather(*worker_tasks)
|
||||||
|
|
||||||
sglang_server.cancel()
|
model_server.cancel()
|
||||||
metrics_task.cancel()
|
metrics_task.cancel()
|
||||||
logger.info("Work done")
|
logger.info("Work done")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user