mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-13 16:21:16 +00:00
Downloads from s3 based on hash
This commit is contained in:
parent
6598e2dc45
commit
910c2ebcfc
@ -398,8 +398,7 @@ async def worker(args, queue, semaphore, worker_id):
|
|||||||
|
|
||||||
async def sglang_server_task(args, semaphore):
|
async def sglang_server_task(args, semaphore):
|
||||||
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
||||||
# TODO cache locally
|
download_directory(args.model, model_cache_dir)
|
||||||
#download_directory(args.model, model_cache_dir)
|
|
||||||
|
|
||||||
# Check the rope config and make sure it's got the proper key
|
# Check the rope config and make sure it's got the proper key
|
||||||
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
|
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
|
||||||
@ -484,6 +483,7 @@ async def sglang_server_task(args, semaphore):
|
|||||||
async def sglang_server_host(args, semaphore):
|
async def sglang_server_host(args, semaphore):
|
||||||
while True:
|
while True:
|
||||||
await sglang_server_task(args, semaphore)
|
await sglang_server_task(args, semaphore)
|
||||||
|
logger.warning("SGLang server task ended")
|
||||||
|
|
||||||
|
|
||||||
async def sglang_server_ready():
|
async def sglang_server_ready():
|
||||||
@ -525,7 +525,7 @@ async def main():
|
|||||||
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
|
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
|
||||||
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
|
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
|
||||||
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
|
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
|
||||||
parser.add_argument('--workers', type=int, default=3, help='Number of workers to run at a time')
|
parser.add_argument('--workers', type=int, default=5, help='Number of workers to run at a time')
|
||||||
|
|
||||||
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
|
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
|
||||||
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
|
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import tempfile
|
|||||||
import boto3
|
import boto3
|
||||||
import requests
|
import requests
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import hashlib # Added for MD5 hash computation
|
||||||
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -14,7 +15,7 @@ from google.cloud import storage
|
|||||||
from botocore.config import Config
|
from botocore.config import Config
|
||||||
from botocore.exceptions import NoCredentialsError
|
from botocore.exceptions import NoCredentialsError
|
||||||
from boto3.s3.transfer import TransferConfig
|
from boto3.s3.transfer import TransferConfig
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import zstandard as zstd
|
import zstandard as zstd
|
||||||
from io import BytesIO, TextIOWrapper
|
from io import BytesIO, TextIOWrapper
|
||||||
@ -133,21 +134,19 @@ def is_running_on_gcp():
|
|||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def download_directory(model_choices: List[str], local_dir: str):
|
||||||
def download_directory(model_choices: list[str], local_dir: str):
|
|
||||||
"""
|
"""
|
||||||
Download the model to a specified local directory.
|
Download the model to a specified local directory.
|
||||||
The function will attempt to download from the first available source in the provided list.
|
The function will attempt to download from the first available source in the provided list.
|
||||||
Supports Weka (weka://), Google Cloud Storage (gs://), and Amazon S3 (s3://) links.
|
Supports Weka (weka://), Google Cloud Storage (gs://), and Amazon S3 (s3://) links.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_choices (list[str]): List of model paths (weka://, gs://, or s3://).
|
model_choices (List[str]): List of model paths (weka://, gs://, or s3://).
|
||||||
local_dir (str): Local directory path where the model will be downloaded.
|
local_dir (str): Local directory path where the model will be downloaded.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no valid model path is found in the provided choices.
|
ValueError: If no valid model path is found in the provided choices.
|
||||||
"""
|
"""
|
||||||
# Ensure the local directory exists
|
|
||||||
local_path = Path(os.path.expanduser(local_dir))
|
local_path = Path(os.path.expanduser(local_dir))
|
||||||
local_path.mkdir(parents=True, exist_ok=True)
|
local_path.mkdir(parents=True, exist_ok=True)
|
||||||
logger.info(f"Local directory set to: {local_path}")
|
logger.info(f"Local directory set to: {local_path}")
|
||||||
@ -157,148 +156,180 @@ def download_directory(model_choices: list[str], local_dir: str):
|
|||||||
other_choices = [path for path in model_choices if not path.startswith("weka://")]
|
other_choices = [path for path in model_choices if not path.startswith("weka://")]
|
||||||
prioritized_choices = weka_choices + other_choices
|
prioritized_choices = weka_choices + other_choices
|
||||||
|
|
||||||
# Iterate through the provided choices and attempt to download from the first available source
|
|
||||||
for model_path in prioritized_choices:
|
for model_path in prioritized_choices:
|
||||||
logger.info(f"Attempting to download from: {model_path}")
|
logger.info(f"Attempting to download from: {model_path}")
|
||||||
try:
|
try:
|
||||||
if model_path.startswith("weka://"):
|
if model_path.startswith("weka://"):
|
||||||
download_dir_from_weka(model_path, str(local_path))
|
download_dir_from_storage(
|
||||||
|
model_path, str(local_path), storage_type='weka')
|
||||||
logger.info(f"Successfully downloaded model from Weka: {model_path}")
|
logger.info(f"Successfully downloaded model from Weka: {model_path}")
|
||||||
return
|
return
|
||||||
elif model_path.startswith("gs://"):
|
elif model_path.startswith("gs://"):
|
||||||
download_dir_from_gcs(model_path, str(local_path))
|
download_dir_from_storage(
|
||||||
|
model_path, str(local_path), storage_type='gcs')
|
||||||
logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}")
|
logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}")
|
||||||
return
|
return
|
||||||
elif model_path.startswith("s3://"):
|
elif model_path.startswith("s3://"):
|
||||||
download_dir_from_s3(model_path, str(local_path))
|
download_dir_from_storage(
|
||||||
|
model_path, str(local_path), storage_type='s3')
|
||||||
logger.info(f"Successfully downloaded model from S3: {model_path}")
|
logger.info(f"Successfully downloaded model from S3: {model_path}")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unsupported model path scheme: {model_path}")
|
logger.warning(f"Unsupported model path scheme: {model_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to download from {model_path}: {e}")
|
logger.error(f"Failed to download from {model_path}: {e}")
|
||||||
continue # Try the next available source
|
continue
|
||||||
|
|
||||||
raise ValueError("Failed to download the model from all provided sources.")
|
raise ValueError("Failed to download the model from all provided sources.")
|
||||||
|
|
||||||
|
|
||||||
def download_dir_from_gcs(gcs_path: str, local_dir: str):
|
def download_dir_from_storage(storage_path: str, local_dir: str, storage_type: str):
|
||||||
"""Download model files from Google Cloud Storage to a local directory."""
|
"""
|
||||||
client = storage.Client()
|
Generalized function to download model files from different storage services
|
||||||
bucket_name, prefix = parse_s3_path(gcs_path.replace("gs://", "s3://"))
|
to a local directory, syncing using MD5 hashes where possible.
|
||||||
bucket = client.bucket(bucket_name)
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path (str): The path to the storage location (weka://, gs://, or s3://).
|
||||||
|
local_dir (str): The local directory where files will be downloaded.
|
||||||
|
storage_type (str): Type of storage ('weka', 'gcs', or 's3').
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the storage type is unsupported or credentials are missing.
|
||||||
|
"""
|
||||||
|
bucket_name, prefix = parse_s3_path(storage_path)
|
||||||
|
total_files = 0
|
||||||
|
objects = []
|
||||||
|
|
||||||
|
if storage_type == 'gcs':
|
||||||
|
client = storage.Client()
|
||||||
|
bucket = client.bucket(bucket_name)
|
||||||
blobs = list(bucket.list_blobs(prefix=prefix))
|
blobs = list(bucket.list_blobs(prefix=prefix))
|
||||||
total_files = len(blobs)
|
total_files = len(blobs)
|
||||||
logger.info(f"Found {total_files} files in GCS bucket '{bucket_name}' with prefix '{prefix}'.")
|
logger.info(f"Found {total_files} files in GCS bucket '{bucket_name}' with prefix '{prefix}'.")
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
def should_download(blob, local_file_path):
|
||||||
futures = []
|
return compare_hashes_gcs(blob, local_file_path)
|
||||||
for blob in blobs:
|
|
||||||
relative_path = os.path.relpath(blob.name, prefix)
|
|
||||||
local_file_path = os.path.join(local_dir, relative_path)
|
|
||||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
|
||||||
futures.append(executor.submit(blob.download_to_filename, local_file_path))
|
|
||||||
|
|
||||||
# Use tqdm to display progress
|
def download_blob(blob, local_file_path):
|
||||||
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from GCS"):
|
blob.download_to_filename(local_file_path)
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info(f"Downloaded model from Google Cloud Storage to {local_dir}")
|
items = blobs
|
||||||
|
elif storage_type in ('s3', 'weka'):
|
||||||
|
if storage_type == 'weka':
|
||||||
def download_dir_from_s3(s3_path: str, local_dir: str):
|
|
||||||
"""Download model files from S3 to a local directory."""
|
|
||||||
boto3_config = Config(
|
|
||||||
max_pool_connections=500 # Adjust this number based on your requirements
|
|
||||||
)
|
|
||||||
s3_client = boto3.client('s3', config=boto3_config)
|
|
||||||
bucket, prefix = parse_s3_path(s3_path)
|
|
||||||
paginator = s3_client.get_paginator("list_objects_v2")
|
|
||||||
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
|
||||||
|
|
||||||
objects = []
|
|
||||||
for page in pages:
|
|
||||||
if 'Contents' in page:
|
|
||||||
objects.extend(page['Contents'])
|
|
||||||
|
|
||||||
total_files = len(objects)
|
|
||||||
logger.info(f"Found {total_files} files in S3 bucket '{bucket}' with prefix '{prefix}'.")
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
futures = []
|
|
||||||
for obj in objects:
|
|
||||||
key = obj["Key"]
|
|
||||||
relative_path = os.path.relpath(key, prefix)
|
|
||||||
local_file_path = os.path.join(local_dir, relative_path)
|
|
||||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
|
||||||
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path))
|
|
||||||
|
|
||||||
# Use tqdm to display progress
|
|
||||||
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from S3"):
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info(f"Downloaded model from S3 to {local_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
def download_dir_from_weka(weka_path: str, local_dir: str):
|
|
||||||
"""Download model files from Weka to a local directory."""
|
|
||||||
# Retrieve Weka credentials from environment variables
|
|
||||||
weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID")
|
weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID")
|
||||||
weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY")
|
weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY")
|
||||||
if not weka_access_key or not weka_secret_key:
|
if not weka_access_key or not weka_secret_key:
|
||||||
raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY environment variables must be set for Weka access.")
|
raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY must be set for Weka access.")
|
||||||
|
endpoint_url = "https://weka-aus.beaker.org:9000"
|
||||||
# Configure the boto3 client for Weka
|
|
||||||
weka_endpoint = "https://weka-aus.beaker.org:9000"
|
|
||||||
boto3_config = Config(
|
boto3_config = Config(
|
||||||
max_pool_connections=500, # Adjust this number based on your requirements
|
max_pool_connections=500,
|
||||||
signature_version='s3v4',
|
signature_version='s3v4',
|
||||||
retries={'max_attempts': 10, 'mode': 'standard'}
|
retries={'max_attempts': 10, 'mode': 'standard'}
|
||||||
)
|
)
|
||||||
# Configure transfer settings for multipart download
|
|
||||||
transfer_config = TransferConfig(
|
|
||||||
multipart_threshold=8 * 1024 * 1024, # 8MB threshold for multipart downloads
|
|
||||||
multipart_chunksize=8 * 1024 * 1024, # 8MB per part
|
|
||||||
max_concurrency=100, # Number of threads for each file download
|
|
||||||
use_threads=True # Enable threading
|
|
||||||
)
|
|
||||||
|
|
||||||
s3_client = boto3.client(
|
s3_client = boto3.client(
|
||||||
's3',
|
's3',
|
||||||
endpoint_url=weka_endpoint,
|
endpoint_url=endpoint_url,
|
||||||
aws_access_key_id=weka_access_key,
|
aws_access_key_id=weka_access_key,
|
||||||
aws_secret_access_key=weka_secret_key,
|
aws_secret_access_key=weka_secret_key,
|
||||||
config=boto3_config
|
config=boto3_config
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
s3_client = boto3.client('s3', config=Config(max_pool_connections=500))
|
||||||
|
|
||||||
|
|
||||||
bucket, prefix = parse_s3_path(weka_path)
|
|
||||||
paginator = s3_client.get_paginator("list_objects_v2")
|
paginator = s3_client.get_paginator("list_objects_v2")
|
||||||
try:
|
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||||
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
|
||||||
except s3_client.exceptions.NoSuchBucket:
|
|
||||||
raise ValueError(f"The bucket '{bucket}' does not exist in Weka.")
|
|
||||||
|
|
||||||
objects = []
|
|
||||||
for page in pages:
|
for page in pages:
|
||||||
if 'Contents' in page:
|
if 'Contents' in page:
|
||||||
objects.extend(page['Contents'])
|
objects.extend(page['Contents'])
|
||||||
|
|
||||||
total_files = len(objects)
|
total_files = len(objects)
|
||||||
logger.info(f"Found {total_files} files in Weka bucket '{bucket}' with prefix '{prefix}'.")
|
logger.info(f"Found {total_files} files in {'Weka' if storage_type == 'weka' else 'S3'} bucket '{bucket_name}' with prefix '{prefix}'.")
|
||||||
|
|
||||||
|
transfer_config = TransferConfig(
|
||||||
|
multipart_threshold=8 * 1024 * 1024,
|
||||||
|
multipart_chunksize=8 * 1024 * 1024,
|
||||||
|
max_concurrency=100,
|
||||||
|
use_threads=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_download(obj, local_file_path):
|
||||||
|
return compare_hashes_s3(obj, local_file_path)
|
||||||
|
|
||||||
|
def download_blob(obj, local_file_path):
|
||||||
|
s3_client.download_file(bucket_name, obj['Key'], local_file_path, Config=transfer_config)
|
||||||
|
|
||||||
|
items = objects
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported storage type: {storage_type}")
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
futures = []
|
futures = []
|
||||||
for obj in objects:
|
for item in items:
|
||||||
key = obj["Key"]
|
if storage_type == 'gcs':
|
||||||
relative_path = os.path.relpath(key, prefix)
|
relative_path = os.path.relpath(item.name, prefix)
|
||||||
|
else:
|
||||||
|
relative_path = os.path.relpath(item['Key'], prefix)
|
||||||
local_file_path = os.path.join(local_dir, relative_path)
|
local_file_path = os.path.join(local_dir, relative_path)
|
||||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||||
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path, Config=transfer_config))
|
if should_download(item, local_file_path):
|
||||||
|
futures.append(executor.submit(download_blob, item, local_file_path))
|
||||||
|
else:
|
||||||
|
total_files -= 1 # Decrement total_files as we're skipping this file
|
||||||
|
|
||||||
# Use tqdm to display progress
|
if total_files > 0:
|
||||||
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from Weka"):
|
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc=f"Downloading from {storage_type.upper()}"):
|
||||||
pass
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("All files are up-to-date. No downloads needed.")
|
||||||
|
|
||||||
logger.info(f"Downloaded model from Weka to {local_dir}")
|
logger.info(f"Downloaded model from {storage_type.upper()} to {local_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def compare_hashes_gcs(blob, local_file_path: str) -> bool:
|
||||||
|
"""Compare MD5 hashes for GCS blobs."""
|
||||||
|
if os.path.exists(local_file_path):
|
||||||
|
remote_md5_base64 = blob.md5_hash
|
||||||
|
hash_md5 = hashlib.md5()
|
||||||
|
with open(local_file_path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(8192), b""):
|
||||||
|
hash_md5.update(chunk)
|
||||||
|
local_md5 = hash_md5.digest()
|
||||||
|
remote_md5 = base64.b64decode(remote_md5_base64)
|
||||||
|
if remote_md5 == local_md5:
|
||||||
|
logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info(f"File '{local_file_path}' differs from GCS. Downloading.")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.info(f"File '{local_file_path}' does not exist locally. Downloading.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def compare_hashes_s3(obj, local_file_path: str) -> bool:
|
||||||
|
"""Compare MD5 hashes or sizes for S3 objects (including Weka)."""
|
||||||
|
if os.path.exists(local_file_path):
|
||||||
|
etag = obj['ETag'].strip('"')
|
||||||
|
if '-' in etag:
|
||||||
|
remote_size = obj['Size']
|
||||||
|
local_size = os.path.getsize(local_file_path)
|
||||||
|
if remote_size == local_size:
|
||||||
|
logger.info(f"File '{local_file_path}' size matches remote multipart file. Skipping download.")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info(f"File '{local_file_path}' size differs from remote multipart file. Downloading.")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
hash_md5 = hashlib.md5()
|
||||||
|
with open(local_file_path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(8192), b""):
|
||||||
|
hash_md5.update(chunk)
|
||||||
|
local_md5 = hash_md5.hexdigest()
|
||||||
|
if etag == local_md5:
|
||||||
|
logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info(f"File '{local_file_path}' differs from remote. Downloading.")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.info(f"File '{local_file_path}' does not exist locally. Downloading.")
|
||||||
|
return True
|
||||||
Loading…
x
Reference in New Issue
Block a user