mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-26 17:04:02 +00:00
dbmanager
This commit is contained in:
parent
2dccc4be3b
commit
f477a68621
@ -7,6 +7,54 @@ import argparse
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
|
||||||
|
class DatabaseManager:
|
||||||
|
def __init__(self, db_path):
|
||||||
|
self.db_path = db_path
|
||||||
|
self.conn = sqlite3.connect(self.db_path)
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
self._initialize_tables()
|
||||||
|
|
||||||
|
def _initialize_tables(self):
|
||||||
|
self.cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS index_table (
|
||||||
|
custom_id TEXT,
|
||||||
|
s3_path TEXT,
|
||||||
|
start_index BIGINT,
|
||||||
|
end_index BIGINT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
self.cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS processed_files (
|
||||||
|
s3_path TEXT PRIMARY KEY,
|
||||||
|
etag TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def is_file_processed(self, s3_path, etag):
|
||||||
|
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
|
||||||
|
result = self.cursor.fetchone()
|
||||||
|
return result is not None and result[0] == etag
|
||||||
|
|
||||||
|
def add_index_entries(self, index_entries):
|
||||||
|
if index_entries:
|
||||||
|
self.cursor.executemany("""
|
||||||
|
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""", index_entries)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def update_processed_file(self, s3_path, etag):
|
||||||
|
self.cursor.execute("""
|
||||||
|
INSERT INTO processed_files (s3_path, etag)
|
||||||
|
VALUES (?, ?)
|
||||||
|
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
|
||||||
|
""", (s3_path, etag))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
def build_index(s3_path):
|
def build_index(s3_path):
|
||||||
# Hash the s3_path to get a cache key
|
# Hash the s3_path to get a cache key
|
||||||
cache_key = hashlib.sha256(s3_path.encode('utf-8')).hexdigest()
|
cache_key = hashlib.sha256(s3_path.encode('utf-8')).hexdigest()
|
||||||
@ -14,25 +62,9 @@ def build_index(s3_path):
|
|||||||
os.makedirs(home_cache_dir, exist_ok=True)
|
os.makedirs(home_cache_dir, exist_ok=True)
|
||||||
db_path = os.path.join(home_cache_dir, 'index.db')
|
db_path = os.path.join(home_cache_dir, 'index.db')
|
||||||
|
|
||||||
# Connect to sqlite and create tables if not exist
|
# Initialize the database manager
|
||||||
print("Building page index at", db_path)
|
print("Building page index at", db_path)
|
||||||
conn = sqlite3.connect(db_path)
|
db_manager = DatabaseManager(db_path)
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS index_table (
|
|
||||||
custom_id TEXT,
|
|
||||||
s3_path TEXT,
|
|
||||||
start_index BIGINT,
|
|
||||||
end_index BIGINT
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS processed_files (
|
|
||||||
s3_path TEXT PRIMARY KEY,
|
|
||||||
etag TEXT
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
s3 = boto3.client('s3')
|
s3 = boto3.client('s3')
|
||||||
bucket, prefix = parse_s3_path(s3_path)
|
bucket, prefix = parse_s3_path(s3_path)
|
||||||
@ -42,42 +74,34 @@ def build_index(s3_path):
|
|||||||
|
|
||||||
if not files:
|
if not files:
|
||||||
print("No .json or .jsonl files found in the specified S3 path.")
|
print("No .json or .jsonl files found in the specified S3 path.")
|
||||||
|
db_manager.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Prepare a list of files that need processing
|
# Prepare a list of files that need processing
|
||||||
files_to_process = []
|
files_to_process = [
|
||||||
for key, etag in files.items():
|
(key, etag) for key, etag in files.items()
|
||||||
cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (key,))
|
if not db_manager.is_file_processed(key, etag)
|
||||||
db_result = cursor.fetchone()
|
]
|
||||||
if db_result and db_result[0] == etag:
|
|
||||||
# File has already been processed with the same ETag
|
|
||||||
pass # Skip
|
|
||||||
else:
|
|
||||||
files_to_process.append((key, etag))
|
|
||||||
|
|
||||||
if not files_to_process:
|
if not files_to_process:
|
||||||
print("All files are up to date. No processing needed.")
|
print("All files are up to date. No processing needed.")
|
||||||
|
db_manager.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Use ProcessPoolExecutor to process files with tqdm progress bar
|
# Use ProcessPoolExecutor to process files with tqdm progress bar
|
||||||
with ProcessPoolExecutor() as executor:
|
with ProcessPoolExecutor() as executor:
|
||||||
futures = [executor.submit(process_file, bucket, key, etag) for key, etag in files_to_process]
|
futures = [
|
||||||
|
executor.submit(process_file, bucket, key, etag)
|
||||||
|
for key, etag in files_to_process
|
||||||
|
]
|
||||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
|
||||||
s3_path, key, etag, index_entries = future.result()
|
s3_path, key, etag, index_entries = future.result()
|
||||||
if index_entries:
|
if index_entries:
|
||||||
cursor.executemany("""
|
db_manager.add_index_entries(index_entries)
|
||||||
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
|
|
||||||
VALUES (?, ?, ?, ?)
|
|
||||||
""", index_entries)
|
|
||||||
# Update the processed_files table
|
# Update the processed_files table
|
||||||
cursor.execute("""
|
db_manager.update_processed_file(key, etag)
|
||||||
INSERT INTO processed_files (s3_path, etag)
|
|
||||||
VALUES (?, ?)
|
|
||||||
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
|
|
||||||
""", (key, etag))
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
conn.close()
|
db_manager.close()
|
||||||
|
|
||||||
def parse_s3_path(s3_path):
|
def parse_s3_path(s3_path):
|
||||||
if not s3_path.startswith('s3://'):
|
if not s3_path.startswith('s3://'):
|
||||||
@ -139,5 +163,3 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Step one, build an index of all the pages that were processed
|
# Step one, build an index of all the pages that were processed
|
||||||
build_index(args.s3_path)
|
build_index(args.s3_path)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user