mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-01 10:33:57 +00:00
Refactored to have a more efficient batchwriter, and also not allow too many running futures
This commit is contained in:
parent
d99096e9a2
commit
38dc5a2a0f
@ -4,7 +4,7 @@ import boto3
|
||||
import sqlite3
|
||||
import json
|
||||
import argparse
|
||||
import glob
|
||||
import uuid
|
||||
import tempfile
|
||||
import datetime
|
||||
import posixpath
|
||||
@ -19,6 +19,7 @@ from tqdm import tqdm
|
||||
from functools import partial
|
||||
from typing import Optional, List, Tuple, Dict, Callable, Any
|
||||
from urllib.parse import urlparse
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||
@ -241,26 +242,104 @@ class DatabaseManager:
|
||||
self.conn.close()
|
||||
|
||||
|
||||
# Writes batches of lines out to a set of files, keeping each file below some maximum size
|
||||
class BatchWriter:
|
||||
def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
output_prefix: str,
|
||||
max_size_mb: int = 250,
|
||||
after_flush: Optional[Callable[[List[Any]], Any]] = None,
|
||||
):
|
||||
self.output_prefix = output_prefix
|
||||
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||
self.batch = []
|
||||
self.batch_objects = []
|
||||
self.batch_size = 0
|
||||
self.after_flush = after_flush
|
||||
self.threads = []
|
||||
self.temp_file = None # The temporary file object
|
||||
self.temp_file_path = None # Path to the temporary file
|
||||
|
||||
parsed = urlparse(output_prefix)
|
||||
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
|
||||
self.is_s3 = parsed.scheme in ("s3", "s3a", "s3n")
|
||||
|
||||
if not self.is_s3:
|
||||
os.makedirs(output_prefix, exist_ok=True)
|
||||
|
||||
def _compute_hash(self, content: str) -> str:
|
||||
"""Compute a 20-character SHA1 hash of the given content."""
|
||||
def write_line(self, obj: Optional[Any]):
|
||||
if obj is None:
|
||||
return
|
||||
|
||||
line_bytes = json.dumps(obj, ensure_ascii=False).encode("utf-8")
|
||||
line_size = len(line_bytes) + 1 # +1 for newline
|
||||
|
||||
if self.batch_size + line_size > self.max_size:
|
||||
self._write_batch()
|
||||
|
||||
if self.batch_size == 0:
|
||||
# Open a new temporary file
|
||||
self.temp_file = tempfile.NamedTemporaryFile(mode="wb+", delete=False)
|
||||
self.temp_file_path = self.temp_file.name
|
||||
|
||||
self.temp_file.write(line_bytes + b"\n")
|
||||
self.batch_objects.append(obj)
|
||||
self.batch_size += line_size
|
||||
|
||||
def _write_batch(self):
|
||||
if self.batch_size == 0:
|
||||
return
|
||||
|
||||
# Close the temp file
|
||||
self.temp_file.flush()
|
||||
self.temp_file.close()
|
||||
|
||||
# Start a new thread to upload the temp file
|
||||
thread = threading.Thread(
|
||||
target=self._write_batch_to_file, args=(self.temp_file_path, self.batch_objects)
|
||||
)
|
||||
thread.start()
|
||||
self.threads.append(thread)
|
||||
|
||||
# Reset batch_objects and batch_size
|
||||
self.batch_objects = []
|
||||
self.batch_size = 0
|
||||
self.temp_file = None
|
||||
self.temp_file_path = None
|
||||
|
||||
def _write_batch_to_file(self, temp_file_path: str, batch_objects: List[Any]):
|
||||
# Compute hash based on file content
|
||||
hash_str = self._compute_hash(temp_file_path)
|
||||
output_path = self._get_output_path(hash_str)
|
||||
|
||||
if self.is_s3:
|
||||
# Use s3 upload_file
|
||||
parsed = urlparse(output_path)
|
||||
bucket = parsed.netloc
|
||||
key = parsed.path.lstrip("/")
|
||||
|
||||
# Use the s3 client directly
|
||||
try:
|
||||
workspace_s3.upload_file(temp_file_path, bucket, key)
|
||||
except Exception as e:
|
||||
print(f"Failed to upload {temp_file_path} to {output_path}: {e}")
|
||||
else:
|
||||
# Move the temp file to the output path
|
||||
os.rename(temp_file_path, output_path)
|
||||
|
||||
# After writing, call the after_flush callback if it is set
|
||||
if self.after_flush:
|
||||
self.after_flush(batch_objects)
|
||||
|
||||
# Delete the temporary file
|
||||
os.remove(temp_file_path)
|
||||
|
||||
def _compute_hash(self, temp_file_path: str) -> str:
|
||||
"""Compute a 20-character SHA1 hash of the file content."""
|
||||
sha1 = hashlib.sha1()
|
||||
sha1.update(content.encode('utf-8'))
|
||||
with open(temp_file_path, "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024*1024)
|
||||
if not data:
|
||||
break
|
||||
sha1.update(data)
|
||||
return sha1.hexdigest()[:20]
|
||||
|
||||
def _get_output_path(self, hash_str: str) -> str:
|
||||
@ -268,59 +347,15 @@ class BatchWriter:
|
||||
parsed = urlparse(self.output_prefix)
|
||||
if self.is_s3:
|
||||
bucket = parsed.netloc
|
||||
key = parsed.path.lstrip('/')
|
||||
if key and not key.endswith('/'):
|
||||
key += '/'
|
||||
key = parsed.path.lstrip("/")
|
||||
if key and not key.endswith("/"):
|
||||
key += "/"
|
||||
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
|
||||
return f"s3://{bucket}/{full_key}"
|
||||
else:
|
||||
filename = f"output_{hash_str}.jsonl"
|
||||
return os.path.join(self.output_prefix, filename)
|
||||
|
||||
def write_line(self, line: Optional[str]):
|
||||
if line is None or not line.strip():
|
||||
return
|
||||
|
||||
line_size = len(line.encode('utf-8')) + 1 # +1 for newline
|
||||
|
||||
if self.batch_size + line_size > self.max_size:
|
||||
self._write_batch()
|
||||
|
||||
self.batch.append(line)
|
||||
self.batch_size += line_size
|
||||
|
||||
def _write_batch(self):
|
||||
if not self.batch:
|
||||
return
|
||||
|
||||
batch_lines = self.batch.copy()
|
||||
batch_content = "\n".join(batch_lines) + "\n"
|
||||
hash_str = self._compute_hash(batch_content)
|
||||
output_path = self._get_output_path(hash_str)
|
||||
|
||||
# Start a new thread to write the batch
|
||||
thread = threading.Thread(
|
||||
target=self._write_batch_to_file,
|
||||
args=(batch_content, output_path, batch_lines)
|
||||
)
|
||||
thread.start()
|
||||
self.threads.append(thread)
|
||||
|
||||
# Clear the batch and batch_size
|
||||
self.batch = []
|
||||
self.batch_size = 0
|
||||
|
||||
def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]):
|
||||
if self.is_s3:
|
||||
put_s3_bytes(workspace_s3, output_path, batch_content.encode("utf-8"))
|
||||
else:
|
||||
with open(output_path, 'w', encoding='utf-8') as f_out:
|
||||
f_out.write(batch_content)
|
||||
|
||||
# After writing, call the after_flush callback if it is set
|
||||
if self.after_flush:
|
||||
self.after_flush(batch_lines)
|
||||
|
||||
def close(self):
|
||||
self._write_batch()
|
||||
# Wait for all threads to finish
|
||||
@ -520,11 +555,11 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
|
||||
|
||||
return dolma_doc
|
||||
|
||||
def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
|
||||
def mark_pdfs_done(s3_workspace: str, dolma_docs: list[dict]):
|
||||
db = DatabaseManager(s3_workspace)
|
||||
|
||||
for line in dolma_doc_lines:
|
||||
db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed")
|
||||
for doc in dolma_docs:
|
||||
db.update_pdf_status(doc["metadata"]["Source-File"], "completed")
|
||||
|
||||
def get_current_round(s3_workspace: str) -> int:
|
||||
path = s3_workspace[5:]
|
||||
@ -610,13 +645,13 @@ if __name__ == '__main__':
|
||||
print("Indexing all batch inference sent to this workspace")
|
||||
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
|
||||
|
||||
inference_output_paths = [
|
||||
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
|
||||
inference_output_paths = {
|
||||
s3_path: etag for s3_path, etag in inference_output_paths.items()
|
||||
if not db.is_file_processed(s3_path, etag)
|
||||
]
|
||||
}
|
||||
|
||||
print(f"Found {len(inference_output_paths):,} new batch inference results to index")
|
||||
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths}
|
||||
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths.items()}
|
||||
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||
s3_path, etag = future_to_path[future]
|
||||
@ -638,29 +673,72 @@ if __name__ == '__main__':
|
||||
potentially_done_pdfs = db.get_pdfs_by_status("pending")
|
||||
else:
|
||||
print(f"\nCreating batch inference files for new PDFs")
|
||||
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf, current_round, args.target_longest_image_dim, args.target_anchor_text_len): pdf for pdf in db.get_pdfs_by_status("pending")}
|
||||
pdf_list = list(db.get_pdfs_by_status("pending"))
|
||||
pdf_iter = iter(pdf_list)
|
||||
pending_futures = {}
|
||||
potentially_done_pdfs = []
|
||||
lines_written = 0
|
||||
new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb)
|
||||
total_pdfs = len(pdf_list)
|
||||
max_pending = 5000
|
||||
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||
pdf = future_to_path[future]
|
||||
inference_lines = future.result()
|
||||
with tqdm(total=total_pdfs) as pbar:
|
||||
# Submit initial batch of futures
|
||||
for _ in range(min(max_pending, total_pdfs)):
|
||||
pdf = next(pdf_iter)
|
||||
future = executor.submit(
|
||||
build_pdf_queries,
|
||||
args.workspace,
|
||||
pdf,
|
||||
current_round,
|
||||
args.target_longest_image_dim,
|
||||
args.target_anchor_text_len,
|
||||
)
|
||||
pending_futures[future] = pdf
|
||||
|
||||
if len(inference_lines) == 0:
|
||||
potentially_done_pdfs.append(pdf)
|
||||
while pending_futures:
|
||||
# Wait for the next future to complete
|
||||
done, _ = concurrent.futures.wait(
|
||||
pending_futures.keys(),
|
||||
return_when=concurrent.futures.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
for line in inference_lines:
|
||||
lines_written += 1
|
||||
for future in done:
|
||||
pdf = pending_futures.pop(future)
|
||||
inference_lines = future.result()
|
||||
|
||||
if line is not None:
|
||||
new_inference_writer.write_line(json.dumps(line))
|
||||
if len(inference_lines) == 0:
|
||||
potentially_done_pdfs.append(pdf)
|
||||
|
||||
for line in inference_lines:
|
||||
lines_written += 1
|
||||
|
||||
if line is not None:
|
||||
new_inference_writer.write_line(line)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Submit a new future if there are more PDFs
|
||||
try:
|
||||
pdf = next(pdf_iter)
|
||||
new_future = executor.submit(
|
||||
build_pdf_queries,
|
||||
args.workspace,
|
||||
pdf,
|
||||
current_round,
|
||||
args.target_longest_image_dim,
|
||||
args.target_anchor_text_len,
|
||||
)
|
||||
pending_futures[new_future] = pdf
|
||||
except StopIteration:
|
||||
pass # No more PDFs to process
|
||||
|
||||
new_inference_writer.close()
|
||||
|
||||
if lines_written > 0:
|
||||
print(f"Added {lines_written:,} new batch inference requests")
|
||||
|
||||
|
||||
# Now, finally, assemble any potentially done docs into dolma documents
|
||||
print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
|
||||
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
|
||||
@ -671,7 +749,7 @@ if __name__ == '__main__':
|
||||
dolma_doc = future.result()
|
||||
|
||||
if dolma_doc is not None:
|
||||
new_output_writer.write_line(json.dumps(dolma_doc))
|
||||
new_output_writer.write_line(dolma_doc)
|
||||
|
||||
new_output_writer.close()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user