mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-29 08:41:00 +00:00
Refactored to have a more efficient batchwriter, and also not allow too many running futures
This commit is contained in:
parent
d99096e9a2
commit
38dc5a2a0f
@ -4,7 +4,7 @@ import boto3
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import json
|
import json
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import uuid
|
||||||
import tempfile
|
import tempfile
|
||||||
import datetime
|
import datetime
|
||||||
import posixpath
|
import posixpath
|
||||||
@ -19,6 +19,7 @@ from tqdm import tqdm
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, List, Tuple, Dict, Callable, Any
|
from typing import Optional, List, Tuple, Dict, Callable, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
import concurrent.futures
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
|
||||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||||
@ -241,26 +242,104 @@ class DatabaseManager:
|
|||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
|
||||||
|
|
||||||
# Writes batches of lines out to a set of files, keeping each file below some maximum size
|
|
||||||
class BatchWriter:
|
class BatchWriter:
|
||||||
def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_prefix: str,
|
||||||
|
max_size_mb: int = 250,
|
||||||
|
after_flush: Optional[Callable[[List[Any]], Any]] = None,
|
||||||
|
):
|
||||||
self.output_prefix = output_prefix
|
self.output_prefix = output_prefix
|
||||||
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
|
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||||
self.batch = []
|
self.batch_objects = []
|
||||||
self.batch_size = 0
|
self.batch_size = 0
|
||||||
self.after_flush = after_flush
|
self.after_flush = after_flush
|
||||||
self.threads = []
|
self.threads = []
|
||||||
|
self.temp_file = None # The temporary file object
|
||||||
|
self.temp_file_path = None # Path to the temporary file
|
||||||
|
|
||||||
parsed = urlparse(output_prefix)
|
parsed = urlparse(output_prefix)
|
||||||
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
|
self.is_s3 = parsed.scheme in ("s3", "s3a", "s3n")
|
||||||
|
|
||||||
if not self.is_s3:
|
if not self.is_s3:
|
||||||
os.makedirs(output_prefix, exist_ok=True)
|
os.makedirs(output_prefix, exist_ok=True)
|
||||||
|
|
||||||
def _compute_hash(self, content: str) -> str:
|
def write_line(self, obj: Optional[Any]):
|
||||||
"""Compute a 20-character SHA1 hash of the given content."""
|
if obj is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
line_bytes = json.dumps(obj, ensure_ascii=False).encode("utf-8")
|
||||||
|
line_size = len(line_bytes) + 1 # +1 for newline
|
||||||
|
|
||||||
|
if self.batch_size + line_size > self.max_size:
|
||||||
|
self._write_batch()
|
||||||
|
|
||||||
|
if self.batch_size == 0:
|
||||||
|
# Open a new temporary file
|
||||||
|
self.temp_file = tempfile.NamedTemporaryFile(mode="wb+", delete=False)
|
||||||
|
self.temp_file_path = self.temp_file.name
|
||||||
|
|
||||||
|
self.temp_file.write(line_bytes + b"\n")
|
||||||
|
self.batch_objects.append(obj)
|
||||||
|
self.batch_size += line_size
|
||||||
|
|
||||||
|
def _write_batch(self):
|
||||||
|
if self.batch_size == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Close the temp file
|
||||||
|
self.temp_file.flush()
|
||||||
|
self.temp_file.close()
|
||||||
|
|
||||||
|
# Start a new thread to upload the temp file
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=self._write_batch_to_file, args=(self.temp_file_path, self.batch_objects)
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
self.threads.append(thread)
|
||||||
|
|
||||||
|
# Reset batch_objects and batch_size
|
||||||
|
self.batch_objects = []
|
||||||
|
self.batch_size = 0
|
||||||
|
self.temp_file = None
|
||||||
|
self.temp_file_path = None
|
||||||
|
|
||||||
|
def _write_batch_to_file(self, temp_file_path: str, batch_objects: List[Any]):
|
||||||
|
# Compute hash based on file content
|
||||||
|
hash_str = self._compute_hash(temp_file_path)
|
||||||
|
output_path = self._get_output_path(hash_str)
|
||||||
|
|
||||||
|
if self.is_s3:
|
||||||
|
# Use s3 upload_file
|
||||||
|
parsed = urlparse(output_path)
|
||||||
|
bucket = parsed.netloc
|
||||||
|
key = parsed.path.lstrip("/")
|
||||||
|
|
||||||
|
# Use the s3 client directly
|
||||||
|
try:
|
||||||
|
workspace_s3.upload_file(temp_file_path, bucket, key)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to upload {temp_file_path} to {output_path}: {e}")
|
||||||
|
else:
|
||||||
|
# Move the temp file to the output path
|
||||||
|
os.rename(temp_file_path, output_path)
|
||||||
|
|
||||||
|
# After writing, call the after_flush callback if it is set
|
||||||
|
if self.after_flush:
|
||||||
|
self.after_flush(batch_objects)
|
||||||
|
|
||||||
|
# Delete the temporary file
|
||||||
|
os.remove(temp_file_path)
|
||||||
|
|
||||||
|
def _compute_hash(self, temp_file_path: str) -> str:
|
||||||
|
"""Compute a 20-character SHA1 hash of the file content."""
|
||||||
sha1 = hashlib.sha1()
|
sha1 = hashlib.sha1()
|
||||||
sha1.update(content.encode('utf-8'))
|
with open(temp_file_path, "rb") as f:
|
||||||
|
while True:
|
||||||
|
data = f.read(1024*1024)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
sha1.update(data)
|
||||||
return sha1.hexdigest()[:20]
|
return sha1.hexdigest()[:20]
|
||||||
|
|
||||||
def _get_output_path(self, hash_str: str) -> str:
|
def _get_output_path(self, hash_str: str) -> str:
|
||||||
@ -268,59 +347,15 @@ class BatchWriter:
|
|||||||
parsed = urlparse(self.output_prefix)
|
parsed = urlparse(self.output_prefix)
|
||||||
if self.is_s3:
|
if self.is_s3:
|
||||||
bucket = parsed.netloc
|
bucket = parsed.netloc
|
||||||
key = parsed.path.lstrip('/')
|
key = parsed.path.lstrip("/")
|
||||||
if key and not key.endswith('/'):
|
if key and not key.endswith("/"):
|
||||||
key += '/'
|
key += "/"
|
||||||
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
|
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
|
||||||
return f"s3://{bucket}/{full_key}"
|
return f"s3://{bucket}/{full_key}"
|
||||||
else:
|
else:
|
||||||
filename = f"output_{hash_str}.jsonl"
|
filename = f"output_{hash_str}.jsonl"
|
||||||
return os.path.join(self.output_prefix, filename)
|
return os.path.join(self.output_prefix, filename)
|
||||||
|
|
||||||
def write_line(self, line: Optional[str]):
|
|
||||||
if line is None or not line.strip():
|
|
||||||
return
|
|
||||||
|
|
||||||
line_size = len(line.encode('utf-8')) + 1 # +1 for newline
|
|
||||||
|
|
||||||
if self.batch_size + line_size > self.max_size:
|
|
||||||
self._write_batch()
|
|
||||||
|
|
||||||
self.batch.append(line)
|
|
||||||
self.batch_size += line_size
|
|
||||||
|
|
||||||
def _write_batch(self):
|
|
||||||
if not self.batch:
|
|
||||||
return
|
|
||||||
|
|
||||||
batch_lines = self.batch.copy()
|
|
||||||
batch_content = "\n".join(batch_lines) + "\n"
|
|
||||||
hash_str = self._compute_hash(batch_content)
|
|
||||||
output_path = self._get_output_path(hash_str)
|
|
||||||
|
|
||||||
# Start a new thread to write the batch
|
|
||||||
thread = threading.Thread(
|
|
||||||
target=self._write_batch_to_file,
|
|
||||||
args=(batch_content, output_path, batch_lines)
|
|
||||||
)
|
|
||||||
thread.start()
|
|
||||||
self.threads.append(thread)
|
|
||||||
|
|
||||||
# Clear the batch and batch_size
|
|
||||||
self.batch = []
|
|
||||||
self.batch_size = 0
|
|
||||||
|
|
||||||
def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]):
|
|
||||||
if self.is_s3:
|
|
||||||
put_s3_bytes(workspace_s3, output_path, batch_content.encode("utf-8"))
|
|
||||||
else:
|
|
||||||
with open(output_path, 'w', encoding='utf-8') as f_out:
|
|
||||||
f_out.write(batch_content)
|
|
||||||
|
|
||||||
# After writing, call the after_flush callback if it is set
|
|
||||||
if self.after_flush:
|
|
||||||
self.after_flush(batch_lines)
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._write_batch()
|
self._write_batch()
|
||||||
# Wait for all threads to finish
|
# Wait for all threads to finish
|
||||||
@ -520,11 +555,11 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
|
|||||||
|
|
||||||
return dolma_doc
|
return dolma_doc
|
||||||
|
|
||||||
def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
|
def mark_pdfs_done(s3_workspace: str, dolma_docs: list[dict]):
|
||||||
db = DatabaseManager(s3_workspace)
|
db = DatabaseManager(s3_workspace)
|
||||||
|
|
||||||
for line in dolma_doc_lines:
|
for doc in dolma_docs:
|
||||||
db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed")
|
db.update_pdf_status(doc["metadata"]["Source-File"], "completed")
|
||||||
|
|
||||||
def get_current_round(s3_workspace: str) -> int:
|
def get_current_round(s3_workspace: str) -> int:
|
||||||
path = s3_workspace[5:]
|
path = s3_workspace[5:]
|
||||||
@ -610,13 +645,13 @@ if __name__ == '__main__':
|
|||||||
print("Indexing all batch inference sent to this workspace")
|
print("Indexing all batch inference sent to this workspace")
|
||||||
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
|
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
|
||||||
|
|
||||||
inference_output_paths = [
|
inference_output_paths = {
|
||||||
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
|
s3_path: etag for s3_path, etag in inference_output_paths.items()
|
||||||
if not db.is_file_processed(s3_path, etag)
|
if not db.is_file_processed(s3_path, etag)
|
||||||
]
|
}
|
||||||
|
|
||||||
print(f"Found {len(inference_output_paths):,} new batch inference results to index")
|
print(f"Found {len(inference_output_paths):,} new batch inference results to index")
|
||||||
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths}
|
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths.items()}
|
||||||
|
|
||||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||||
s3_path, etag = future_to_path[future]
|
s3_path, etag = future_to_path[future]
|
||||||
@ -638,29 +673,72 @@ if __name__ == '__main__':
|
|||||||
potentially_done_pdfs = db.get_pdfs_by_status("pending")
|
potentially_done_pdfs = db.get_pdfs_by_status("pending")
|
||||||
else:
|
else:
|
||||||
print(f"\nCreating batch inference files for new PDFs")
|
print(f"\nCreating batch inference files for new PDFs")
|
||||||
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf, current_round, args.target_longest_image_dim, args.target_anchor_text_len): pdf for pdf in db.get_pdfs_by_status("pending")}
|
pdf_list = list(db.get_pdfs_by_status("pending"))
|
||||||
|
pdf_iter = iter(pdf_list)
|
||||||
|
pending_futures = {}
|
||||||
potentially_done_pdfs = []
|
potentially_done_pdfs = []
|
||||||
lines_written = 0
|
lines_written = 0
|
||||||
new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb)
|
new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb)
|
||||||
|
total_pdfs = len(pdf_list)
|
||||||
|
max_pending = 5000
|
||||||
|
|
||||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
with tqdm(total=total_pdfs) as pbar:
|
||||||
pdf = future_to_path[future]
|
# Submit initial batch of futures
|
||||||
inference_lines = future.result()
|
for _ in range(min(max_pending, total_pdfs)):
|
||||||
|
pdf = next(pdf_iter)
|
||||||
|
future = executor.submit(
|
||||||
|
build_pdf_queries,
|
||||||
|
args.workspace,
|
||||||
|
pdf,
|
||||||
|
current_round,
|
||||||
|
args.target_longest_image_dim,
|
||||||
|
args.target_anchor_text_len,
|
||||||
|
)
|
||||||
|
pending_futures[future] = pdf
|
||||||
|
|
||||||
if len(inference_lines) == 0:
|
while pending_futures:
|
||||||
potentially_done_pdfs.append(pdf)
|
# Wait for the next future to complete
|
||||||
|
done, _ = concurrent.futures.wait(
|
||||||
|
pending_futures.keys(),
|
||||||
|
return_when=concurrent.futures.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
|
||||||
for line in inference_lines:
|
for future in done:
|
||||||
lines_written += 1
|
pdf = pending_futures.pop(future)
|
||||||
|
inference_lines = future.result()
|
||||||
|
|
||||||
if line is not None:
|
if len(inference_lines) == 0:
|
||||||
new_inference_writer.write_line(json.dumps(line))
|
potentially_done_pdfs.append(pdf)
|
||||||
|
|
||||||
|
for line in inference_lines:
|
||||||
|
lines_written += 1
|
||||||
|
|
||||||
|
if line is not None:
|
||||||
|
new_inference_writer.write_line(line)
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Submit a new future if there are more PDFs
|
||||||
|
try:
|
||||||
|
pdf = next(pdf_iter)
|
||||||
|
new_future = executor.submit(
|
||||||
|
build_pdf_queries,
|
||||||
|
args.workspace,
|
||||||
|
pdf,
|
||||||
|
current_round,
|
||||||
|
args.target_longest_image_dim,
|
||||||
|
args.target_anchor_text_len,
|
||||||
|
)
|
||||||
|
pending_futures[new_future] = pdf
|
||||||
|
except StopIteration:
|
||||||
|
pass # No more PDFs to process
|
||||||
|
|
||||||
new_inference_writer.close()
|
new_inference_writer.close()
|
||||||
|
|
||||||
if lines_written > 0:
|
if lines_written > 0:
|
||||||
print(f"Added {lines_written:,} new batch inference requests")
|
print(f"Added {lines_written:,} new batch inference requests")
|
||||||
|
|
||||||
|
|
||||||
# Now, finally, assemble any potentially done docs into dolma documents
|
# Now, finally, assemble any potentially done docs into dolma documents
|
||||||
print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
|
print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
|
||||||
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
|
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
|
||||||
@ -671,7 +749,7 @@ if __name__ == '__main__':
|
|||||||
dolma_doc = future.result()
|
dolma_doc = future.result()
|
||||||
|
|
||||||
if dolma_doc is not None:
|
if dolma_doc is not None:
|
||||||
new_output_writer.write_line(json.dumps(dolma_doc))
|
new_output_writer.write_line(dolma_doc)
|
||||||
|
|
||||||
new_output_writer.close()
|
new_output_writer.close()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user