olmocr/olmocr/data/buildsilverdatasummary.py

189 lines
5.9 KiB
Python

import argparse
import collections
import csv
import json
import os
import random
import re
import sqlite3
from collections import Counter
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Optional
from urllib.parse import urlparse
from tqdm import tqdm
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
match = re.match(pattern, pretty_pdf_path)
if match:
return match.group(1) + match.group(2)
return None
def cache_athena_csv_to_db(athena_csv_path: str) -> str:
db_path = athena_csv_path + ".db"
if not os.path.exists(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("PRAGMA synchronous = OFF;")
cursor.execute("PRAGMA journal_mode = MEMORY;")
cursor.execute(
"""
CREATE TABLE pdf_mapping (
pdf_hash TEXT PRIMARY KEY,
uri TEXT
)
"""
)
with open(athena_csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
batch = []
for row in tqdm(reader):
batch.append((row["distinct_pdf_hash"], row["uri"]))
if len(batch) == 1000:
cursor.executemany("INSERT INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", batch)
conn.commit()
batch = []
if batch:
cursor.executemany("INSERT INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", batch)
conn.commit()
conn.close()
return db_path
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
result = cursor.fetchone()
conn.close()
return result[0] if result else None
def process_file(filepath, db_path):
results = []
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
custom_id = data.get("custom_id")
if not custom_id:
continue
pdf_hash = parse_pdf_hash(custom_id)
if not pdf_hash:
continue
uri = get_uri_from_db(db_path, pdf_hash)
domain = None
if uri:
parsed = urlparse(uri)
domain = parsed.netloc
results.append((custom_id, uri, domain))
return results
def main():
parser = argparse.ArgumentParser(
description="Review silver dataset and provide summary statistics based on source URL and also provide a few data samples for review."
)
parser.add_argument(
"--input",
type=str,
default="openai_batch_data",
help="Input folder, which is the output of the buildsilver.py script",
)
parser.add_argument(
"--output",
type=str,
default="openai_batch_data_summary",
help="Output destination (folder)",
)
parser.add_argument(
"--athena-csv",
type=str,
default="/home/ubuntu/s2pdf_url_data/c974870d-3b06-4793-9a62-d46d38e2c8b2.csv",
help="CSV file that maps pdf_hash to uri",
)
parser.add_argument(
"--sample-size",
type=int,
default=20,
help="How many sample rows to include in the sample CSV",
)
args = parser.parse_args()
db_path = cache_athena_csv_to_db(args.athena_csv)
all_rows = []
filepaths = [os.path.join(args.input, filename) for filename in os.listdir(args.input) if filename.endswith(".jsonl")]
with ProcessPoolExecutor() as executor:
future_to_file = {executor.submit(process_file, filepath, db_path): filepath for filepath in filepaths}
for future in tqdm(as_completed(future_to_file), total=len(filepaths)):
try:
results = future.result()
all_rows.extend(results)
except Exception as e:
print(f"Error processing file: {future_to_file[future]}\n{e}")
os.makedirs(args.output, exist_ok=True)
output_csv_path = os.path.join(args.output, "custom_id_to_url.csv")
with open(output_csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(["custom_id", "uri", "domain"])
for cid, uri, domain in all_rows:
writer.writerow([cid, uri if uri else "", domain if domain else ""])
domain_counter: Counter[str] = Counter()
for _, _, domain in all_rows:
if domain:
domain_counter[domain] += 1
most_common_domains = domain_counter.most_common(1000)
domain_csv_path = os.path.join(args.output, "top_1000_domains.csv")
with open(domain_csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(["domain", "count"])
for domain, count in most_common_domains:
writer.writerow([domain, count])
sample_size = min(args.sample_size, len(all_rows))
sample_rows = random.sample(all_rows, sample_size) if all_rows else []
sample_csv_path = os.path.join(args.output, "data_samples.csv")
with open(sample_csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(["custom_id", "uri", "domain"])
for cid, uri, domain in sample_rows:
writer.writerow([cid, uri if uri else "", domain if domain else ""])
print(f"Summary files written to: {args.output}")
print(f" - Full mapping: {output_csv_path}")
print(f" - Top domains: {domain_csv_path}")
print(f" - Samples: {sample_csv_path}")
if __name__ == "__main__":
main()