From e0afb935fab4be61ab5a0d59de5a2859ddb893ab Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 28 Jan 2025 13:56:00 -0800 Subject: [PATCH] Better check for separate sglang installation step --- .gitignore | 2 +- olmocr/check.py | 11 ++++++++++- olmocr/pipeline.py | 10 +++++++--- pyproject.toml | 6 +++--- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 3eec6d3..d38a9b9 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,7 @@ scoreelo.csv debug.log birrpipeline-debug.log beakerpipeline-debug.log - +olmocr-pipeline-debug.log # build artifacts diff --git a/olmocr/check.py b/olmocr/check.py index e08a57c..0abdf2e 100644 --- a/olmocr/check.py +++ b/olmocr/check.py @@ -1,6 +1,7 @@ import sys import subprocess import logging +import importlib.util logger = logging.getLogger(__name__) @@ -17,5 +18,13 @@ def check_poppler_version(): logger.error("Check the README in the https://github.com/allenai/olmocr/blob/main/README.md for installation instructions") sys.exit(1) +def check_sglang_version(): + if importlib.util.find_spec("sglang") is None: + logger.error(f"Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html") + logger.error("Sglang needs to be installed with a separate command in order to find all dependencies properly.") + sys.exit(1) + + if __name__ == "__main__": - check_poppler_version() \ No newline at end of file + check_poppler_version() + check_sglang_version() \ No newline at end of file diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 11c8388..e4d1966 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -38,7 +38,7 @@ from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.filter.filter import PdfFilter, Language from olmocr.prompts import build_finetuning_prompt, PageResponse from olmocr.prompts.anchor import get_anchor_text -from olmocr.check import check_poppler_version +from olmocr.check import check_poppler_version, check_sglang_version from olmocr.metrics import MetricsKeeper, WorkerTracker from olmocr.version import VERSION @@ -50,7 +50,7 @@ logger.propagate = False sglang_logger = logging.getLogger("sglang") sglang_logger.propagate = False -file_handler = logging.FileHandler('beakerpipeline-debug.log', mode='a') +file_handler = logging.FileHandler('olmocr-pipeline-debug.log', mode='a') file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) @@ -466,6 +466,7 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id): async def sglang_server_task(args, semaphore): model_name_or_path = args.model + # if "://" in model_name_or_path: # # TODO, Fix this code so that we support the multiple s3/weka paths, or else remove it # model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model') @@ -589,7 +590,9 @@ async def sglang_server_host(args, semaphore): retry += 1 if retry >= MAX_RETRIES: - logger.error(f"Ended up restarting the sglang server more than {retry} times, cancelling") + logger.error(f"Ended up starting the sglang server more than {retry} times, cancelling pipeline") + logger.error(f"") + logger.error(f"Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html") sys.exit(1) @@ -892,6 +895,7 @@ async def main(): pdf_s3 = pdf_session.client("s3") check_poppler_version() + check_sglang_version() # Create work queue if args.workspace.startswith("s3://"): diff --git a/pyproject.toml b/pyproject.toml index a78a89e..ddd9749 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ authors = [ {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org"} ] -requires-python = ">=3.8" +requires-python = ">=3.11" dependencies = [ "cached-path", "smart_open", @@ -35,9 +35,9 @@ dependencies = [ "requests", "zstandard", "boto3", - "torch==2.5.1", + "httpx", + "torch>=2.5.1", "transformers>=4.46.2", - "sglang[all]==0.4.1", "beaker-py", ] license = {file = "LICENSE"}