Refactored to have a more efficient batchwriter, and also not allow too many running futures

This commit is contained in:
Jake Poznanski 2024-10-23 16:28:46 +00:00
parent d99096e9a2
commit 38dc5a2a0f

View File

@ -4,7 +4,7 @@ import boto3
import sqlite3
import json
import argparse
import glob
import uuid
import tempfile
import datetime
import posixpath
@ -19,6 +19,7 @@ from tqdm import tqdm
from functools import partial
from typing import Optional, List, Tuple, Dict, Callable, Any
from urllib.parse import urlparse
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor, as_completed
from pdelfin.data.renderpdf import render_pdf_to_base64png
@ -241,26 +242,104 @@ class DatabaseManager:
self.conn.close()
# Writes batches of lines out to a set of files, keeping each file below some maximum size
class BatchWriter:
def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None):
def __init__(
self,
output_prefix: str,
max_size_mb: int = 250,
after_flush: Optional[Callable[[List[Any]], Any]] = None,
):
self.output_prefix = output_prefix
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
self.batch = []
self.batch_objects = []
self.batch_size = 0
self.after_flush = after_flush
self.threads = []
self.temp_file = None # The temporary file object
self.temp_file_path = None # Path to the temporary file
parsed = urlparse(output_prefix)
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
self.is_s3 = parsed.scheme in ("s3", "s3a", "s3n")
if not self.is_s3:
os.makedirs(output_prefix, exist_ok=True)
def _compute_hash(self, content: str) -> str:
"""Compute a 20-character SHA1 hash of the given content."""
def write_line(self, obj: Optional[Any]):
if obj is None:
return
line_bytes = json.dumps(obj, ensure_ascii=False).encode("utf-8")
line_size = len(line_bytes) + 1 # +1 for newline
if self.batch_size + line_size > self.max_size:
self._write_batch()
if self.batch_size == 0:
# Open a new temporary file
self.temp_file = tempfile.NamedTemporaryFile(mode="wb+", delete=False)
self.temp_file_path = self.temp_file.name
self.temp_file.write(line_bytes + b"\n")
self.batch_objects.append(obj)
self.batch_size += line_size
def _write_batch(self):
if self.batch_size == 0:
return
# Close the temp file
self.temp_file.flush()
self.temp_file.close()
# Start a new thread to upload the temp file
thread = threading.Thread(
target=self._write_batch_to_file, args=(self.temp_file_path, self.batch_objects)
)
thread.start()
self.threads.append(thread)
# Reset batch_objects and batch_size
self.batch_objects = []
self.batch_size = 0
self.temp_file = None
self.temp_file_path = None
def _write_batch_to_file(self, temp_file_path: str, batch_objects: List[Any]):
# Compute hash based on file content
hash_str = self._compute_hash(temp_file_path)
output_path = self._get_output_path(hash_str)
if self.is_s3:
# Use s3 upload_file
parsed = urlparse(output_path)
bucket = parsed.netloc
key = parsed.path.lstrip("/")
# Use the s3 client directly
try:
workspace_s3.upload_file(temp_file_path, bucket, key)
except Exception as e:
print(f"Failed to upload {temp_file_path} to {output_path}: {e}")
else:
# Move the temp file to the output path
os.rename(temp_file_path, output_path)
# After writing, call the after_flush callback if it is set
if self.after_flush:
self.after_flush(batch_objects)
# Delete the temporary file
os.remove(temp_file_path)
def _compute_hash(self, temp_file_path: str) -> str:
"""Compute a 20-character SHA1 hash of the file content."""
sha1 = hashlib.sha1()
sha1.update(content.encode('utf-8'))
with open(temp_file_path, "rb") as f:
while True:
data = f.read(1024*1024)
if not data:
break
sha1.update(data)
return sha1.hexdigest()[:20]
def _get_output_path(self, hash_str: str) -> str:
@ -268,59 +347,15 @@ class BatchWriter:
parsed = urlparse(self.output_prefix)
if self.is_s3:
bucket = parsed.netloc
key = parsed.path.lstrip('/')
if key and not key.endswith('/'):
key += '/'
key = parsed.path.lstrip("/")
if key and not key.endswith("/"):
key += "/"
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
return f"s3://{bucket}/{full_key}"
else:
filename = f"output_{hash_str}.jsonl"
return os.path.join(self.output_prefix, filename)
def write_line(self, line: Optional[str]):
if line is None or not line.strip():
return
line_size = len(line.encode('utf-8')) + 1 # +1 for newline
if self.batch_size + line_size > self.max_size:
self._write_batch()
self.batch.append(line)
self.batch_size += line_size
def _write_batch(self):
if not self.batch:
return
batch_lines = self.batch.copy()
batch_content = "\n".join(batch_lines) + "\n"
hash_str = self._compute_hash(batch_content)
output_path = self._get_output_path(hash_str)
# Start a new thread to write the batch
thread = threading.Thread(
target=self._write_batch_to_file,
args=(batch_content, output_path, batch_lines)
)
thread.start()
self.threads.append(thread)
# Clear the batch and batch_size
self.batch = []
self.batch_size = 0
def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]):
if self.is_s3:
put_s3_bytes(workspace_s3, output_path, batch_content.encode("utf-8"))
else:
with open(output_path, 'w', encoding='utf-8') as f_out:
f_out.write(batch_content)
# After writing, call the after_flush callback if it is set
if self.after_flush:
self.after_flush(batch_lines)
def close(self):
self._write_batch()
# Wait for all threads to finish
@ -520,11 +555,11 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
return dolma_doc
def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
def mark_pdfs_done(s3_workspace: str, dolma_docs: list[dict]):
db = DatabaseManager(s3_workspace)
for line in dolma_doc_lines:
db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed")
for doc in dolma_docs:
db.update_pdf_status(doc["metadata"]["Source-File"], "completed")
def get_current_round(s3_workspace: str) -> int:
path = s3_workspace[5:]
@ -610,13 +645,13 @@ if __name__ == '__main__':
print("Indexing all batch inference sent to this workspace")
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
inference_output_paths = [
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
inference_output_paths = {
s3_path: etag for s3_path, etag in inference_output_paths.items()
if not db.is_file_processed(s3_path, etag)
]
}
print(f"Found {len(inference_output_paths):,} new batch inference results to index")
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths}
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths.items()}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path, etag = future_to_path[future]
@ -638,29 +673,72 @@ if __name__ == '__main__':
potentially_done_pdfs = db.get_pdfs_by_status("pending")
else:
print(f"\nCreating batch inference files for new PDFs")
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf, current_round, args.target_longest_image_dim, args.target_anchor_text_len): pdf for pdf in db.get_pdfs_by_status("pending")}
pdf_list = list(db.get_pdfs_by_status("pending"))
pdf_iter = iter(pdf_list)
pending_futures = {}
potentially_done_pdfs = []
lines_written = 0
new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb)
total_pdfs = len(pdf_list)
max_pending = 5000
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
pdf = future_to_path[future]
inference_lines = future.result()
with tqdm(total=total_pdfs) as pbar:
# Submit initial batch of futures
for _ in range(min(max_pending, total_pdfs)):
pdf = next(pdf_iter)
future = executor.submit(
build_pdf_queries,
args.workspace,
pdf,
current_round,
args.target_longest_image_dim,
args.target_anchor_text_len,
)
pending_futures[future] = pdf
if len(inference_lines) == 0:
potentially_done_pdfs.append(pdf)
while pending_futures:
# Wait for the next future to complete
done, _ = concurrent.futures.wait(
pending_futures.keys(),
return_when=concurrent.futures.FIRST_COMPLETED,
)
for line in inference_lines:
lines_written += 1
for future in done:
pdf = pending_futures.pop(future)
inference_lines = future.result()
if line is not None:
new_inference_writer.write_line(json.dumps(line))
if len(inference_lines) == 0:
potentially_done_pdfs.append(pdf)
for line in inference_lines:
lines_written += 1
if line is not None:
new_inference_writer.write_line(line)
pbar.update(1)
# Submit a new future if there are more PDFs
try:
pdf = next(pdf_iter)
new_future = executor.submit(
build_pdf_queries,
args.workspace,
pdf,
current_round,
args.target_longest_image_dim,
args.target_anchor_text_len,
)
pending_futures[new_future] = pdf
except StopIteration:
pass # No more PDFs to process
new_inference_writer.close()
if lines_written > 0:
print(f"Added {lines_written:,} new batch inference requests")
# Now, finally, assemble any potentially done docs into dolma documents
print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
@ -671,7 +749,7 @@ if __name__ == '__main__':
dolma_doc = future.result()
if dolma_doc is not None:
new_output_writer.write_line(json.dumps(dolma_doc))
new_output_writer.write_line(dolma_doc)
new_output_writer.close()