mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-27 15:14:43 +00:00
Minor fixes
This commit is contained in:
parent
9ff107b7b5
commit
da1b23fc47
@ -7,6 +7,7 @@ import sys
|
||||
import time
|
||||
import subprocess
|
||||
import hashlib
|
||||
import json
|
||||
import base64
|
||||
import atexit
|
||||
import asyncio
|
||||
@ -205,7 +206,7 @@ async def process_page(session, pdf_path, page_num, args) -> PageResponse:
|
||||
|
||||
try:
|
||||
base_response_data = await response.json()
|
||||
model_response_json = orjson.loads(base_response_data["outputs"][0]["text"])
|
||||
model_response_json = json.loads(base_response_data["outputs"][0]["text"])
|
||||
page_response = PageResponse(**model_response_json)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse response for {pdf_path}-{page_num}")
|
||||
@ -239,6 +240,7 @@ async def process_pdf(args, pdf_s3_path):
|
||||
# If we failed to build a page, then this document is toast
|
||||
# TODO Abort earlier, if a page returns a None, then we can stop processing the whole pdf
|
||||
if any(page is None for page in page_results):
|
||||
logger.warning(f"PDF {pdf_s3_path} was not able to complete, not able to process a page")
|
||||
return None
|
||||
|
||||
# Build the document text and page spans
|
||||
@ -305,6 +307,17 @@ async def sglang_server_task(args):
|
||||
# TODO cache locally
|
||||
#download_directory(args.model, model_cache_dir)
|
||||
|
||||
# Check the rope config and make sure it's got the proper key
|
||||
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
|
||||
config_data = json.load(cfin)
|
||||
|
||||
if "rope_type" in config_data["rope_scaling"]:
|
||||
del config_data["rope_scaling"]["rope_type"]
|
||||
config_data["rope_scaling"]["type"] = "mrope"
|
||||
|
||||
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
|
||||
json.dump(config_data, cfout)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"python3",
|
||||
|
||||
@ -315,7 +328,12 @@ async def sglang_server_task(args):
|
||||
)
|
||||
|
||||
# Make really sure we kill this subprocess on exit
|
||||
atexit.register(lambda: proc.kill())
|
||||
def _kill_proc():
|
||||
proc.terminate()
|
||||
time.sleep(3)
|
||||
proc.kill()
|
||||
|
||||
atexit.register(_kill_proc)
|
||||
|
||||
await proc.wait()
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from google.auth import compute_engine
|
||||
from google.cloud import storage
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
from boto3.s3.transfer import TransferConfig
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
import zstandard as zstd
|
||||
@ -209,7 +210,7 @@ def download_dir_from_gcs(gcs_path: str, local_dir: str):
|
||||
def download_dir_from_s3(s3_path: str, local_dir: str):
|
||||
"""Download model files from S3 to a local directory."""
|
||||
boto3_config = Config(
|
||||
max_pool_connections=50 # Adjust this number based on your requirements
|
||||
max_pool_connections=500 # Adjust this number based on your requirements
|
||||
)
|
||||
s3_client = boto3.client('s3', config=boto3_config)
|
||||
bucket, prefix = parse_s3_path(s3_path)
|
||||
@ -251,10 +252,18 @@ def download_dir_from_weka(weka_path: str, local_dir: str):
|
||||
# Configure the boto3 client for Weka
|
||||
weka_endpoint = "https://weka-aus.beaker.org:9000"
|
||||
boto3_config = Config(
|
||||
max_pool_connections=50, # Adjust this number based on your requirements
|
||||
max_pool_connections=500, # Adjust this number based on your requirements
|
||||
signature_version='s3v4',
|
||||
retries={'max_attempts': 10, 'mode': 'standard'}
|
||||
)
|
||||
# Configure transfer settings for multipart download
|
||||
transfer_config = TransferConfig(
|
||||
multipart_threshold=8 * 1024 * 1024, # 8MB threshold for multipart downloads
|
||||
multipart_chunksize=8 * 1024 * 1024, # 8MB per part
|
||||
max_concurrency=100, # Number of threads for each file download
|
||||
use_threads=True # Enable threading
|
||||
)
|
||||
|
||||
s3_client = boto3.client(
|
||||
's3',
|
||||
endpoint_url=weka_endpoint,
|
||||
@ -263,6 +272,7 @@ def download_dir_from_weka(weka_path: str, local_dir: str):
|
||||
config=boto3_config
|
||||
)
|
||||
|
||||
|
||||
bucket, prefix = parse_s3_path(weka_path)
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
try:
|
||||
@ -285,7 +295,7 @@ def download_dir_from_weka(weka_path: str, local_dir: str):
|
||||
relative_path = os.path.relpath(key, prefix)
|
||||
local_file_path = os.path.join(local_dir, relative_path)
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path))
|
||||
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path, Config=transfer_config))
|
||||
|
||||
# Use tqdm to display progress
|
||||
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from Weka"):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user