mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 10:04:26 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			188 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			188 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import csv
 | |
| import json
 | |
| import os
 | |
| import random
 | |
| import re
 | |
| import sqlite3
 | |
| from collections import Counter
 | |
| from concurrent.futures import ProcessPoolExecutor, as_completed
 | |
| from typing import Optional
 | |
| from urllib.parse import urlparse
 | |
| 
 | |
| from tqdm import tqdm
 | |
| 
 | |
| 
 | |
| def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
 | |
|     pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
 | |
|     match = re.match(pattern, pretty_pdf_path)
 | |
|     if match:
 | |
|         return match.group(1) + match.group(2)
 | |
|     return None
 | |
| 
 | |
| 
 | |
| def cache_athena_csv_to_db(athena_csv_path: str) -> str:
 | |
|     db_path = athena_csv_path + ".db"
 | |
| 
 | |
|     if not os.path.exists(db_path):
 | |
|         conn = sqlite3.connect(db_path)
 | |
|         cursor = conn.cursor()
 | |
| 
 | |
|         cursor.execute("PRAGMA synchronous = OFF;")
 | |
|         cursor.execute("PRAGMA journal_mode = MEMORY;")
 | |
| 
 | |
|         cursor.execute(
 | |
|             """
 | |
|             CREATE TABLE pdf_mapping (
 | |
|                 pdf_hash TEXT PRIMARY KEY,
 | |
|                 uri TEXT
 | |
|             )
 | |
|             """
 | |
|         )
 | |
| 
 | |
|         with open(athena_csv_path, "r", encoding="utf-8") as f:
 | |
|             reader = csv.DictReader(f)
 | |
|             batch = []
 | |
|             for row in tqdm(reader):
 | |
|                 batch.append((row["distinct_pdf_hash"], row["uri"]))
 | |
|                 if len(batch) == 1000:
 | |
|                     cursor.executemany("INSERT INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", batch)
 | |
|                     conn.commit()
 | |
|                     batch = []
 | |
| 
 | |
|             if batch:
 | |
|                 cursor.executemany("INSERT INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", batch)
 | |
|                 conn.commit()
 | |
| 
 | |
|         conn.close()
 | |
| 
 | |
|     return db_path
 | |
| 
 | |
| 
 | |
| def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
 | |
|     conn = sqlite3.connect(db_path)
 | |
|     cursor = conn.cursor()
 | |
|     cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
 | |
|     result = cursor.fetchone()
 | |
|     conn.close()
 | |
|     return result[0] if result else None
 | |
| 
 | |
| 
 | |
| def process_file(filepath, db_path):
 | |
|     results = []
 | |
|     with open(filepath, "r", encoding="utf-8") as f:
 | |
|         for line in f:
 | |
|             line = line.strip()
 | |
|             if not line:
 | |
|                 continue
 | |
| 
 | |
|             try:
 | |
|                 data = json.loads(line)
 | |
|             except json.JSONDecodeError:
 | |
|                 continue
 | |
| 
 | |
|             custom_id = data.get("custom_id")
 | |
|             if not custom_id:
 | |
|                 continue
 | |
| 
 | |
|             pdf_hash = parse_pdf_hash(custom_id)
 | |
|             if not pdf_hash:
 | |
|                 continue
 | |
| 
 | |
|             uri = get_uri_from_db(db_path, pdf_hash)
 | |
| 
 | |
|             domain = None
 | |
|             if uri:
 | |
|                 parsed = urlparse(uri)
 | |
|                 domain = parsed.netloc
 | |
| 
 | |
|             results.append((custom_id, uri, domain))
 | |
|     return results
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description="Review silver dataset and provide summary statistics based on source URL and also provide a few data samples for review."
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--input",
 | |
|         type=str,
 | |
|         default="openai_batch_data",
 | |
|         help="Input folder, which is the output of the buildsilver.py script",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--output",
 | |
|         type=str,
 | |
|         default="openai_batch_data_summary",
 | |
|         help="Output destination (folder)",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--athena-csv",
 | |
|         type=str,
 | |
|         default="/home/ubuntu/s2pdf_url_data/c974870d-3b06-4793-9a62-d46d38e2c8b2.csv",
 | |
|         help="CSV file that maps pdf_hash to uri",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--sample-size",
 | |
|         type=int,
 | |
|         default=20,
 | |
|         help="How many sample rows to include in the sample CSV",
 | |
|     )
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     db_path = cache_athena_csv_to_db(args.athena_csv)
 | |
| 
 | |
|     all_rows = []
 | |
|     filepaths = [os.path.join(args.input, filename) for filename in os.listdir(args.input) if filename.endswith(".jsonl")]
 | |
| 
 | |
|     with ProcessPoolExecutor() as executor:
 | |
|         future_to_file = {executor.submit(process_file, filepath, db_path): filepath for filepath in filepaths}
 | |
| 
 | |
|         for future in tqdm(as_completed(future_to_file), total=len(filepaths)):
 | |
|             try:
 | |
|                 results = future.result()
 | |
|                 all_rows.extend(results)
 | |
|             except Exception as e:
 | |
|                 print(f"Error processing file: {future_to_file[future]}\n{e}")
 | |
| 
 | |
|     os.makedirs(args.output, exist_ok=True)
 | |
| 
 | |
|     output_csv_path = os.path.join(args.output, "custom_id_to_url.csv")
 | |
|     with open(output_csv_path, "w", encoding="utf-8", newline="") as f:
 | |
|         writer = csv.writer(f)
 | |
|         writer.writerow(["custom_id", "uri", "domain"])
 | |
|         for cid, uri, domain in all_rows:
 | |
|             writer.writerow([cid, uri if uri else "", domain if domain else ""])
 | |
| 
 | |
|     domain_counter: Counter[str] = Counter()
 | |
|     for _, _, domain in all_rows:
 | |
|         if domain:
 | |
|             domain_counter[domain] += 1
 | |
| 
 | |
|     most_common_domains = domain_counter.most_common(1000)
 | |
|     domain_csv_path = os.path.join(args.output, "top_1000_domains.csv")
 | |
|     with open(domain_csv_path, "w", encoding="utf-8", newline="") as f:
 | |
|         writer = csv.writer(f)
 | |
|         writer.writerow(["domain", "count"])
 | |
|         for domain, count in most_common_domains:
 | |
|             writer.writerow([domain, count])
 | |
| 
 | |
|     sample_size = min(args.sample_size, len(all_rows))
 | |
|     sample_rows = random.sample(all_rows, sample_size) if all_rows else []
 | |
|     sample_csv_path = os.path.join(args.output, "data_samples.csv")
 | |
|     with open(sample_csv_path, "w", encoding="utf-8", newline="") as f:
 | |
|         writer = csv.writer(f)
 | |
|         writer.writerow(["custom_id", "uri", "domain"])
 | |
|         for cid, uri, domain in sample_rows:
 | |
|             writer.writerow([cid, uri if uri else "", domain if domain else ""])
 | |
| 
 | |
|     print(f"Summary files written to: {args.output}")
 | |
|     print(f" - Full mapping: {output_csv_path}")
 | |
|     print(f" - Top domains: {domain_csv_path}")
 | |
|     print(f" - Samples: {sample_csv_path}")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 | 
