diff --git a/olmocr/train/hf/convertjsontoparquet.py b/olmocr/train/hf/convertjsontoparquet.py index 85d7462..7ffac5c 100644 --- a/olmocr/train/hf/convertjsontoparquet.py +++ b/olmocr/train/hf/convertjsontoparquet.py @@ -8,23 +8,24 @@ # The url will be the result of get_uri_from_db # Rresponse will be NormalizedEntry.text import argparse +import concurrent.futures import glob import json import multiprocessing +import os import re +import shutil import sqlite3 import tempfile -import os -import shutil from dataclasses import dataclass -from typing import Optional, List, Tuple, Dict, Set -import concurrent.futures +from typing import Dict, List, Optional, Set, Tuple from urllib.parse import urlparse import boto3 -from tqdm import tqdm import pandas as pd from pypdf import PdfReader, PdfWriter +from tqdm import tqdm + def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]: """ @@ -44,7 +45,7 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]: return urlparse(pretty_pdf_path).path.split("/")[-1] else: raise NotImplementedError() - + def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]: """ @@ -58,6 +59,7 @@ def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]: conn.close() return result[0].strip() if result and result[0] else None + @dataclass(frozen=True) class NormalizedEntry: s3_path: str @@ -70,7 +72,7 @@ class NormalizedEntry: def from_goldkey(goldkey: str, **kwargs): """ Constructs a NormalizedEntry from a goldkey string. - The goldkey is expected to be of the format: + The goldkey is expected to be of the format: - """ s3_path = goldkey[: goldkey.rindex("-")] @@ -81,6 +83,7 @@ class NormalizedEntry: def goldkey(self): return f"{self.s3_path}-{self.pagenum}" + def normalize_json_entry(data: dict) -> NormalizedEntry: """ Normalizes a JSON entry from any of the supported formats. @@ -117,6 +120,7 @@ def normalize_json_entry(data: dict) -> NormalizedEntry: else: raise ValueError("Unsupported JSON format") + def parse_s3_url(s3_url: str) -> Tuple[str, str]: """ Parses an S3 URL of the form s3://bucket/key and returns (bucket, key). @@ -127,6 +131,7 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]: bucket, key = s3_path.split("/", 1) return bucket, key + def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]: """ Downloads the PDF from the given S3 URL into the specified cache directory. @@ -135,11 +140,11 @@ def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]: """ try: bucket, key = parse_s3_url(s3_url) - s3_client = boto3.client('s3') + s3_client = boto3.client("s3") pdf_hash = parse_pdf_hash(s3_url) if not pdf_hash: # Fallback: use a sanitized version of the s3_url - pdf_hash = re.sub(r'\W+', '_', s3_url) + pdf_hash = re.sub(r"\W+", "_", s3_url) dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf") # Avoid re-downloading if already exists if not os.path.exists(dest_path): @@ -149,6 +154,7 @@ def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]: print(f"Error downloading {s3_url}: {e}") return None + def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]: """ Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url. @@ -178,6 +184,7 @@ def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf print(f"Error processing PDF page for {s3_url} page {page_number}: {e}") return None + def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]: """ Process a single file and return a tuple: @@ -215,8 +222,7 @@ def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: D # Apply filter: skip if response contains "resume" (any case) and an email or phone number. response_text = normalized.text if normalized.text else "" - if (re.search(r"resume", response_text, re.IGNORECASE) and - (re.search(email_regex, response_text) or re.search(phone_regex, response_text))): + if re.search(r"resume", response_text, re.IGNORECASE) and (re.search(email_regex, response_text) or re.search(phone_regex, response_text)): print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}") continue @@ -254,6 +260,7 @@ def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: D print(f"Error processing file {file_path}: {e}") return rows, missing_count + def scan_file_for_s3_urls(file_path: str) -> Set[str]: """ Scans a single file and returns a set of unique S3 URLs found in the JSON entries. @@ -276,18 +283,15 @@ def scan_file_for_s3_urls(file_path: str) -> Set[str]: print(f"Error reading file {file_path}: {e}") return urls + def main(): - parser = argparse.ArgumentParser( - description="Generate a Parquet dataset file for HuggingFace upload." - ) + parser = argparse.ArgumentParser(description="Generate a Parquet dataset file for HuggingFace upload.") parser.add_argument( "input_dataset", help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')", ) parser.add_argument("db_path", help="Path to the SQLite database file.") - parser.add_argument( - "--output", default="output.parquet", help="Output Parquet file path." - ) + parser.add_argument("--output", default="output.parquet", help="Output Parquet file path.") args = parser.parse_args() @@ -303,7 +307,7 @@ def main(): # Create a temporary directory for caching PDFs. pdf_cache_dir = "/tmp/pdf_cache" os.makedirs(pdf_cache_dir, exist_ok=True) - + print(f"Caching PDFs to temporary directory: {pdf_cache_dir}") # --------------------------------------------------------------------- @@ -323,12 +327,8 @@ def main(): pdf_cache: Dict[str, str] = {} print("Caching PDFs from S3...") with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor: - future_to_url = { - executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url - for s3_url in unique_s3_urls - } - for future in tqdm(concurrent.futures.as_completed(future_to_url), - total=len(future_to_url), desc="Downloading PDFs"): + future_to_url = {executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for s3_url in unique_s3_urls} + for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(future_to_url), desc="Downloading PDFs"): s3_url = future_to_url[future] try: local_path = future.result() @@ -345,12 +345,8 @@ def main(): total_missing = 0 print("Processing files...") with concurrent.futures.ProcessPoolExecutor() as executor: - futures = { - executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path - for file_path in files - } - for future in tqdm(concurrent.futures.as_completed(futures), - total=len(futures), desc="Processing files"): + futures = {executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files} + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"): file_path = futures[future] try: rows, missing_count = future.result() diff --git a/olmocr/train/hf/hfhub_upload.py b/olmocr/train/hf/hfhub_upload.py index b9eb710..3e7a6a8 100644 --- a/olmocr/train/hf/hfhub_upload.py +++ b/olmocr/train/hf/hfhub_upload.py @@ -1,23 +1,21 @@ +import logging import os import tarfile -import logging -from math import ceil from concurrent.futures import ProcessPoolExecutor, as_completed -from tqdm import tqdm +from math import ceil + from huggingface_hub import HfApi +from tqdm import tqdm # Configuration -pdf_dir = "pdfs" # Directory with PDF files (flat structure) -tarball_dir = "tarballs" # Directory where tar.gz files will be saved +pdf_dir = "pdfs" # Directory with PDF files (flat structure) +tarball_dir = "tarballs" # Directory where tar.gz files will be saved os.makedirs(tarball_dir, exist_ok=True) repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID # Set up logging to file -logging.basicConfig( - filename='upload.log', - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' -) +logging.basicConfig(filename="upload.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + def process_chunk(args): """ @@ -27,7 +25,7 @@ def process_chunk(args): chunk_index, chunk_files = args tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz" tarball_path = os.path.join(tarball_dir, tarball_name) - + try: with tarfile.open(tarball_path, "w:gz") as tar: for pdf_filename in chunk_files: @@ -41,10 +39,11 @@ def process_chunk(args): logging.error(error_msg) return chunk_index, False, error_msg + def main(): # List all PDF files (assuming a flat directory) try: - pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith('.pdf')]) + pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith(".pdf")]) except Exception as e: logging.error(f"Error listing PDFs in '{pdf_dir}': {e}") return @@ -61,7 +60,7 @@ def main(): # end = start + chunk_size # chunk_files = pdf_files[start:end] # chunks.append((idx, chunk_files)) - + # # Create tarballs in parallel # results = [] # with ProcessPoolExecutor() as executor: @@ -90,10 +89,11 @@ def main(): api.upload_large_folder( folder_path=tarball_dir, repo_id=repo_id, - #path_in_repo="pdf_tarballs", - repo_type="dataset" + # path_in_repo="pdf_tarballs", + repo_type="dataset", ) logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.") + if __name__ == "__main__": main() diff --git a/olmocr/train/hf/warc_parser.py b/olmocr/train/hf/warc_parser.py index 7d0775b..715c5e4 100644 --- a/olmocr/train/hf/warc_parser.py +++ b/olmocr/train/hf/warc_parser.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 import argparse import sqlite3 +from concurrent.futures import ThreadPoolExecutor +from functools import partial + import boto3 from tqdm import tqdm from warcio.archiveiterator import ArchiveIterator -from concurrent.futures import ThreadPoolExecutor -from functools import partial + def parse_s3_path(s3_path): """ @@ -19,23 +21,25 @@ def parse_s3_path(s3_path): prefix = parts[1] if len(parts) > 1 else "" return bucket, prefix -def list_s3_warc_objects(s3_path, suffix='.warc.gz'): + +def list_s3_warc_objects(s3_path, suffix=".warc.gz"): """ Lists all objects under the given S3 path that end with the provided suffix. Uses a paginator to handle large result sets. """ bucket, prefix = parse_s3_path(s3_path) - s3_client = boto3.client('s3') - paginator = s3_client.get_paginator('list_objects_v2') + s3_client = boto3.client("s3") + paginator = s3_client.get_paginator("list_objects_v2") warc_keys = [] for page in paginator.paginate(Bucket=bucket, Prefix=prefix): - if 'Contents' in page: - for obj in page['Contents']: - key = obj['Key'] + if "Contents" in page: + for obj in page["Contents"]: + key = obj["Key"] if key.endswith(suffix): warc_keys.append(key) return bucket, warc_keys, s3_client + def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576): """ Retrieves the first head_bytes bytes (1 MB by default) from the S3 object using a range request, @@ -43,8 +47,8 @@ def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576): """ target_uri = None try: - response = s3_client.get_object(Bucket=bucket, Key=key, Range=f'bytes=0-{head_bytes-1}') - stream = response['Body'] + response = s3_client.get_object(Bucket=bucket, Key=key, Range=f"bytes=0-{head_bytes-1}") + stream = response["Body"] for record in ArchiveIterator(stream): for name, value in record.rec_headers.headers: if name == "WARC-Target-URI": @@ -56,6 +60,7 @@ def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576): tqdm.write(f"Error processing s3://{bucket}/{key}: {e}") return target_uri + def create_db(db_path): """ Creates (or opens) the SQLite database and ensures that the pdf_mapping table exists, @@ -63,18 +68,23 @@ def create_db(db_path): """ conn = sqlite3.connect(db_path) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS pdf_mapping ( pdf_hash TEXT PRIMARY KEY, uri TEXT ) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_pdf_hash ON pdf_mapping (pdf_hash) - ''') + """ + ) conn.commit() return conn + def process_warc_file(key, bucket, s3_client): """ Processes a single WARC file from S3 and returns a tuple (pdf_hash, uri) @@ -83,19 +93,20 @@ def process_warc_file(key, bucket, s3_client): uri = extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576) if uri: # Derive pdf_hash as the file's basename with .warc.gz replaced by .pdf. - pdf_hash = key.split('/')[-1].replace('.warc.gz', '.pdf') + pdf_hash = key.split("/")[-1].replace(".warc.gz", ".pdf") return (pdf_hash, uri) else: tqdm.write(f"Warning: No valid response record found in s3://{bucket}/{key}") return None + def process_s3_folder(s3_path, db_path): """ Lists all .warc.gz files under the provided S3 path, then processes each file in parallel to extract the target URI from the HTTP headers. The resulting mapping (derived from the file's basename with .warc.gz replaced by .pdf) is stored in the SQLite database. """ - bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix='.warc.gz') + bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix=".warc.gz") conn = create_db(db_path) cursor = conn.cursor() @@ -110,21 +121,18 @@ def process_s3_folder(s3_path, db_path): # Bulk insert into the database. conn.execute("BEGIN") for pdf_hash, uri in results: - cursor.execute( - "INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", - (pdf_hash, uri) - ) + cursor.execute("INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", (pdf_hash, uri)) conn.commit() conn.close() + def main(): - parser = argparse.ArgumentParser( - description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files." - ) + parser = argparse.ArgumentParser(description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files.") parser.add_argument("s3_path", help="S3 path (e.g., s3://bucket/prefix) containing .warc.gz files") parser.add_argument("db_file", help="Path for the output SQLite database file") args = parser.parse_args() process_s3_folder(args.s3_path, args.db_file) -if __name__ == '__main__': + +if __name__ == "__main__": main()