Formatting fix

This commit is contained in:
Jake Poznanski 2025-02-14 19:50:19 +00:00
parent 0dcdbcc61a
commit 32aa359458
3 changed files with 73 additions and 69 deletions

View File

@ -8,23 +8,24 @@
# The url will be the result of get_uri_from_db # The url will be the result of get_uri_from_db
# Rresponse will be NormalizedEntry.text # Rresponse will be NormalizedEntry.text
import argparse import argparse
import concurrent.futures
import glob import glob
import json import json
import multiprocessing import multiprocessing
import os
import re import re
import shutil
import sqlite3 import sqlite3
import tempfile import tempfile
import os
import shutil
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict, Set from typing import Dict, List, Optional, Set, Tuple
import concurrent.futures
from urllib.parse import urlparse from urllib.parse import urlparse
import boto3 import boto3
from tqdm import tqdm
import pandas as pd import pandas as pd
from pypdf import PdfReader, PdfWriter from pypdf import PdfReader, PdfWriter
from tqdm import tqdm
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]: def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
""" """
@ -44,7 +45,7 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
return urlparse(pretty_pdf_path).path.split("/")[-1] return urlparse(pretty_pdf_path).path.split("/")[-1]
else: else:
raise NotImplementedError() raise NotImplementedError()
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]: def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
""" """
@ -58,6 +59,7 @@ def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
conn.close() conn.close()
return result[0].strip() if result and result[0] else None return result[0].strip() if result and result[0] else None
@dataclass(frozen=True) @dataclass(frozen=True)
class NormalizedEntry: class NormalizedEntry:
s3_path: str s3_path: str
@ -70,7 +72,7 @@ class NormalizedEntry:
def from_goldkey(goldkey: str, **kwargs): def from_goldkey(goldkey: str, **kwargs):
""" """
Constructs a NormalizedEntry from a goldkey string. Constructs a NormalizedEntry from a goldkey string.
The goldkey is expected to be of the format: The goldkey is expected to be of the format:
<s3_path>-<page_number> <s3_path>-<page_number>
""" """
s3_path = goldkey[: goldkey.rindex("-")] s3_path = goldkey[: goldkey.rindex("-")]
@ -81,6 +83,7 @@ class NormalizedEntry:
def goldkey(self): def goldkey(self):
return f"{self.s3_path}-{self.pagenum}" return f"{self.s3_path}-{self.pagenum}"
def normalize_json_entry(data: dict) -> NormalizedEntry: def normalize_json_entry(data: dict) -> NormalizedEntry:
""" """
Normalizes a JSON entry from any of the supported formats. Normalizes a JSON entry from any of the supported formats.
@ -117,6 +120,7 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
else: else:
raise ValueError("Unsupported JSON format") raise ValueError("Unsupported JSON format")
def parse_s3_url(s3_url: str) -> Tuple[str, str]: def parse_s3_url(s3_url: str) -> Tuple[str, str]:
""" """
Parses an S3 URL of the form s3://bucket/key and returns (bucket, key). Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
@ -127,6 +131,7 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]:
bucket, key = s3_path.split("/", 1) bucket, key = s3_path.split("/", 1)
return bucket, key return bucket, key
def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]: def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
""" """
Downloads the PDF from the given S3 URL into the specified cache directory. Downloads the PDF from the given S3 URL into the specified cache directory.
@ -135,11 +140,11 @@ def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
""" """
try: try:
bucket, key = parse_s3_url(s3_url) bucket, key = parse_s3_url(s3_url)
s3_client = boto3.client('s3') s3_client = boto3.client("s3")
pdf_hash = parse_pdf_hash(s3_url) pdf_hash = parse_pdf_hash(s3_url)
if not pdf_hash: if not pdf_hash:
# Fallback: use a sanitized version of the s3_url # Fallback: use a sanitized version of the s3_url
pdf_hash = re.sub(r'\W+', '_', s3_url) pdf_hash = re.sub(r"\W+", "_", s3_url)
dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf") dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
# Avoid re-downloading if already exists # Avoid re-downloading if already exists
if not os.path.exists(dest_path): if not os.path.exists(dest_path):
@ -149,6 +154,7 @@ def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
print(f"Error downloading {s3_url}: {e}") print(f"Error downloading {s3_url}: {e}")
return None return None
def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]: def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
""" """
Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url. Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
@ -178,6 +184,7 @@ def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf
print(f"Error processing PDF page for {s3_url} page {page_number}: {e}") print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
return None return None
def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]: def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
""" """
Process a single file and return a tuple: Process a single file and return a tuple:
@ -215,8 +222,7 @@ def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: D
# Apply filter: skip if response contains "resume" (any case) and an email or phone number. # Apply filter: skip if response contains "resume" (any case) and an email or phone number.
response_text = normalized.text if normalized.text else "" response_text = normalized.text if normalized.text else ""
if (re.search(r"resume", response_text, re.IGNORECASE) and if re.search(r"resume", response_text, re.IGNORECASE) and (re.search(email_regex, response_text) or re.search(phone_regex, response_text)):
(re.search(email_regex, response_text) or re.search(phone_regex, response_text))):
print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}") print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
continue continue
@ -254,6 +260,7 @@ def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: D
print(f"Error processing file {file_path}: {e}") print(f"Error processing file {file_path}: {e}")
return rows, missing_count return rows, missing_count
def scan_file_for_s3_urls(file_path: str) -> Set[str]: def scan_file_for_s3_urls(file_path: str) -> Set[str]:
""" """
Scans a single file and returns a set of unique S3 URLs found in the JSON entries. Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
@ -276,18 +283,15 @@ def scan_file_for_s3_urls(file_path: str) -> Set[str]:
print(f"Error reading file {file_path}: {e}") print(f"Error reading file {file_path}: {e}")
return urls return urls
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Generate a Parquet dataset file for HuggingFace upload.")
description="Generate a Parquet dataset file for HuggingFace upload."
)
parser.add_argument( parser.add_argument(
"input_dataset", "input_dataset",
help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')", help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')",
) )
parser.add_argument("db_path", help="Path to the SQLite database file.") parser.add_argument("db_path", help="Path to the SQLite database file.")
parser.add_argument( parser.add_argument("--output", default="output.parquet", help="Output Parquet file path.")
"--output", default="output.parquet", help="Output Parquet file path."
)
args = parser.parse_args() args = parser.parse_args()
@ -303,7 +307,7 @@ def main():
# Create a temporary directory for caching PDFs. # Create a temporary directory for caching PDFs.
pdf_cache_dir = "/tmp/pdf_cache" pdf_cache_dir = "/tmp/pdf_cache"
os.makedirs(pdf_cache_dir, exist_ok=True) os.makedirs(pdf_cache_dir, exist_ok=True)
print(f"Caching PDFs to temporary directory: {pdf_cache_dir}") print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
@ -323,12 +327,8 @@ def main():
pdf_cache: Dict[str, str] = {} pdf_cache: Dict[str, str] = {}
print("Caching PDFs from S3...") print("Caching PDFs from S3...")
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor: with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor:
future_to_url = { future_to_url = {executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for s3_url in unique_s3_urls}
executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(future_to_url), desc="Downloading PDFs"):
for s3_url in unique_s3_urls
}
for future in tqdm(concurrent.futures.as_completed(future_to_url),
total=len(future_to_url), desc="Downloading PDFs"):
s3_url = future_to_url[future] s3_url = future_to_url[future]
try: try:
local_path = future.result() local_path = future.result()
@ -345,12 +345,8 @@ def main():
total_missing = 0 total_missing = 0
print("Processing files...") print("Processing files...")
with concurrent.futures.ProcessPoolExecutor() as executor: with concurrent.futures.ProcessPoolExecutor() as executor:
futures = { futures = {executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files}
executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"):
for file_path in files
}
for future in tqdm(concurrent.futures.as_completed(futures),
total=len(futures), desc="Processing files"):
file_path = futures[future] file_path = futures[future]
try: try:
rows, missing_count = future.result() rows, missing_count = future.result()

View File

@ -1,23 +1,21 @@
import logging
import os import os
import tarfile import tarfile
import logging
from math import ceil
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm from math import ceil
from huggingface_hub import HfApi from huggingface_hub import HfApi
from tqdm import tqdm
# Configuration # Configuration
pdf_dir = "pdfs" # Directory with PDF files (flat structure) pdf_dir = "pdfs" # Directory with PDF files (flat structure)
tarball_dir = "tarballs" # Directory where tar.gz files will be saved tarball_dir = "tarballs" # Directory where tar.gz files will be saved
os.makedirs(tarball_dir, exist_ok=True) os.makedirs(tarball_dir, exist_ok=True)
repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID
# Set up logging to file # Set up logging to file
logging.basicConfig( logging.basicConfig(filename="upload.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
filename='upload.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def process_chunk(args): def process_chunk(args):
""" """
@ -27,7 +25,7 @@ def process_chunk(args):
chunk_index, chunk_files = args chunk_index, chunk_files = args
tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz" tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz"
tarball_path = os.path.join(tarball_dir, tarball_name) tarball_path = os.path.join(tarball_dir, tarball_name)
try: try:
with tarfile.open(tarball_path, "w:gz") as tar: with tarfile.open(tarball_path, "w:gz") as tar:
for pdf_filename in chunk_files: for pdf_filename in chunk_files:
@ -41,10 +39,11 @@ def process_chunk(args):
logging.error(error_msg) logging.error(error_msg)
return chunk_index, False, error_msg return chunk_index, False, error_msg
def main(): def main():
# List all PDF files (assuming a flat directory) # List all PDF files (assuming a flat directory)
try: try:
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith('.pdf')]) pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith(".pdf")])
except Exception as e: except Exception as e:
logging.error(f"Error listing PDFs in '{pdf_dir}': {e}") logging.error(f"Error listing PDFs in '{pdf_dir}': {e}")
return return
@ -61,7 +60,7 @@ def main():
# end = start + chunk_size # end = start + chunk_size
# chunk_files = pdf_files[start:end] # chunk_files = pdf_files[start:end]
# chunks.append((idx, chunk_files)) # chunks.append((idx, chunk_files))
# # Create tarballs in parallel # # Create tarballs in parallel
# results = [] # results = []
# with ProcessPoolExecutor() as executor: # with ProcessPoolExecutor() as executor:
@ -90,10 +89,11 @@ def main():
api.upload_large_folder( api.upload_large_folder(
folder_path=tarball_dir, folder_path=tarball_dir,
repo_id=repo_id, repo_id=repo_id,
#path_in_repo="pdf_tarballs", # path_in_repo="pdf_tarballs",
repo_type="dataset" repo_type="dataset",
) )
logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.") logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,11 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import sqlite3 import sqlite3
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import boto3 import boto3
from tqdm import tqdm from tqdm import tqdm
from warcio.archiveiterator import ArchiveIterator from warcio.archiveiterator import ArchiveIterator
from concurrent.futures import ThreadPoolExecutor
from functools import partial
def parse_s3_path(s3_path): def parse_s3_path(s3_path):
""" """
@ -19,23 +21,25 @@ def parse_s3_path(s3_path):
prefix = parts[1] if len(parts) > 1 else "" prefix = parts[1] if len(parts) > 1 else ""
return bucket, prefix return bucket, prefix
def list_s3_warc_objects(s3_path, suffix='.warc.gz'):
def list_s3_warc_objects(s3_path, suffix=".warc.gz"):
""" """
Lists all objects under the given S3 path that end with the provided suffix. Lists all objects under the given S3 path that end with the provided suffix.
Uses a paginator to handle large result sets. Uses a paginator to handle large result sets.
""" """
bucket, prefix = parse_s3_path(s3_path) bucket, prefix = parse_s3_path(s3_path)
s3_client = boto3.client('s3') s3_client = boto3.client("s3")
paginator = s3_client.get_paginator('list_objects_v2') paginator = s3_client.get_paginator("list_objects_v2")
warc_keys = [] warc_keys = []
for page in paginator.paginate(Bucket=bucket, Prefix=prefix): for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
if 'Contents' in page: if "Contents" in page:
for obj in page['Contents']: for obj in page["Contents"]:
key = obj['Key'] key = obj["Key"]
if key.endswith(suffix): if key.endswith(suffix):
warc_keys.append(key) warc_keys.append(key)
return bucket, warc_keys, s3_client return bucket, warc_keys, s3_client
def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576): def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
""" """
Retrieves the first head_bytes bytes (1 MB by default) from the S3 object using a range request, Retrieves the first head_bytes bytes (1 MB by default) from the S3 object using a range request,
@ -43,8 +47,8 @@ def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
""" """
target_uri = None target_uri = None
try: try:
response = s3_client.get_object(Bucket=bucket, Key=key, Range=f'bytes=0-{head_bytes-1}') response = s3_client.get_object(Bucket=bucket, Key=key, Range=f"bytes=0-{head_bytes-1}")
stream = response['Body'] stream = response["Body"]
for record in ArchiveIterator(stream): for record in ArchiveIterator(stream):
for name, value in record.rec_headers.headers: for name, value in record.rec_headers.headers:
if name == "WARC-Target-URI": if name == "WARC-Target-URI":
@ -56,6 +60,7 @@ def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
tqdm.write(f"Error processing s3://{bucket}/{key}: {e}") tqdm.write(f"Error processing s3://{bucket}/{key}: {e}")
return target_uri return target_uri
def create_db(db_path): def create_db(db_path):
""" """
Creates (or opens) the SQLite database and ensures that the pdf_mapping table exists, Creates (or opens) the SQLite database and ensures that the pdf_mapping table exists,
@ -63,18 +68,23 @@ def create_db(db_path):
""" """
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pdf_mapping ( CREATE TABLE IF NOT EXISTS pdf_mapping (
pdf_hash TEXT PRIMARY KEY, pdf_hash TEXT PRIMARY KEY,
uri TEXT uri TEXT
) )
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_pdf_hash ON pdf_mapping (pdf_hash) CREATE INDEX IF NOT EXISTS idx_pdf_hash ON pdf_mapping (pdf_hash)
''') """
)
conn.commit() conn.commit()
return conn return conn
def process_warc_file(key, bucket, s3_client): def process_warc_file(key, bucket, s3_client):
""" """
Processes a single WARC file from S3 and returns a tuple (pdf_hash, uri) Processes a single WARC file from S3 and returns a tuple (pdf_hash, uri)
@ -83,19 +93,20 @@ def process_warc_file(key, bucket, s3_client):
uri = extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576) uri = extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576)
if uri: if uri:
# Derive pdf_hash as the file's basename with .warc.gz replaced by .pdf. # Derive pdf_hash as the file's basename with .warc.gz replaced by .pdf.
pdf_hash = key.split('/')[-1].replace('.warc.gz', '.pdf') pdf_hash = key.split("/")[-1].replace(".warc.gz", ".pdf")
return (pdf_hash, uri) return (pdf_hash, uri)
else: else:
tqdm.write(f"Warning: No valid response record found in s3://{bucket}/{key}") tqdm.write(f"Warning: No valid response record found in s3://{bucket}/{key}")
return None return None
def process_s3_folder(s3_path, db_path): def process_s3_folder(s3_path, db_path):
""" """
Lists all .warc.gz files under the provided S3 path, then processes each file in parallel Lists all .warc.gz files under the provided S3 path, then processes each file in parallel
to extract the target URI from the HTTP headers. The resulting mapping (derived from the file's to extract the target URI from the HTTP headers. The resulting mapping (derived from the file's
basename with .warc.gz replaced by .pdf) is stored in the SQLite database. basename with .warc.gz replaced by .pdf) is stored in the SQLite database.
""" """
bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix='.warc.gz') bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix=".warc.gz")
conn = create_db(db_path) conn = create_db(db_path)
cursor = conn.cursor() cursor = conn.cursor()
@ -110,21 +121,18 @@ def process_s3_folder(s3_path, db_path):
# Bulk insert into the database. # Bulk insert into the database.
conn.execute("BEGIN") conn.execute("BEGIN")
for pdf_hash, uri in results: for pdf_hash, uri in results:
cursor.execute( cursor.execute("INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", (pdf_hash, uri))
"INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)",
(pdf_hash, uri)
)
conn.commit() conn.commit()
conn.close() conn.close()
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files.")
description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files."
)
parser.add_argument("s3_path", help="S3 path (e.g., s3://bucket/prefix) containing .warc.gz files") parser.add_argument("s3_path", help="S3 path (e.g., s3://bucket/prefix) containing .warc.gz files")
parser.add_argument("db_file", help="Path for the output SQLite database file") parser.add_argument("db_file", help="Path for the output SQLite database file")
args = parser.parse_args() args = parser.parse_args()
process_s3_folder(args.s3_path, args.db_file) process_s3_folder(args.s3_path, args.db_file)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()