mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-26 17:04:02 +00:00
More cleanup
This commit is contained in:
parent
53fdb6108c
commit
a45f86e4a4
@ -8,6 +8,7 @@ import glob
|
|||||||
import tempfile
|
import tempfile
|
||||||
import posixpath
|
import posixpath
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -30,15 +31,17 @@ class DatabaseManager:
|
|||||||
|
|
||||||
def _initialize_tables(self):
|
def _initialize_tables(self):
|
||||||
self.cursor.execute("""
|
self.cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS index_table (
|
CREATE TABLE IF NOT EXISTS page_results (
|
||||||
custom_id TEXT,
|
|
||||||
s3_path TEXT,
|
s3_path TEXT,
|
||||||
|
page_num INTEGER,
|
||||||
start_index BIGINT,
|
start_index BIGINT,
|
||||||
end_index BIGINT
|
length BIGINT,
|
||||||
|
finish_reason STRING
|
||||||
|
error STRING
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
self.cursor.execute("""
|
self.cursor.execute("""
|
||||||
CREATE INDEX IF NOT EXISTS idx_custom_id ON index_table(custom_id)
|
CREATE INDEX IF NOT EXISTS idx_path ON index_table(s3_path)
|
||||||
""")
|
""")
|
||||||
self.cursor.execute("""
|
self.cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS pdfs (
|
CREATE TABLE IF NOT EXISTS pdfs (
|
||||||
@ -60,15 +63,16 @@ class DatabaseManager:
|
|||||||
value TEXT
|
value TEXT
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
self.cursor.execute("SELECT value FROM metadata WHERE key='round'")
|
|
||||||
if self.cursor.fetchone() is None:
|
|
||||||
self.cursor.execute("INSERT INTO metadata (key, value) VALUES ('round', '0')")
|
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def get_current_round(self):
|
def get_metadata(self, key: str) -> str:
|
||||||
self.cursor.execute("SELECT value FROM metadata WHERE key='round'")
|
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
|
||||||
result = self.cursor.fetchone()
|
result = self.cursor.fetchone()
|
||||||
return int(result[0])
|
return result[0]
|
||||||
|
|
||||||
|
def get_current_round(self):
|
||||||
|
return int(self.get_metadata("round"))
|
||||||
|
|
||||||
def is_file_processed(self, s3_path, etag):
|
def is_file_processed(self, s3_path, etag):
|
||||||
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
|
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
|
||||||
@ -76,6 +80,7 @@ class DatabaseManager:
|
|||||||
return result is not None and result[0] == etag
|
return result is not None and result[0] == etag
|
||||||
|
|
||||||
def add_index_entries(self, index_entries):
|
def add_index_entries(self, index_entries):
|
||||||
|
# TODO MAke it take batchInferenceLines
|
||||||
if index_entries:
|
if index_entries:
|
||||||
self.cursor.executemany("""
|
self.cursor.executemany("""
|
||||||
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
|
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
|
||||||
@ -113,45 +118,6 @@ class DatabaseManager:
|
|||||||
def close(self):
|
def close(self):
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
|
||||||
def build_index(s3_path):
|
|
||||||
db_manager = DatabaseManager(s3_path)
|
|
||||||
|
|
||||||
bucket, prefix = parse_s3_path(s3_path)
|
|
||||||
|
|
||||||
# List all .json and .jsonl files under s3_path with their ETags
|
|
||||||
files = expand_s3_glob(s3_path)
|
|
||||||
|
|
||||||
if not files:
|
|
||||||
print("No .json or .jsonl files found in the specified S3 path.")
|
|
||||||
db_manager.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Prepare a list of files that need processing
|
|
||||||
files_to_process = [
|
|
||||||
(key, etag) for key, etag in files.items()
|
|
||||||
if not db_manager.is_file_processed(key, etag)
|
|
||||||
]
|
|
||||||
|
|
||||||
if not files_to_process:
|
|
||||||
print("All files are up to date. No processing needed.")
|
|
||||||
db_manager.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use ProcessPoolExecutor to process files with tqdm progress bar
|
|
||||||
with ProcessPoolExecutor() as executor:
|
|
||||||
futures = [
|
|
||||||
executor.submit(process_file, bucket, key, etag)
|
|
||||||
for key, etag in files_to_process
|
|
||||||
]
|
|
||||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
|
|
||||||
s3_path, key, etag, index_entries = future.result()
|
|
||||||
if index_entries:
|
|
||||||
db_manager.add_index_entries(index_entries)
|
|
||||||
# Update the processed_files table
|
|
||||||
db_manager.update_processed_file(key, etag)
|
|
||||||
|
|
||||||
db_manager.close()
|
|
||||||
|
|
||||||
def parse_s3_path(s3_path):
|
def parse_s3_path(s3_path):
|
||||||
if not s3_path.startswith('s3://'):
|
if not s3_path.startswith('s3://'):
|
||||||
raise ValueError('s3_path must start with s3://')
|
raise ValueError('s3_path must start with s3://')
|
||||||
@ -159,13 +125,13 @@ def parse_s3_path(s3_path):
|
|||||||
bucket, _, prefix = path.partition('/')
|
bucket, _, prefix = path.partition('/')
|
||||||
return bucket, prefix
|
return bucket, prefix
|
||||||
|
|
||||||
|
|
||||||
def expand_s3_glob(s3_glob: str) -> dict[str, str]:
|
def expand_s3_glob(s3_glob: str) -> dict[str, str]:
|
||||||
parsed = urlparse(s3_glob)
|
parsed = urlparse(s3_glob)
|
||||||
bucket_name = parsed.netloc
|
bucket_name = parsed.netloc
|
||||||
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
|
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
|
||||||
pattern = os.path.basename(parsed.path)
|
pattern = os.path.basename(parsed.path)
|
||||||
|
|
||||||
|
|
||||||
paginator = s3.get_paginator('list_objects_v2')
|
paginator = s3.get_paginator('list_objects_v2')
|
||||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||||
|
|
||||||
@ -178,37 +144,43 @@ def expand_s3_glob(s3_glob: str) -> dict[str, str]:
|
|||||||
|
|
||||||
return matched_files
|
return matched_files
|
||||||
|
|
||||||
def process_file(bucket, key, etag):
|
@dataclass(frozen=True)
|
||||||
s3 = boto3.client('s3') # Initialize s3 client in the worker process
|
class BatchInferenceLine:
|
||||||
s3_path = f's3://{bucket}/{key}'
|
s3_path: str
|
||||||
try:
|
page_num: int # 1 indexed!
|
||||||
# Get the object
|
start_index: int
|
||||||
obj = s3.get_object(Bucket=bucket, Key=key)
|
length: int
|
||||||
# Read the content as bytes
|
finish_reason: str
|
||||||
content = obj['Body'].read()
|
error: Optional[str]
|
||||||
# Process the file as JSONL
|
|
||||||
index_entries = process_jsonl_content(content, s3_path)
|
def parse_custom_id(custom_id: str) -> tuple[str, int]:
|
||||||
# Return the necessary data to the main process
|
s3_path = custom_id[:custom_id.rindex("-")]
|
||||||
return s3_path, key, etag, index_entries
|
page_num = int(custom_id[custom_id.rindex("-") + 1:])
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing file {s3_path}: {e}")
|
return s3_path, page_num
|
||||||
return s3_path, key, etag, []
|
|
||||||
|
def process_jsonl_content(s3_path) -> list[BatchInferenceLine]:
|
||||||
|
content = get_s3_bytes(s3_path).decode("utf-8")
|
||||||
|
|
||||||
def process_jsonl_content(content, s3_path):
|
|
||||||
start_index = 0
|
start_index = 0
|
||||||
index_entries = []
|
index_entries = []
|
||||||
lines = content.splitlines(keepends=True)
|
lines = content.splitlines(keepends=True)
|
||||||
for line in lines:
|
for line in lines:
|
||||||
line_length = len(line)
|
line_length = len(line)
|
||||||
end_index = start_index + line_length
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
custom_id = data.get('custom_id')
|
s3_path, page_num = parse_custom_id(data["custom_id"])
|
||||||
if custom_id:
|
|
||||||
index_entries.append((custom_id, s3_path, start_index, end_index))
|
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
|
||||||
|
|
||||||
|
index_entries.append(BatchInferenceLine(s3_path, page_num, start_index, line_length,
|
||||||
|
finish_reason=data["outputs"][0]["finish_reason"], error=data.get("completion_error", None)))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass # Handle JSON decode errors if necessary
|
pass # Handle JSON decode errors if necessary
|
||||||
start_index = end_index
|
|
||||||
|
start_index = start_index + line_length
|
||||||
|
|
||||||
return index_entries
|
return index_entries
|
||||||
|
|
||||||
def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
|
def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
|
||||||
@ -246,7 +218,7 @@ def get_pdf_num_pages(s3_path: str) -> Optional[int]:
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
||||||
parser.add_argument('--pdfs', help='Glob path to PDFs (local or s3)', default=None)
|
parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None)
|
||||||
parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB')
|
parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -258,12 +230,12 @@ if __name__ == '__main__':
|
|||||||
executor = ProcessPoolExecutor()
|
executor = ProcessPoolExecutor()
|
||||||
|
|
||||||
# If you have new PDFs, add them to the list
|
# If you have new PDFs, add them to the list
|
||||||
if args.pdfs:
|
if args.add_pdfs:
|
||||||
assert args.pdfs.startswith("s3://"), "PDFs must live on s3"
|
assert args.add_pdfs.startswith("s3://"), "PDFs must live on s3"
|
||||||
|
|
||||||
print(f"Querying all PDFs at {args.pdfs}")
|
print(f"Querying all PDFs at {args.add_pdfs}")
|
||||||
|
|
||||||
all_pdfs = expand_s3_glob(args.pdfs)
|
all_pdfs = expand_s3_glob(args.add_pdfs)
|
||||||
print(f"Found {len(all_pdfs)} total pdf paths")
|
print(f"Found {len(all_pdfs)} total pdf paths")
|
||||||
|
|
||||||
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
|
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
|
||||||
@ -279,8 +251,21 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
|
|
||||||
# Now build an index of all the pages that were processed within the workspace so far
|
# Now build an index of all the pages that were processed within the workspace so far
|
||||||
build_index(f"{args.workspace}/*.jsonl")
|
inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl")
|
||||||
|
|
||||||
# Now, for each pending book, find all pages which still need to be processed
|
inference_output_paths = [
|
||||||
# and add them to the next round's batch inference jobs
|
(key, etag) for key, etag in inference_output_paths.items()
|
||||||
|
if not db.is_file_processed(key, etag)
|
||||||
|
]
|
||||||
|
|
||||||
|
future_to_path = {executor.submit(process_jsonl_content, s3_path): s3_path for s3_path, etag in inference_output_paths}
|
||||||
|
|
||||||
|
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||||
|
s3_path = future_to_path[future]
|
||||||
|
|
||||||
|
inference_lines = future.result()
|
||||||
|
|
||||||
|
db.add_index_entries(inference_lines)
|
||||||
|
|
||||||
|
db.update_processed_file(s3_path, etag=TODO)
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user