olmocr/olmocr/s3_utils.py

404 lines
17 KiB
Python
Raw Normal View History

2025-01-29 15:25:10 -08:00
import base64
import concurrent.futures
2024-10-16 16:18:27 +00:00
import glob
2025-01-29 15:25:10 -08:00
import hashlib
2024-11-07 19:01:45 +00:00
import logging
2025-01-29 15:25:10 -08:00
import os
import posixpath
2024-11-15 13:02:38 -08:00
import time
2025-01-29 15:25:10 -08:00
from io import BytesIO, TextIOWrapper
2024-11-07 19:01:45 +00:00
from pathlib import Path
2025-01-29 15:25:10 -08:00
from typing import List, Optional
2024-10-16 16:18:27 +00:00
from urllib.parse import urlparse
2025-01-29 15:25:10 -08:00
import boto3
import requests # type: ignore
2024-11-07 19:01:45 +00:00
import zstandard as zstd
2025-01-29 15:25:10 -08:00
from boto3.s3.transfer import TransferConfig
from botocore.config import Config
2025-01-29 15:47:57 -08:00
from botocore.exceptions import ClientError
2025-01-29 15:25:10 -08:00
from google.cloud import storage
2024-11-07 19:01:45 +00:00
from tqdm import tqdm
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
2024-10-16 16:18:27 +00:00
def parse_s3_path(s3_path: str) -> tuple[str, str]:
2025-01-29 15:30:39 -08:00
if not (s3_path.startswith("s3://") or s3_path.startswith("gs://") or s3_path.startswith("weka://")):
raise ValueError("s3_path must start with s3://, gs://, or weka://")
2024-10-16 16:18:27 +00:00
parsed = urlparse(s3_path)
bucket = parsed.netloc
2025-01-29 15:30:39 -08:00
key = parsed.path.lstrip("/")
2024-10-16 16:18:27 +00:00
return bucket, key
def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
"""
Expand an S3 path that may or may not contain wildcards (e.g., *.pdf).
Returns a dict of {'s3://bucket/key': etag} for each matching object.
Raises a ValueError if nothing is found or if a bare prefix was provided by mistake.
"""
2024-10-16 16:18:27 +00:00
parsed = urlparse(s3_glob)
if not parsed.scheme.startswith("s3"):
raise ValueError("Path must start with s3://")
2024-10-16 16:18:27 +00:00
bucket = parsed.netloc
raw_path = parsed.path.lstrip("/")
prefix = posixpath.dirname(raw_path)
pattern = posixpath.basename(raw_path)
# Case 1: We have a wildcard
if any(wc in pattern for wc in ["*", "?", "[", "]"]):
if prefix and not prefix.endswith("/"):
prefix += "/"
paginator = s3_client.get_paginator("list_objects_v2")
matched = {}
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
for obj in page.get("Contents", []):
key = obj["Key"]
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)): # type: ignore
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
return matched
# Case 2: No wildcard → single file or a bare prefix
try:
# Attempt to head a single file
resp = s3_client.head_object(Bucket=bucket, Key=raw_path)
2024-10-16 16:18:27 +00:00
if resp["ContentType"] == "application/x-directory":
2025-01-29 15:30:39 -08:00
raise ValueError(f"'{s3_glob}' appears to be a folder. " f"Use a wildcard (e.g., '{s3_glob.rstrip('/')}/*.pdf') to match files.")
2024-10-16 16:18:27 +00:00
return {f"s3://{bucket}/{raw_path}": resp["ETag"].strip('"')}
except ClientError as e:
if e.response["Error"]["Code"] == "404":
# Check if it's actually a folder with contents
check_prefix = raw_path if raw_path.endswith("/") else raw_path + "/"
paginator = s3_client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=bucket, Prefix=check_prefix):
if page.get("Contents"):
2025-01-29 15:30:39 -08:00
raise ValueError(f"'{s3_glob}' appears to be a folder. " f"Use a wildcard (e.g., '{s3_glob.rstrip('/')}/*.pdf') to match files.")
raise ValueError(f"No object or prefix found at '{s3_glob}'. Check your path or add a wildcard.")
else:
raise
2024-10-16 16:18:27 +00:00
2024-11-07 19:01:45 +00:00
2024-10-16 16:18:27 +00:00
def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
2025-01-28 15:03:31 -08:00
# Fall back for local files
if os.path.exists(s3_path):
assert start_index is None and end_index is None, "Range query not supported yet"
with open(s3_path, "rb") as f:
return f.read()
2025-01-29 15:30:39 -08:00
2024-10-16 16:18:27 +00:00
bucket, key = parse_s3_path(s3_path)
# Build the range header if start_index and/or end_index are specified
range_header = None
if start_index is not None and end_index is not None:
# Range: bytes=start_index-end_index
range_value = f"bytes={start_index}-{end_index}"
2025-01-29 15:30:39 -08:00
range_header = {"Range": range_value}
2024-10-16 16:18:27 +00:00
elif start_index is not None and end_index is None:
# Range: bytes=start_index-
range_value = f"bytes={start_index}-"
2025-01-29 15:30:39 -08:00
range_header = {"Range": range_value}
2024-10-16 16:18:27 +00:00
elif start_index is None and end_index is not None:
# Range: bytes=-end_index (last end_index bytes)
range_value = f"bytes=-{end_index}"
2025-01-29 15:30:39 -08:00
range_header = {"Range": range_value}
2024-10-16 16:18:27 +00:00
if range_header:
2025-01-29 15:30:39 -08:00
obj = s3_client.get_object(Bucket=bucket, Key=key, Range=range_header["Range"])
2024-10-16 16:18:27 +00:00
else:
obj = s3_client.get_object(Bucket=bucket, Key=key)
2025-01-29 15:30:39 -08:00
return obj["Body"].read()
2024-10-16 16:18:27 +00:00
2024-11-27 19:11:20 +00:00
2024-11-15 13:02:38 -08:00
def get_s3_bytes_with_backoff(s3_client, pdf_s3_path, max_retries: int = 8, backoff_factor: int = 2):
2024-11-14 14:13:04 -08:00
attempt = 0
2024-11-15 13:02:38 -08:00
2024-11-14 14:13:04 -08:00
while attempt < max_retries:
try:
return get_s3_bytes(s3_client, pdf_s3_path)
2024-11-15 13:02:38 -08:00
except ClientError as e:
2024-12-04 17:56:45 +00:00
# Check for some error kinds AccessDenied error and raise immediately
2025-01-29 15:30:39 -08:00
if e.response["Error"]["Code"] in ("AccessDenied", "NoSuchKey"):
2024-12-04 17:56:45 +00:00
logger.error(f"{e.response['Error']['Code']} error when trying to access {pdf_s3_path}: {e}")
2024-11-15 13:02:38 -08:00
raise
else:
2025-01-29 15:30:39 -08:00
wait_time = backoff_factor**attempt
2024-11-15 13:02:38 -08:00
logger.warning(f"Attempt {attempt+1} failed to get_s3_bytes for {pdf_s3_path}: {e}. Retrying in {wait_time} seconds...")
time.sleep(wait_time)
attempt += 1
2024-11-14 14:13:04 -08:00
except Exception as e:
2025-01-29 15:30:39 -08:00
wait_time = backoff_factor**attempt
2024-11-14 14:13:04 -08:00
logger.warning(f"Attempt {attempt+1} failed to get_s3_bytes for {pdf_s3_path}: {e}. Retrying in {wait_time} seconds...")
time.sleep(wait_time)
attempt += 1
2024-11-15 13:02:38 -08:00
2024-11-14 14:13:04 -08:00
logger.error(f"Failed to get_s3_bytes for {pdf_s3_path} after {max_retries} retries.")
raise Exception("Failed to get_s3_bytes after retries")
2024-11-07 19:01:45 +00:00
2024-11-27 19:11:20 +00:00
2024-10-16 16:18:27 +00:00
def put_s3_bytes(s3_client, s3_path: str, data: bytes):
bucket, key = parse_s3_path(s3_path)
2025-01-29 15:30:39 -08:00
s3_client.put_object(Bucket=bucket, Key=key, Body=data, ContentType="text/plain; charset=utf-8")
2024-11-07 19:01:45 +00:00
def parse_custom_id(custom_id: str) -> tuple[str, int]:
2025-01-29 15:30:39 -08:00
s3_path = custom_id[: custom_id.rindex("-")]
page_num = int(custom_id[custom_id.rindex("-") + 1 :])
2024-11-07 19:01:45 +00:00
return s3_path, page_num
def download_zstd_csv(s3_client, s3_path):
"""Download and decompress a .zstd CSV file from S3."""
try:
compressed_data = get_s3_bytes(s3_client, s3_path)
dctx = zstd.ZstdDecompressor()
decompressed = dctx.decompress(compressed_data)
2025-01-29 15:30:39 -08:00
text_stream = TextIOWrapper(BytesIO(decompressed), encoding="utf-8")
2024-11-07 19:01:45 +00:00
lines = text_stream.readlines()
logger.info(f"Downloaded and decompressed {s3_path}")
return lines
except s3_client.exceptions.NoSuchKey:
logger.info(f"No existing {s3_path} found in s3, starting fresh.")
return []
def upload_zstd_csv(s3_client, s3_path, lines):
"""Compress and upload a list of lines as a .zstd CSV file to S3."""
joined_text = "\n".join(lines)
compressor = zstd.ZstdCompressor()
2025-01-29 15:30:39 -08:00
compressed = compressor.compress(joined_text.encode("utf-8"))
2024-11-07 19:01:45 +00:00
put_s3_bytes(s3_client, s3_path, compressed)
logger.info(f"Uploaded compressed {s3_path}")
def is_running_on_gcp():
"""Check if the script is running on a Google Cloud Platform (GCP) instance."""
try:
# GCP metadata server URL to check instance information
response = requests.get(
2025-01-29 15:30:39 -08:00
"http://metadata.google.internal/computeMetadata/v1/instance/", headers={"Metadata-Flavor": "Google"}, timeout=1 # Set a short timeout
2024-11-07 19:01:45 +00:00
)
return response.status_code == 200
except requests.RequestException:
return False
2024-11-27 19:11:20 +00:00
2024-11-12 15:18:04 -08:00
def download_directory(model_choices: List[str], local_dir: str):
2024-11-07 19:01:45 +00:00
"""
Download the model to a specified local directory.
The function will attempt to download from the first available source in the provided list.
Supports Weka (weka://), Google Cloud Storage (gs://), and Amazon S3 (s3://) links.
2024-11-07 19:01:45 +00:00
Args:
2024-11-12 15:18:04 -08:00
model_choices (List[str]): List of model paths (weka://, gs://, or s3://).
2024-11-07 19:01:45 +00:00
local_dir (str): Local directory path where the model will be downloaded.
Raises:
ValueError: If no valid model path is found in the provided choices.
"""
local_path = Path(os.path.expanduser(local_dir))
local_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Local directory set to: {local_path}")
# Reorder model_choices to prioritize weka:// links
weka_choices = [path for path in model_choices if path.startswith("weka://")]
2024-11-14 14:52:19 -08:00
# This is so hacky, but if you are on beaker/pluto, don't use weka
2024-11-18 13:07:27 -08:00
if os.environ.get("BEAKER_NODE_HOSTNAME", "").lower().startswith("pluto") or os.environ.get("BEAKER_NODE_HOSTNAME", "").lower().startswith("augusta"):
2024-11-14 14:52:19 -08:00
weka_choices = []
other_choices = [path for path in model_choices if not path.startswith("weka://")]
prioritized_choices = weka_choices + other_choices
for model_path in prioritized_choices:
2024-11-07 19:01:45 +00:00
logger.info(f"Attempting to download from: {model_path}")
try:
if model_path.startswith("weka://"):
2025-01-29 15:30:39 -08:00
download_dir_from_storage(model_path, str(local_path), storage_type="weka")
logger.info(f"Successfully downloaded model from Weka: {model_path}")
return
elif model_path.startswith("gs://"):
2025-01-29 15:30:39 -08:00
download_dir_from_storage(model_path, str(local_path), storage_type="gcs")
2024-11-07 19:01:45 +00:00
logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}")
return
elif model_path.startswith("s3://"):
2025-01-29 15:30:39 -08:00
download_dir_from_storage(model_path, str(local_path), storage_type="s3")
2024-11-07 19:01:45 +00:00
logger.info(f"Successfully downloaded model from S3: {model_path}")
return
else:
logger.warning(f"Unsupported model path scheme: {model_path}")
except Exception as e:
logger.error(f"Failed to download from {model_path}: {e}")
2024-11-12 15:18:04 -08:00
continue
2024-11-07 19:01:45 +00:00
raise ValueError("Failed to download the model from all provided sources.")
2024-11-12 15:18:04 -08:00
def download_dir_from_storage(storage_path: str, local_dir: str, storage_type: str):
"""
Generalized function to download model files from different storage services
to a local directory, syncing using MD5 hashes where possible.
2024-11-07 19:01:45 +00:00
2024-11-12 15:18:04 -08:00
Args:
storage_path (str): The path to the storage location (weka://, gs://, or s3://).
local_dir (str): The local directory where files will be downloaded.
storage_type (str): Type of storage ('weka', 'gcs', or 's3').
2024-11-07 19:01:45 +00:00
2024-11-12 15:18:04 -08:00
Raises:
ValueError: If the storage type is unsupported or credentials are missing.
"""
bucket_name, prefix = parse_s3_path(storage_path)
total_files = 0
2024-11-07 19:01:45 +00:00
objects = []
2025-01-29 15:30:39 -08:00
if storage_type == "gcs":
2024-11-12 15:18:04 -08:00
client = storage.Client()
bucket = client.bucket(bucket_name)
blobs = list(bucket.list_blobs(prefix=prefix))
total_files = len(blobs)
logger.info(f"Found {total_files} files in GCS bucket '{bucket_name}' with prefix '{prefix}'.")
def should_download(blob, local_file_path):
return compare_hashes_gcs(blob, local_file_path)
def download_blob(blob, local_file_path):
2024-11-27 19:11:20 +00:00
try:
blob.download_to_filename(local_file_path)
logger.info(f"Successfully downloaded {blob.name} to {local_file_path}")
except Exception as e:
logger.error(f"Failed to download {blob.name} to {local_file_path}: {e}")
raise
2024-11-12 15:18:04 -08:00
items = blobs
2025-01-29 15:30:39 -08:00
elif storage_type in ("s3", "weka"):
if storage_type == "weka":
2024-11-12 15:18:04 -08:00
weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID")
weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY")
if not weka_access_key or not weka_secret_key:
raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY must be set for Weka access.")
endpoint_url = "https://weka-aus.beaker.org:9000"
2025-01-29 15:30:39 -08:00
boto3_config = Config(max_pool_connections=500, signature_version="s3v4", retries={"max_attempts": 10, "mode": "standard"})
2024-11-12 15:18:04 -08:00
s3_client = boto3.client(
2025-01-29 15:30:39 -08:00
"s3", endpoint_url=endpoint_url, aws_access_key_id=weka_access_key, aws_secret_access_key=weka_secret_key, config=boto3_config
2024-11-12 15:18:04 -08:00
)
else:
2025-01-29 15:30:39 -08:00
s3_client = boto3.client("s3", config=Config(max_pool_connections=500))
2024-11-12 15:18:04 -08:00
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
for page in pages:
2025-01-29 15:30:39 -08:00
if "Contents" in page:
objects.extend(page["Contents"])
2024-11-27 19:11:20 +00:00
else:
logger.warning(f"No contents found in page: {page}")
2024-11-12 15:18:04 -08:00
total_files = len(objects)
logger.info(f"Found {total_files} files in {'Weka' if storage_type == 'weka' else 'S3'} bucket '{bucket_name}' with prefix '{prefix}'.")
transfer_config = TransferConfig(
2025-01-29 15:30:39 -08:00
multipart_threshold=8 * 1024 * 1024, multipart_chunksize=8 * 1024 * 1024, max_concurrency=10, use_threads=True # Reduced for WekaFS compatibility
2024-11-12 15:18:04 -08:00
)
2024-11-11 10:24:47 -08:00
2024-11-12 15:18:04 -08:00
def should_download(obj, local_file_path):
2024-11-27 19:11:20 +00:00
return compare_hashes_s3(obj, local_file_path, storage_type)
2024-11-12 15:18:04 -08:00
def download_blob(obj, local_file_path):
2024-11-27 19:11:20 +00:00
logger.info(f"Starting download of {obj['Key']} to {local_file_path}")
try:
2025-01-29 15:30:39 -08:00
with open(local_file_path, "wb") as f:
s3_client.download_fileobj(bucket_name, obj["Key"], f, Config=transfer_config)
2024-11-27 19:11:20 +00:00
logger.info(f"Successfully downloaded {obj['Key']} to {local_file_path}")
except Exception as e:
logger.error(f"Failed to download {obj['Key']} to {local_file_path}: {e}")
raise
2024-11-11 10:24:47 -08:00
2024-11-12 15:18:04 -08:00
items = objects
else:
raise ValueError(f"Unsupported storage type: {storage_type}")
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
2024-11-12 15:18:04 -08:00
for item in items:
2025-01-29 15:30:39 -08:00
if storage_type == "gcs":
2024-11-12 15:18:04 -08:00
relative_path = os.path.relpath(item.name, prefix)
else:
2025-01-29 15:30:39 -08:00
relative_path = os.path.relpath(item["Key"], prefix)
local_file_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
2024-11-12 15:18:04 -08:00
if should_download(item, local_file_path):
futures.append(executor.submit(download_blob, item, local_file_path))
else:
total_files -= 1 # Decrement total_files as we're skipping this file
if total_files > 0:
2024-11-27 19:11:20 +00:00
for future in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc=f"Downloading from {storage_type.upper()}"):
try:
future.result()
except Exception as e:
logger.error(f"Error occurred during download: {e}")
2024-11-12 15:18:04 -08:00
else:
logger.info("All files are up-to-date. No downloads needed.")
logger.info(f"Downloaded model from {storage_type.upper()} to {local_dir}")
def compare_hashes_gcs(blob, local_file_path: str) -> bool:
"""Compare MD5 hashes for GCS blobs."""
if os.path.exists(local_file_path):
remote_md5_base64 = blob.md5_hash
hash_md5 = hashlib.md5()
with open(local_file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
local_md5 = hash_md5.digest()
remote_md5 = base64.b64decode(remote_md5_base64)
if remote_md5 == local_md5:
logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' differs from GCS. Downloading.")
return True
else:
logger.info(f"File '{local_file_path}' does not exist locally. Downloading.")
return True
2024-11-27 19:11:20 +00:00
def compare_hashes_s3(obj, local_file_path: str, storage_type: str) -> bool:
2024-11-12 15:18:04 -08:00
"""Compare MD5 hashes or sizes for S3 objects (including Weka)."""
if os.path.exists(local_file_path):
2025-01-29 15:30:39 -08:00
if storage_type == "weka":
2024-11-27 19:11:20 +00:00
return True
2024-11-12 15:18:04 -08:00
else:
2025-01-29 15:30:39 -08:00
etag = obj["ETag"].strip('"')
if "-" in etag:
2024-11-27 19:11:20 +00:00
# Multipart upload, compare sizes
2025-01-29 15:30:39 -08:00
remote_size = obj["Size"]
2024-11-27 19:11:20 +00:00
local_size = os.path.getsize(local_file_path)
if remote_size == local_size:
logger.info(f"File '{local_file_path}' size matches remote multipart file. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' size differs from remote multipart file. Downloading.")
return True
2024-11-12 15:18:04 -08:00
else:
2024-11-27 19:11:20 +00:00
hash_md5 = hashlib.md5()
with open(local_file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
local_md5 = hash_md5.hexdigest()
if etag == local_md5:
logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' differs from remote. Downloading.")
return True
2024-11-12 15:18:04 -08:00
else:
logger.info(f"File '{local_file_path}' does not exist locally. Downloading.")
2024-11-27 19:11:20 +00:00
return True