hfupload scripts

This commit is contained in:
Jake Poznanski 2025-02-14 17:36:00 +00:00
parent 8297955290
commit 6583fb641a
4 changed files with 229 additions and 0 deletions

View File

View File

@ -0,0 +1,99 @@
import os
import tarfile
import logging
from math import ceil
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from huggingface_hub import HfApi
# Configuration
pdf_dir = "pdfs" # Directory with PDF files (flat structure)
tarball_dir = "tarballs" # Directory where tar.gz files will be saved
os.makedirs(tarball_dir, exist_ok=True)
repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID
# Set up logging to file
logging.basicConfig(
filename='upload.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def process_chunk(args):
"""
Worker function to create a tar.gz file for a given chunk.
Returns a tuple: (chunk_index, success (bool), message).
"""
chunk_index, chunk_files = args
tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz"
tarball_path = os.path.join(tarball_dir, tarball_name)
try:
with tarfile.open(tarball_path, "w:gz") as tar:
for pdf_filename in chunk_files:
pdf_path = os.path.join(pdf_dir, pdf_filename)
# Add the file with its basename to maintain a flat structure
tar.add(pdf_path, arcname=pdf_filename)
logging.info(f"Chunk {chunk_index:04d}: Created '{tarball_name}' with {len(chunk_files)} PDFs.")
return chunk_index, True, "Success"
except Exception as e:
error_msg = f"Chunk {chunk_index:04d}: Error creating '{tarball_name}': {e}"
logging.error(error_msg)
return chunk_index, False, error_msg
def main():
# List all PDF files (assuming a flat directory)
try:
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith('.pdf')])
except Exception as e:
logging.error(f"Error listing PDFs in '{pdf_dir}': {e}")
return
total_files = len(pdf_files)
chunk_size = 5000
total_chunks = ceil(total_files / chunk_size)
logging.info(f"Found {total_files} PDFs; dividing into {total_chunks} chunks of up to {chunk_size} files each.")
# # Enumerate chunks (starting at 0000)
# chunks = []
# for idx in range(total_chunks):
# start = idx * chunk_size
# end = start + chunk_size
# chunk_files = pdf_files[start:end]
# chunks.append((idx, chunk_files))
# # Create tarballs in parallel
# results = []
# with ProcessPoolExecutor() as executor:
# futures = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
# for future in tqdm(as_completed(futures), total=len(futures), desc="Creating tarballs"):
# try:
# result = future.result()
# results.append(result)
# chunk_index, success, message = result
# if not success:
# logging.error(f"Chunk {chunk_index:04d} failed: {message}")
# except Exception as e:
# logging.error(f"Unexpected error processing a chunk: {e}")
# # Abort upload if any tarball creation failed
# failed_chunks = [r for r in results if not r[1]]
# if failed_chunks:
# logging.error(f"{len(failed_chunks)} chunk(s) failed to create. Aborting upload.")
# return
# All tarballs created successfully; now upload the entire tarball directory
api = HfApi()
logging.info("Starting upload of tarballs folder to Hugging Face Hub...")
# This will upload all files in tarball_dir to the repo under "pdf_tarballs"
api.upload_large_folder(
folder_path=tarball_dir,
repo_id=repo_id,
#path_in_repo="pdf_tarballs",
repo_type="dataset"
)
logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,130 @@
#!/usr/bin/env python3
import argparse
import sqlite3
import boto3
from tqdm import tqdm
from warcio.archiveiterator import ArchiveIterator
from concurrent.futures import ThreadPoolExecutor
from functools import partial
def parse_s3_path(s3_path):
"""
Parses an S3 path of the form s3://bucket/prefix and returns the bucket and prefix.
"""
if not s3_path.startswith("s3://"):
raise ValueError("S3 path must start with s3://")
without_prefix = s3_path[5:]
parts = without_prefix.split("/", 1)
bucket = parts[0]
prefix = parts[1] if len(parts) > 1 else ""
return bucket, prefix
def list_s3_warc_objects(s3_path, suffix='.warc.gz'):
"""
Lists all objects under the given S3 path that end with the provided suffix.
Uses a paginator to handle large result sets.
"""
bucket, prefix = parse_s3_path(s3_path)
s3_client = boto3.client('s3')
paginator = s3_client.get_paginator('list_objects_v2')
warc_keys = []
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
if 'Contents' in page:
for obj in page['Contents']:
key = obj['Key']
if key.endswith(suffix):
warc_keys.append(key)
return bucket, warc_keys, s3_client
def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
"""
Retrieves the first head_bytes bytes (1 MB by default) from the S3 object using a range request,
and extracts the first response record's target URI from the HTTP headers.
"""
target_uri = None
try:
response = s3_client.get_object(Bucket=bucket, Key=key, Range=f'bytes=0-{head_bytes-1}')
stream = response['Body']
for record in ArchiveIterator(stream):
for name, value in record.rec_headers.headers:
if name == "WARC-Target-URI":
target_uri = value
break
if target_uri:
break # Only use the first valid response record
except Exception as e:
tqdm.write(f"Error processing s3://{bucket}/{key}: {e}")
return target_uri
def create_db(db_path):
"""
Creates (or opens) the SQLite database and ensures that the pdf_mapping table exists,
including an index on pdf_hash.
"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS pdf_mapping (
pdf_hash TEXT PRIMARY KEY,
uri TEXT
)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_pdf_hash ON pdf_mapping (pdf_hash)
''')
conn.commit()
return conn
def process_warc_file(key, bucket, s3_client):
"""
Processes a single WARC file from S3 and returns a tuple (pdf_hash, uri)
if successful, otherwise returns None.
"""
uri = extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576)
if uri:
# Derive pdf_hash as the file's basename with .warc.gz replaced by .pdf.
pdf_hash = key.split('/')[-1].replace('.warc.gz', '.pdf')
return (pdf_hash, uri)
else:
tqdm.write(f"Warning: No valid response record found in s3://{bucket}/{key}")
return None
def process_s3_folder(s3_path, db_path):
"""
Lists all .warc.gz files under the provided S3 path, then processes each file in parallel
to extract the target URI from the HTTP headers. The resulting mapping (derived from the file's
basename with .warc.gz replaced by .pdf) is stored in the SQLite database.
"""
bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix='.warc.gz')
conn = create_db(db_path)
cursor = conn.cursor()
# Process WARC files concurrently using ThreadPoolExecutor.
results = []
func = partial(process_warc_file, bucket=bucket, s3_client=s3_client)
with ThreadPoolExecutor() as executor:
for result in tqdm(executor.map(func, warc_keys), total=len(warc_keys), desc="Processing S3 WARC files"):
if result is not None:
results.append(result)
# Bulk insert into the database.
conn.execute("BEGIN")
for pdf_hash, uri in results:
cursor.execute(
"INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)",
(pdf_hash, uri)
)
conn.commit()
conn.close()
def main():
parser = argparse.ArgumentParser(
description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files."
)
parser.add_argument("s3_path", help="S3 path (e.g., s3://bucket/prefix) containing .warc.gz files")
parser.add_argument("db_file", help="Path for the output SQLite database file")
args = parser.parse_args()
process_s3_folder(args.s3_path, args.db_file)
if __name__ == '__main__':
main()