mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-04 11:11:08 +00:00
Some small updates
This commit is contained in:
parent
6586744718
commit
0311b445fd
@ -6,16 +6,11 @@ import re
|
|||||||
import collections
|
import collections
|
||||||
import random
|
import random
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
def parse_pdf_hash(pretty_pdf_path: str) -> str:
|
def parse_pdf_hash(pretty_pdf_path: str) -> str:
|
||||||
"""
|
|
||||||
Given a string like "s3://ai2-s2-pdfs/4342/6a12ffc2ffa73f5258eb66095659beae9522.pdf-32",
|
|
||||||
extract the hash ("43426a12ffc2ffa73f5258eb66095659beae9522").
|
|
||||||
Returns None if not found.
|
|
||||||
"""
|
|
||||||
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
|
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
|
||||||
match = re.match(pattern, pretty_pdf_path)
|
match = re.match(pattern, pretty_pdf_path)
|
||||||
if match:
|
if match:
|
||||||
@ -23,29 +18,24 @@ def parse_pdf_hash(pretty_pdf_path: str) -> str:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def cache_athena_csv_to_db(athena_csv_path: str) -> str:
|
def cache_athena_csv_to_db(athena_csv_path: str) -> str:
|
||||||
"""
|
|
||||||
Cache the Athena CSV file into an SQLite database.
|
|
||||||
Returns the path to the SQLite database.
|
|
||||||
"""
|
|
||||||
db_path = athena_csv_path + ".db"
|
db_path = athena_csv_path + ".db"
|
||||||
|
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute("PRAGMA synchronous = OFF;")
|
cursor.execute("PRAGMA synchronous = OFF;")
|
||||||
cursor.execute("PRAGMA journal_mode = MEMORY;")
|
cursor.execute("PRAGMA journal_mode = MEMORY;")
|
||||||
|
|
||||||
|
|
||||||
# Create the table
|
cursor.execute(
|
||||||
cursor.execute("""
|
"""
|
||||||
CREATE TABLE pdf_mapping (
|
CREATE TABLE pdf_mapping (
|
||||||
pdf_hash TEXT PRIMARY KEY,
|
pdf_hash TEXT PRIMARY KEY,
|
||||||
uri TEXT
|
uri TEXT
|
||||||
)
|
)
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# Insert data from CSV in batches of 1000 rows
|
|
||||||
with open(athena_csv_path, "r", encoding="utf-8") as f:
|
with open(athena_csv_path, "r", encoding="utf-8") as f:
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
batch = []
|
batch = []
|
||||||
@ -56,7 +46,6 @@ def cache_athena_csv_to_db(athena_csv_path: str) -> str:
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
batch = []
|
batch = []
|
||||||
|
|
||||||
# Insert remaining rows
|
|
||||||
if batch:
|
if batch:
|
||||||
cursor.executemany("INSERT INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", batch)
|
cursor.executemany("INSERT INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", batch)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@ -66,9 +55,6 @@ def cache_athena_csv_to_db(athena_csv_path: str) -> str:
|
|||||||
return db_path
|
return db_path
|
||||||
|
|
||||||
def get_uri_from_db(db_path: str, pdf_hash: str) -> str:
|
def get_uri_from_db(db_path: str, pdf_hash: str) -> str:
|
||||||
"""
|
|
||||||
Query the SQLite database to retrieve the URI for a given PDF hash.
|
|
||||||
"""
|
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
||||||
@ -76,6 +62,37 @@ def get_uri_from_db(db_path: str, pdf_hash: str) -> str:
|
|||||||
conn.close()
|
conn.close()
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
def process_file(filepath, db_path):
|
||||||
|
results = []
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
custom_id = data.get("custom_id")
|
||||||
|
if not custom_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pdf_hash = parse_pdf_hash(custom_id)
|
||||||
|
if not pdf_hash:
|
||||||
|
continue
|
||||||
|
|
||||||
|
uri = get_uri_from_db(db_path, pdf_hash)
|
||||||
|
|
||||||
|
domain = None
|
||||||
|
if uri:
|
||||||
|
parsed = urlparse(uri)
|
||||||
|
domain = parsed.netloc
|
||||||
|
|
||||||
|
results.append((custom_id, uri, domain))
|
||||||
|
return results
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Review silver dataset and provide summary statistics based on source URL and also provide a few data samples for review."
|
description="Review silver dataset and provide summary statistics based on source URL and also provide a few data samples for review."
|
||||||
@ -107,45 +124,21 @@ def main():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Cache the Athena CSV into SQLite database
|
|
||||||
db_path = cache_athena_csv_to_db(args.athena_csv)
|
db_path = cache_athena_csv_to_db(args.athena_csv)
|
||||||
|
|
||||||
# Process input JSONL files
|
|
||||||
all_rows = []
|
all_rows = []
|
||||||
|
filepaths = [os.path.join(args.input, filename) for filename in os.listdir(args.input) if filename.endswith(".jsonl")]
|
||||||
|
|
||||||
for filename in tqdm(os.listdir(args.input)):
|
with ProcessPoolExecutor() as executor:
|
||||||
if filename.endswith(".jsonl"):
|
future_to_file = {executor.submit(process_file, filepath, db_path): filepath for filepath in filepaths}
|
||||||
filepath = os.path.join(args.input, filename)
|
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
for future in tqdm(as_completed(future_to_file), total=len(filepaths)):
|
||||||
data = json.loads(line)
|
try:
|
||||||
except json.JSONDecodeError:
|
results = future.result()
|
||||||
print("Error parsing line")
|
all_rows.extend(results)
|
||||||
continue
|
except Exception as e:
|
||||||
|
print(f"Error processing file: {future_to_file[future]}\n{e}")
|
||||||
|
|
||||||
custom_id = data.get("custom_id")
|
|
||||||
if not custom_id:
|
|
||||||
print("No custom_id found")
|
|
||||||
continue
|
|
||||||
|
|
||||||
pdf_hash = parse_pdf_hash(custom_id)
|
|
||||||
assert pdf_hash, f"Need to have a pdf_hash {custom_id}"
|
|
||||||
|
|
||||||
uri = get_uri_from_db(db_path, pdf_hash)
|
|
||||||
|
|
||||||
domain = None
|
|
||||||
if uri:
|
|
||||||
parsed = urlparse(uri)
|
|
||||||
domain = parsed.netloc
|
|
||||||
|
|
||||||
all_rows.append((custom_id, uri, domain))
|
|
||||||
|
|
||||||
# Write output CSVs
|
|
||||||
os.makedirs(args.output, exist_ok=True)
|
os.makedirs(args.output, exist_ok=True)
|
||||||
|
|
||||||
output_csv_path = os.path.join(args.output, "custom_id_to_url.csv")
|
output_csv_path = os.path.join(args.output, "custom_id_to_url.csv")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user