Formatting fix

This commit is contained in:
Jake Poznanski 2025-02-14 19:50:19 +00:00
parent 0dcdbcc61a
commit 32aa359458
3 changed files with 73 additions and 69 deletions

View File

@ -8,23 +8,24 @@
# The url will be the result of get_uri_from_db
# Rresponse will be NormalizedEntry.text
import argparse
import concurrent.futures
import glob
import json
import multiprocessing
import os
import re
import shutil
import sqlite3
import tempfile
import os
import shutil
from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict, Set
import concurrent.futures
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse
import boto3
from tqdm import tqdm
import pandas as pd
from pypdf import PdfReader, PdfWriter
from tqdm import tqdm
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
"""
@ -44,7 +45,7 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
return urlparse(pretty_pdf_path).path.split("/")[-1]
else:
raise NotImplementedError()
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
"""
@ -58,6 +59,7 @@ def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
conn.close()
return result[0].strip() if result and result[0] else None
@dataclass(frozen=True)
class NormalizedEntry:
s3_path: str
@ -70,7 +72,7 @@ class NormalizedEntry:
def from_goldkey(goldkey: str, **kwargs):
"""
Constructs a NormalizedEntry from a goldkey string.
The goldkey is expected to be of the format:
The goldkey is expected to be of the format:
<s3_path>-<page_number>
"""
s3_path = goldkey[: goldkey.rindex("-")]
@ -81,6 +83,7 @@ class NormalizedEntry:
def goldkey(self):
return f"{self.s3_path}-{self.pagenum}"
def normalize_json_entry(data: dict) -> NormalizedEntry:
"""
Normalizes a JSON entry from any of the supported formats.
@ -117,6 +120,7 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
else:
raise ValueError("Unsupported JSON format")
def parse_s3_url(s3_url: str) -> Tuple[str, str]:
"""
Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
@ -127,6 +131,7 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]:
bucket, key = s3_path.split("/", 1)
return bucket, key
def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
"""
Downloads the PDF from the given S3 URL into the specified cache directory.
@ -135,11 +140,11 @@ def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
"""
try:
bucket, key = parse_s3_url(s3_url)
s3_client = boto3.client('s3')
s3_client = boto3.client("s3")
pdf_hash = parse_pdf_hash(s3_url)
if not pdf_hash:
# Fallback: use a sanitized version of the s3_url
pdf_hash = re.sub(r'\W+', '_', s3_url)
pdf_hash = re.sub(r"\W+", "_", s3_url)
dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
# Avoid re-downloading if already exists
if not os.path.exists(dest_path):
@ -149,6 +154,7 @@ def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
print(f"Error downloading {s3_url}: {e}")
return None
def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
"""
Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
@ -178,6 +184,7 @@ def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf
print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
return None
def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
"""
Process a single file and return a tuple:
@ -215,8 +222,7 @@ def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: D
# Apply filter: skip if response contains "resume" (any case) and an email or phone number.
response_text = normalized.text if normalized.text else ""
if (re.search(r"resume", response_text, re.IGNORECASE) and
(re.search(email_regex, response_text) or re.search(phone_regex, response_text))):
if re.search(r"resume", response_text, re.IGNORECASE) and (re.search(email_regex, response_text) or re.search(phone_regex, response_text)):
print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
continue
@ -254,6 +260,7 @@ def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: D
print(f"Error processing file {file_path}: {e}")
return rows, missing_count
def scan_file_for_s3_urls(file_path: str) -> Set[str]:
"""
Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
@ -276,18 +283,15 @@ def scan_file_for_s3_urls(file_path: str) -> Set[str]:
print(f"Error reading file {file_path}: {e}")
return urls
def main():
parser = argparse.ArgumentParser(
description="Generate a Parquet dataset file for HuggingFace upload."
)
parser = argparse.ArgumentParser(description="Generate a Parquet dataset file for HuggingFace upload.")
parser.add_argument(
"input_dataset",
help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')",
)
parser.add_argument("db_path", help="Path to the SQLite database file.")
parser.add_argument(
"--output", default="output.parquet", help="Output Parquet file path."
)
parser.add_argument("--output", default="output.parquet", help="Output Parquet file path.")
args = parser.parse_args()
@ -303,7 +307,7 @@ def main():
# Create a temporary directory for caching PDFs.
pdf_cache_dir = "/tmp/pdf_cache"
os.makedirs(pdf_cache_dir, exist_ok=True)
print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")
# ---------------------------------------------------------------------
@ -323,12 +327,8 @@ def main():
pdf_cache: Dict[str, str] = {}
print("Caching PDFs from S3...")
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor:
future_to_url = {
executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url
for s3_url in unique_s3_urls
}
for future in tqdm(concurrent.futures.as_completed(future_to_url),
total=len(future_to_url), desc="Downloading PDFs"):
future_to_url = {executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for s3_url in unique_s3_urls}
for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(future_to_url), desc="Downloading PDFs"):
s3_url = future_to_url[future]
try:
local_path = future.result()
@ -345,12 +345,8 @@ def main():
total_missing = 0
print("Processing files...")
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = {
executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path
for file_path in files
}
for future in tqdm(concurrent.futures.as_completed(futures),
total=len(futures), desc="Processing files"):
futures = {executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files}
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"):
file_path = futures[future]
try:
rows, missing_count = future.result()

View File

@ -1,23 +1,21 @@
import logging
import os
import tarfile
import logging
from math import ceil
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from math import ceil
from huggingface_hub import HfApi
from tqdm import tqdm
# Configuration
pdf_dir = "pdfs" # Directory with PDF files (flat structure)
tarball_dir = "tarballs" # Directory where tar.gz files will be saved
pdf_dir = "pdfs" # Directory with PDF files (flat structure)
tarball_dir = "tarballs" # Directory where tar.gz files will be saved
os.makedirs(tarball_dir, exist_ok=True)
repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID
# Set up logging to file
logging.basicConfig(
filename='upload.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logging.basicConfig(filename="upload.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def process_chunk(args):
"""
@ -27,7 +25,7 @@ def process_chunk(args):
chunk_index, chunk_files = args
tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz"
tarball_path = os.path.join(tarball_dir, tarball_name)
try:
with tarfile.open(tarball_path, "w:gz") as tar:
for pdf_filename in chunk_files:
@ -41,10 +39,11 @@ def process_chunk(args):
logging.error(error_msg)
return chunk_index, False, error_msg
def main():
# List all PDF files (assuming a flat directory)
try:
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith('.pdf')])
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith(".pdf")])
except Exception as e:
logging.error(f"Error listing PDFs in '{pdf_dir}': {e}")
return
@ -61,7 +60,7 @@ def main():
# end = start + chunk_size
# chunk_files = pdf_files[start:end]
# chunks.append((idx, chunk_files))
# # Create tarballs in parallel
# results = []
# with ProcessPoolExecutor() as executor:
@ -90,10 +89,11 @@ def main():
api.upload_large_folder(
folder_path=tarball_dir,
repo_id=repo_id,
#path_in_repo="pdf_tarballs",
repo_type="dataset"
# path_in_repo="pdf_tarballs",
repo_type="dataset",
)
logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.")
if __name__ == "__main__":
main()

View File

@ -1,11 +1,13 @@
#!/usr/bin/env python3
import argparse
import sqlite3
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import boto3
from tqdm import tqdm
from warcio.archiveiterator import ArchiveIterator
from concurrent.futures import ThreadPoolExecutor
from functools import partial
def parse_s3_path(s3_path):
"""
@ -19,23 +21,25 @@ def parse_s3_path(s3_path):
prefix = parts[1] if len(parts) > 1 else ""
return bucket, prefix
def list_s3_warc_objects(s3_path, suffix='.warc.gz'):
def list_s3_warc_objects(s3_path, suffix=".warc.gz"):
"""
Lists all objects under the given S3 path that end with the provided suffix.
Uses a paginator to handle large result sets.
"""
bucket, prefix = parse_s3_path(s3_path)
s3_client = boto3.client('s3')
paginator = s3_client.get_paginator('list_objects_v2')
s3_client = boto3.client("s3")
paginator = s3_client.get_paginator("list_objects_v2")
warc_keys = []
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
if 'Contents' in page:
for obj in page['Contents']:
key = obj['Key']
if "Contents" in page:
for obj in page["Contents"]:
key = obj["Key"]
if key.endswith(suffix):
warc_keys.append(key)
return bucket, warc_keys, s3_client
def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
"""
Retrieves the first head_bytes bytes (1 MB by default) from the S3 object using a range request,
@ -43,8 +47,8 @@ def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
"""
target_uri = None
try:
response = s3_client.get_object(Bucket=bucket, Key=key, Range=f'bytes=0-{head_bytes-1}')
stream = response['Body']
response = s3_client.get_object(Bucket=bucket, Key=key, Range=f"bytes=0-{head_bytes-1}")
stream = response["Body"]
for record in ArchiveIterator(stream):
for name, value in record.rec_headers.headers:
if name == "WARC-Target-URI":
@ -56,6 +60,7 @@ def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
tqdm.write(f"Error processing s3://{bucket}/{key}: {e}")
return target_uri
def create_db(db_path):
"""
Creates (or opens) the SQLite database and ensures that the pdf_mapping table exists,
@ -63,18 +68,23 @@ def create_db(db_path):
"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pdf_mapping (
pdf_hash TEXT PRIMARY KEY,
uri TEXT
)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_pdf_hash ON pdf_mapping (pdf_hash)
''')
"""
)
conn.commit()
return conn
def process_warc_file(key, bucket, s3_client):
"""
Processes a single WARC file from S3 and returns a tuple (pdf_hash, uri)
@ -83,19 +93,20 @@ def process_warc_file(key, bucket, s3_client):
uri = extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576)
if uri:
# Derive pdf_hash as the file's basename with .warc.gz replaced by .pdf.
pdf_hash = key.split('/')[-1].replace('.warc.gz', '.pdf')
pdf_hash = key.split("/")[-1].replace(".warc.gz", ".pdf")
return (pdf_hash, uri)
else:
tqdm.write(f"Warning: No valid response record found in s3://{bucket}/{key}")
return None
def process_s3_folder(s3_path, db_path):
"""
Lists all .warc.gz files under the provided S3 path, then processes each file in parallel
to extract the target URI from the HTTP headers. The resulting mapping (derived from the file's
basename with .warc.gz replaced by .pdf) is stored in the SQLite database.
"""
bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix='.warc.gz')
bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix=".warc.gz")
conn = create_db(db_path)
cursor = conn.cursor()
@ -110,21 +121,18 @@ def process_s3_folder(s3_path, db_path):
# Bulk insert into the database.
conn.execute("BEGIN")
for pdf_hash, uri in results:
cursor.execute(
"INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)",
(pdf_hash, uri)
)
cursor.execute("INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", (pdf_hash, uri))
conn.commit()
conn.close()
def main():
parser = argparse.ArgumentParser(
description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files."
)
parser = argparse.ArgumentParser(description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files.")
parser.add_argument("s3_path", help="S3 path (e.g., s3://bucket/prefix) containing .warc.gz files")
parser.add_argument("db_file", help="Path for the output SQLite database file")
args = parser.parse_args()
process_s3_folder(args.s3_path, args.db_file)
if __name__ == '__main__':
if __name__ == "__main__":
main()