Small corrections

This commit is contained in:
Jake Poznanski 2025-04-24 20:31:59 +00:00
parent df71dc38ce
commit 2d5e1838f4

View File

@ -31,7 +31,7 @@ from olmocr.check import (
check_sglang_version,
check_torch_gpu_available,
)
from olmocr.metrics import MetricsKeeper, WorkerTracker
from olmocr.metrics import MetricsKeeper
from olmocr.s3_utils import (
download_directory,
expand_s3_glob,
@ -68,7 +68,6 @@ SGLANG_SERVER_PORT = 30024
# Global variables for token statistics
metrics = MetricsKeeper(window=60 * 5)
tracker = WorkerTracker()
class PIIClassification(BaseModel):
@ -79,8 +78,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
"""Helper function to process a single document or page."""
text = page_text
# Count the attempt up-front
metrics.add_metrics(sglang_documents=1)
metrics.add_metrics(sglang_requests=1)
query = {
"model": "google/gemma-3-4b-it",
@ -330,7 +328,6 @@ async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, work
file_uri = work_item.work_paths[0]
logger.info(f"Worker {worker_id} processing {file_uri}")
await tracker.clear_work(worker_id)
try:
# ------------------------------------------------------------------
@ -551,7 +548,6 @@ async def metrics_reporter(work_queue):
# Leading newlines preserve table formatting in logs
logger.info(f"Queue remaining: {work_queue.size}")
logger.info("\n" + str(metrics))
logger.info("\n" + str(await tracker.get_status_table()))
await asyncio.sleep(10)
@ -692,15 +688,17 @@ async def main():
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
if "BEAKER_JOB_ID" in os.environ:
sglang_logger.addHandler(console_handler)
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
cred_path = os.path.join(os.path.expanduser("~"), ".gcs", "credentials")
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE"))
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
if "AWS_CREDENTIALS_FILE" in os.environ:
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
cred_path = os.path.join(os.path.expanduser("~"), ".gcs", "credentials")
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE"))
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
workspace_s3 = boto3.client("s3")
# Wait a little bit so that not all beaker jobs in a task start at the same time and download the model at the same time