Better check for separate sglang installation step

This commit is contained in:
Jake Poznanski 2025-01-28 13:56:00 -08:00
parent 00e3aac058
commit e0afb935fa
4 changed files with 21 additions and 8 deletions

2
.gitignore vendored
View File

@ -15,7 +15,7 @@ scoreelo.csv
debug.log
birrpipeline-debug.log
beakerpipeline-debug.log
olmocr-pipeline-debug.log
# build artifacts

View File

@ -1,6 +1,7 @@
import sys
import subprocess
import logging
import importlib.util
logger = logging.getLogger(__name__)
@ -17,5 +18,13 @@ def check_poppler_version():
logger.error("Check the README in the https://github.com/allenai/olmocr/blob/main/README.md for installation instructions")
sys.exit(1)
def check_sglang_version():
if importlib.util.find_spec("sglang") is None:
logger.error(f"Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
logger.error("Sglang needs to be installed with a separate command in order to find all dependencies properly.")
sys.exit(1)
if __name__ == "__main__":
check_poppler_version()
check_poppler_version()
check_sglang_version()

View File

@ -38,7 +38,7 @@ from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.filter.filter import PdfFilter, Language
from olmocr.prompts import build_finetuning_prompt, PageResponse
from olmocr.prompts.anchor import get_anchor_text
from olmocr.check import check_poppler_version
from olmocr.check import check_poppler_version, check_sglang_version
from olmocr.metrics import MetricsKeeper, WorkerTracker
from olmocr.version import VERSION
@ -50,7 +50,7 @@ logger.propagate = False
sglang_logger = logging.getLogger("sglang")
sglang_logger.propagate = False
file_handler = logging.FileHandler('beakerpipeline-debug.log', mode='a')
file_handler = logging.FileHandler('olmocr-pipeline-debug.log', mode='a')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
@ -466,6 +466,7 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
async def sglang_server_task(args, semaphore):
model_name_or_path = args.model
# if "://" in model_name_or_path:
# # TODO, Fix this code so that we support the multiple s3/weka paths, or else remove it
# model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
@ -589,7 +590,9 @@ async def sglang_server_host(args, semaphore):
retry += 1
if retry >= MAX_RETRIES:
logger.error(f"Ended up restarting the sglang server more than {retry} times, cancelling")
logger.error(f"Ended up starting the sglang server more than {retry} times, cancelling pipeline")
logger.error(f"")
logger.error(f"Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
sys.exit(1)
@ -892,6 +895,7 @@ async def main():
pdf_s3 = pdf_session.client("s3")
check_poppler_version()
check_sglang_version()
# Create work queue
if args.workspace.startswith("s3://"):

View File

@ -17,7 +17,7 @@ classifiers = [
authors = [
{name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org"}
]
requires-python = ">=3.8"
requires-python = ">=3.11"
dependencies = [
"cached-path",
"smart_open",
@ -35,9 +35,9 @@ dependencies = [
"requests",
"zstandard",
"boto3",
"torch==2.5.1",
"httpx",
"torch>=2.5.1",
"transformers>=4.46.2",
"sglang[all]==0.4.1",
"beaker-py",
]
license = {file = "LICENSE"}