mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-01 11:21:41 +00:00
Formatting fix
This commit is contained in:
parent
0dcdbcc61a
commit
32aa359458
@ -8,23 +8,24 @@
|
||||
# The url will be the result of get_uri_from_db
|
||||
# Rresponse will be NormalizedEntry.text
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import glob
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple, Dict, Set
|
||||
import concurrent.futures
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
||||
"""
|
||||
@ -44,7 +45,7 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
||||
return urlparse(pretty_pdf_path).path.split("/")[-1]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
||||
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
|
||||
"""
|
||||
@ -58,6 +59,7 @@ def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
|
||||
conn.close()
|
||||
return result[0].strip() if result and result[0] else None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizedEntry:
|
||||
s3_path: str
|
||||
@ -70,7 +72,7 @@ class NormalizedEntry:
|
||||
def from_goldkey(goldkey: str, **kwargs):
|
||||
"""
|
||||
Constructs a NormalizedEntry from a goldkey string.
|
||||
The goldkey is expected to be of the format:
|
||||
The goldkey is expected to be of the format:
|
||||
<s3_path>-<page_number>
|
||||
"""
|
||||
s3_path = goldkey[: goldkey.rindex("-")]
|
||||
@ -81,6 +83,7 @@ class NormalizedEntry:
|
||||
def goldkey(self):
|
||||
return f"{self.s3_path}-{self.pagenum}"
|
||||
|
||||
|
||||
def normalize_json_entry(data: dict) -> NormalizedEntry:
|
||||
"""
|
||||
Normalizes a JSON entry from any of the supported formats.
|
||||
@ -117,6 +120,7 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
|
||||
else:
|
||||
raise ValueError("Unsupported JSON format")
|
||||
|
||||
|
||||
def parse_s3_url(s3_url: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
|
||||
@ -127,6 +131,7 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]:
|
||||
bucket, key = s3_path.split("/", 1)
|
||||
return bucket, key
|
||||
|
||||
|
||||
def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
|
||||
"""
|
||||
Downloads the PDF from the given S3 URL into the specified cache directory.
|
||||
@ -135,11 +140,11 @@ def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
|
||||
"""
|
||||
try:
|
||||
bucket, key = parse_s3_url(s3_url)
|
||||
s3_client = boto3.client('s3')
|
||||
s3_client = boto3.client("s3")
|
||||
pdf_hash = parse_pdf_hash(s3_url)
|
||||
if not pdf_hash:
|
||||
# Fallback: use a sanitized version of the s3_url
|
||||
pdf_hash = re.sub(r'\W+', '_', s3_url)
|
||||
pdf_hash = re.sub(r"\W+", "_", s3_url)
|
||||
dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
|
||||
# Avoid re-downloading if already exists
|
||||
if not os.path.exists(dest_path):
|
||||
@ -149,6 +154,7 @@ def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
|
||||
print(f"Error downloading {s3_url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
|
||||
@ -178,6 +184,7 @@ def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf
|
||||
print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
|
||||
"""
|
||||
Process a single file and return a tuple:
|
||||
@ -215,8 +222,7 @@ def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: D
|
||||
|
||||
# Apply filter: skip if response contains "resume" (any case) and an email or phone number.
|
||||
response_text = normalized.text if normalized.text else ""
|
||||
if (re.search(r"resume", response_text, re.IGNORECASE) and
|
||||
(re.search(email_regex, response_text) or re.search(phone_regex, response_text))):
|
||||
if re.search(r"resume", response_text, re.IGNORECASE) and (re.search(email_regex, response_text) or re.search(phone_regex, response_text)):
|
||||
print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
|
||||
continue
|
||||
|
||||
@ -254,6 +260,7 @@ def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: D
|
||||
print(f"Error processing file {file_path}: {e}")
|
||||
return rows, missing_count
|
||||
|
||||
|
||||
def scan_file_for_s3_urls(file_path: str) -> Set[str]:
|
||||
"""
|
||||
Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
|
||||
@ -276,18 +283,15 @@ def scan_file_for_s3_urls(file_path: str) -> Set[str]:
|
||||
print(f"Error reading file {file_path}: {e}")
|
||||
return urls
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a Parquet dataset file for HuggingFace upload."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Generate a Parquet dataset file for HuggingFace upload.")
|
||||
parser.add_argument(
|
||||
"input_dataset",
|
||||
help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')",
|
||||
)
|
||||
parser.add_argument("db_path", help="Path to the SQLite database file.")
|
||||
parser.add_argument(
|
||||
"--output", default="output.parquet", help="Output Parquet file path."
|
||||
)
|
||||
parser.add_argument("--output", default="output.parquet", help="Output Parquet file path.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -303,7 +307,7 @@ def main():
|
||||
# Create a temporary directory for caching PDFs.
|
||||
pdf_cache_dir = "/tmp/pdf_cache"
|
||||
os.makedirs(pdf_cache_dir, exist_ok=True)
|
||||
|
||||
|
||||
print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
@ -323,12 +327,8 @@ def main():
|
||||
pdf_cache: Dict[str, str] = {}
|
||||
print("Caching PDFs from S3...")
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor:
|
||||
future_to_url = {
|
||||
executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url
|
||||
for s3_url in unique_s3_urls
|
||||
}
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_url),
|
||||
total=len(future_to_url), desc="Downloading PDFs"):
|
||||
future_to_url = {executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for s3_url in unique_s3_urls}
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(future_to_url), desc="Downloading PDFs"):
|
||||
s3_url = future_to_url[future]
|
||||
try:
|
||||
local_path = future.result()
|
||||
@ -345,12 +345,8 @@ def main():
|
||||
total_missing = 0
|
||||
print("Processing files...")
|
||||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||
futures = {
|
||||
executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path
|
||||
for file_path in files
|
||||
}
|
||||
for future in tqdm(concurrent.futures.as_completed(futures),
|
||||
total=len(futures), desc="Processing files"):
|
||||
futures = {executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files}
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"):
|
||||
file_path = futures[future]
|
||||
try:
|
||||
rows, missing_count = future.result()
|
||||
|
@ -1,23 +1,21 @@
|
||||
import logging
|
||||
import os
|
||||
import tarfile
|
||||
import logging
|
||||
from math import ceil
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from tqdm import tqdm
|
||||
from math import ceil
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from tqdm import tqdm
|
||||
|
||||
# Configuration
|
||||
pdf_dir = "pdfs" # Directory with PDF files (flat structure)
|
||||
tarball_dir = "tarballs" # Directory where tar.gz files will be saved
|
||||
pdf_dir = "pdfs" # Directory with PDF files (flat structure)
|
||||
tarball_dir = "tarballs" # Directory where tar.gz files will be saved
|
||||
os.makedirs(tarball_dir, exist_ok=True)
|
||||
repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID
|
||||
|
||||
# Set up logging to file
|
||||
logging.basicConfig(
|
||||
filename='upload.log',
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logging.basicConfig(filename="upload.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
|
||||
def process_chunk(args):
|
||||
"""
|
||||
@ -27,7 +25,7 @@ def process_chunk(args):
|
||||
chunk_index, chunk_files = args
|
||||
tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz"
|
||||
tarball_path = os.path.join(tarball_dir, tarball_name)
|
||||
|
||||
|
||||
try:
|
||||
with tarfile.open(tarball_path, "w:gz") as tar:
|
||||
for pdf_filename in chunk_files:
|
||||
@ -41,10 +39,11 @@ def process_chunk(args):
|
||||
logging.error(error_msg)
|
||||
return chunk_index, False, error_msg
|
||||
|
||||
|
||||
def main():
|
||||
# List all PDF files (assuming a flat directory)
|
||||
try:
|
||||
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith('.pdf')])
|
||||
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith(".pdf")])
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing PDFs in '{pdf_dir}': {e}")
|
||||
return
|
||||
@ -61,7 +60,7 @@ def main():
|
||||
# end = start + chunk_size
|
||||
# chunk_files = pdf_files[start:end]
|
||||
# chunks.append((idx, chunk_files))
|
||||
|
||||
|
||||
# # Create tarballs in parallel
|
||||
# results = []
|
||||
# with ProcessPoolExecutor() as executor:
|
||||
@ -90,10 +89,11 @@ def main():
|
||||
api.upload_large_folder(
|
||||
folder_path=tarball_dir,
|
||||
repo_id=repo_id,
|
||||
#path_in_repo="pdf_tarballs",
|
||||
repo_type="dataset"
|
||||
# path_in_repo="pdf_tarballs",
|
||||
repo_type="dataset",
|
||||
)
|
||||
logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,11 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import sqlite3
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
import boto3
|
||||
from tqdm import tqdm
|
||||
from warcio.archiveiterator import ArchiveIterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
|
||||
def parse_s3_path(s3_path):
|
||||
"""
|
||||
@ -19,23 +21,25 @@ def parse_s3_path(s3_path):
|
||||
prefix = parts[1] if len(parts) > 1 else ""
|
||||
return bucket, prefix
|
||||
|
||||
def list_s3_warc_objects(s3_path, suffix='.warc.gz'):
|
||||
|
||||
def list_s3_warc_objects(s3_path, suffix=".warc.gz"):
|
||||
"""
|
||||
Lists all objects under the given S3 path that end with the provided suffix.
|
||||
Uses a paginator to handle large result sets.
|
||||
"""
|
||||
bucket, prefix = parse_s3_path(s3_path)
|
||||
s3_client = boto3.client('s3')
|
||||
paginator = s3_client.get_paginator('list_objects_v2')
|
||||
s3_client = boto3.client("s3")
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
warc_keys = []
|
||||
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
||||
if 'Contents' in page:
|
||||
for obj in page['Contents']:
|
||||
key = obj['Key']
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
key = obj["Key"]
|
||||
if key.endswith(suffix):
|
||||
warc_keys.append(key)
|
||||
return bucket, warc_keys, s3_client
|
||||
|
||||
|
||||
def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
|
||||
"""
|
||||
Retrieves the first head_bytes bytes (1 MB by default) from the S3 object using a range request,
|
||||
@ -43,8 +47,8 @@ def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
|
||||
"""
|
||||
target_uri = None
|
||||
try:
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key, Range=f'bytes=0-{head_bytes-1}')
|
||||
stream = response['Body']
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key, Range=f"bytes=0-{head_bytes-1}")
|
||||
stream = response["Body"]
|
||||
for record in ArchiveIterator(stream):
|
||||
for name, value in record.rec_headers.headers:
|
||||
if name == "WARC-Target-URI":
|
||||
@ -56,6 +60,7 @@ def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
|
||||
tqdm.write(f"Error processing s3://{bucket}/{key}: {e}")
|
||||
return target_uri
|
||||
|
||||
|
||||
def create_db(db_path):
|
||||
"""
|
||||
Creates (or opens) the SQLite database and ensures that the pdf_mapping table exists,
|
||||
@ -63,18 +68,23 @@ def create_db(db_path):
|
||||
"""
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS pdf_mapping (
|
||||
pdf_hash TEXT PRIMARY KEY,
|
||||
uri TEXT
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_pdf_hash ON pdf_mapping (pdf_hash)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def process_warc_file(key, bucket, s3_client):
|
||||
"""
|
||||
Processes a single WARC file from S3 and returns a tuple (pdf_hash, uri)
|
||||
@ -83,19 +93,20 @@ def process_warc_file(key, bucket, s3_client):
|
||||
uri = extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576)
|
||||
if uri:
|
||||
# Derive pdf_hash as the file's basename with .warc.gz replaced by .pdf.
|
||||
pdf_hash = key.split('/')[-1].replace('.warc.gz', '.pdf')
|
||||
pdf_hash = key.split("/")[-1].replace(".warc.gz", ".pdf")
|
||||
return (pdf_hash, uri)
|
||||
else:
|
||||
tqdm.write(f"Warning: No valid response record found in s3://{bucket}/{key}")
|
||||
return None
|
||||
|
||||
|
||||
def process_s3_folder(s3_path, db_path):
|
||||
"""
|
||||
Lists all .warc.gz files under the provided S3 path, then processes each file in parallel
|
||||
to extract the target URI from the HTTP headers. The resulting mapping (derived from the file's
|
||||
basename with .warc.gz replaced by .pdf) is stored in the SQLite database.
|
||||
"""
|
||||
bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix='.warc.gz')
|
||||
bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix=".warc.gz")
|
||||
conn = create_db(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
@ -110,21 +121,18 @@ def process_s3_folder(s3_path, db_path):
|
||||
# Bulk insert into the database.
|
||||
conn.execute("BEGIN")
|
||||
for pdf_hash, uri in results:
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)",
|
||||
(pdf_hash, uri)
|
||||
)
|
||||
cursor.execute("INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", (pdf_hash, uri))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files.")
|
||||
parser.add_argument("s3_path", help="S3 path (e.g., s3://bucket/prefix) containing .warc.gz files")
|
||||
parser.add_argument("db_file", help="Path for the output SQLite database file")
|
||||
args = parser.parse_args()
|
||||
process_s3_folder(args.s3_path, args.db_file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user