Minor fixes

This commit is contained in:
Jake Poznanski 2024-11-11 10:24:47 -08:00
parent 9ff107b7b5
commit da1b23fc47
2 changed files with 33 additions and 5 deletions

View File

@ -7,6 +7,7 @@ import sys
import time
import subprocess
import hashlib
import json
import base64
import atexit
import asyncio
@ -205,7 +206,7 @@ async def process_page(session, pdf_path, page_num, args) -> PageResponse:
try:
base_response_data = await response.json()
model_response_json = orjson.loads(base_response_data["outputs"][0]["text"])
model_response_json = json.loads(base_response_data["outputs"][0]["text"])
page_response = PageResponse(**model_response_json)
except Exception as e:
logger.warning(f"Could not parse response for {pdf_path}-{page_num}")
@ -239,6 +240,7 @@ async def process_pdf(args, pdf_s3_path):
# If we failed to build a page, then this document is toast
# TODO Abort earlier, if a page returns a None, then we can stop processing the whole pdf
if any(page is None for page in page_results):
logger.warning(f"PDF {pdf_s3_path} was not able to complete, not able to process a page")
return None
# Build the document text and page spans
@ -305,6 +307,17 @@ async def sglang_server_task(args):
# TODO cache locally
#download_directory(args.model, model_cache_dir)
# Check the rope config and make sure it's got the proper key
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
config_data = json.load(cfin)
if "rope_type" in config_data["rope_scaling"]:
del config_data["rope_scaling"]["rope_type"]
config_data["rope_scaling"]["type"] = "mrope"
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
json.dump(config_data, cfout)
proc = await asyncio.create_subprocess_exec(
"python3",
@ -315,7 +328,12 @@ async def sglang_server_task(args):
)
# Make really sure we kill this subprocess on exit
atexit.register(lambda: proc.kill())
def _kill_proc():
proc.terminate()
time.sleep(3)
proc.kill()
atexit.register(_kill_proc)
await proc.wait()

View File

@ -13,6 +13,7 @@ from google.auth import compute_engine
from google.cloud import storage
from botocore.config import Config
from botocore.exceptions import NoCredentialsError
from boto3.s3.transfer import TransferConfig
from typing import Optional
from urllib.parse import urlparse
import zstandard as zstd
@ -209,7 +210,7 @@ def download_dir_from_gcs(gcs_path: str, local_dir: str):
def download_dir_from_s3(s3_path: str, local_dir: str):
"""Download model files from S3 to a local directory."""
boto3_config = Config(
max_pool_connections=50 # Adjust this number based on your requirements
max_pool_connections=500 # Adjust this number based on your requirements
)
s3_client = boto3.client('s3', config=boto3_config)
bucket, prefix = parse_s3_path(s3_path)
@ -251,10 +252,18 @@ def download_dir_from_weka(weka_path: str, local_dir: str):
# Configure the boto3 client for Weka
weka_endpoint = "https://weka-aus.beaker.org:9000"
boto3_config = Config(
max_pool_connections=50, # Adjust this number based on your requirements
max_pool_connections=500, # Adjust this number based on your requirements
signature_version='s3v4',
retries={'max_attempts': 10, 'mode': 'standard'}
)
# Configure transfer settings for multipart download
transfer_config = TransferConfig(
multipart_threshold=8 * 1024 * 1024, # 8MB threshold for multipart downloads
multipart_chunksize=8 * 1024 * 1024, # 8MB per part
max_concurrency=100, # Number of threads for each file download
use_threads=True # Enable threading
)
s3_client = boto3.client(
's3',
endpoint_url=weka_endpoint,
@ -263,6 +272,7 @@ def download_dir_from_weka(weka_path: str, local_dir: str):
config=boto3_config
)
bucket, prefix = parse_s3_path(weka_path)
paginator = s3_client.get_paginator("list_objects_v2")
try:
@ -285,7 +295,7 @@ def download_dir_from_weka(weka_path: str, local_dir: str):
relative_path = os.path.relpath(key, prefix)
local_file_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path))
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path, Config=transfer_config))
# Use tqdm to display progress
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from Weka"):