mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-25 16:30:28 +00:00
dbmanager
This commit is contained in:
parent
2dccc4be3b
commit
f477a68621
@ -7,6 +7,54 @@ import argparse
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
class DatabaseManager:
|
||||
def __init__(self, db_path):
|
||||
self.db_path = db_path
|
||||
self.conn = sqlite3.connect(self.db_path)
|
||||
self.cursor = self.conn.cursor()
|
||||
self._initialize_tables()
|
||||
|
||||
def _initialize_tables(self):
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS index_table (
|
||||
custom_id TEXT,
|
||||
s3_path TEXT,
|
||||
start_index BIGINT,
|
||||
end_index BIGINT
|
||||
)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS processed_files (
|
||||
s3_path TEXT PRIMARY KEY,
|
||||
etag TEXT
|
||||
)
|
||||
""")
|
||||
self.conn.commit()
|
||||
|
||||
def is_file_processed(self, s3_path, etag):
|
||||
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
|
||||
result = self.cursor.fetchone()
|
||||
return result is not None and result[0] == etag
|
||||
|
||||
def add_index_entries(self, index_entries):
|
||||
if index_entries:
|
||||
self.cursor.executemany("""
|
||||
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", index_entries)
|
||||
self.conn.commit()
|
||||
|
||||
def update_processed_file(self, s3_path, etag):
|
||||
self.cursor.execute("""
|
||||
INSERT INTO processed_files (s3_path, etag)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
|
||||
""", (s3_path, etag))
|
||||
self.conn.commit()
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def build_index(s3_path):
|
||||
# Hash the s3_path to get a cache key
|
||||
cache_key = hashlib.sha256(s3_path.encode('utf-8')).hexdigest()
|
||||
@ -14,25 +62,9 @@ def build_index(s3_path):
|
||||
os.makedirs(home_cache_dir, exist_ok=True)
|
||||
db_path = os.path.join(home_cache_dir, 'index.db')
|
||||
|
||||
# Connect to sqlite and create tables if not exist
|
||||
# Initialize the database manager
|
||||
print("Building page index at", db_path)
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS index_table (
|
||||
custom_id TEXT,
|
||||
s3_path TEXT,
|
||||
start_index BIGINT,
|
||||
end_index BIGINT
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS processed_files (
|
||||
s3_path TEXT PRIMARY KEY,
|
||||
etag TEXT
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
db_manager = DatabaseManager(db_path)
|
||||
|
||||
s3 = boto3.client('s3')
|
||||
bucket, prefix = parse_s3_path(s3_path)
|
||||
@ -42,42 +74,34 @@ def build_index(s3_path):
|
||||
|
||||
if not files:
|
||||
print("No .json or .jsonl files found in the specified S3 path.")
|
||||
db_manager.close()
|
||||
return
|
||||
|
||||
# Prepare a list of files that need processing
|
||||
files_to_process = []
|
||||
for key, etag in files.items():
|
||||
cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (key,))
|
||||
db_result = cursor.fetchone()
|
||||
if db_result and db_result[0] == etag:
|
||||
# File has already been processed with the same ETag
|
||||
pass # Skip
|
||||
else:
|
||||
files_to_process.append((key, etag))
|
||||
files_to_process = [
|
||||
(key, etag) for key, etag in files.items()
|
||||
if not db_manager.is_file_processed(key, etag)
|
||||
]
|
||||
|
||||
if not files_to_process:
|
||||
print("All files are up to date. No processing needed.")
|
||||
db_manager.close()
|
||||
return
|
||||
|
||||
# Use ProcessPoolExecutor to process files with tqdm progress bar
|
||||
with ProcessPoolExecutor() as executor:
|
||||
futures = [executor.submit(process_file, bucket, key, etag) for key, etag in files_to_process]
|
||||
futures = [
|
||||
executor.submit(process_file, bucket, key, etag)
|
||||
for key, etag in files_to_process
|
||||
]
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
|
||||
s3_path, key, etag, index_entries = future.result()
|
||||
if index_entries:
|
||||
cursor.executemany("""
|
||||
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", index_entries)
|
||||
db_manager.add_index_entries(index_entries)
|
||||
# Update the processed_files table
|
||||
cursor.execute("""
|
||||
INSERT INTO processed_files (s3_path, etag)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
|
||||
""", (key, etag))
|
||||
conn.commit()
|
||||
db_manager.update_processed_file(key, etag)
|
||||
|
||||
conn.close()
|
||||
db_manager.close()
|
||||
|
||||
def parse_s3_path(s3_path):
|
||||
if not s3_path.startswith('s3://'):
|
||||
@ -139,5 +163,3 @@ if __name__ == '__main__':
|
||||
|
||||
# Step one, build an index of all the pages that were processed
|
||||
build_index(args.s3_path)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user