Refactored to have a more efficient batchwriter, and also not allow too many running futures

This commit is contained in:
Jake Poznanski 2024-10-23 16:28:46 +00:00
parent d99096e9a2
commit 38dc5a2a0f

View File

@ -4,7 +4,7 @@ import boto3
import sqlite3 import sqlite3
import json import json
import argparse import argparse
import glob import uuid
import tempfile import tempfile
import datetime import datetime
import posixpath import posixpath
@ -19,6 +19,7 @@ from tqdm import tqdm
from functools import partial from functools import partial
from typing import Optional, List, Tuple, Dict, Callable, Any from typing import Optional, List, Tuple, Dict, Callable, Any
from urllib.parse import urlparse from urllib.parse import urlparse
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
from pdelfin.data.renderpdf import render_pdf_to_base64png from pdelfin.data.renderpdf import render_pdf_to_base64png
@ -241,26 +242,104 @@ class DatabaseManager:
self.conn.close() self.conn.close()
# Writes batches of lines out to a set of files, keeping each file below some maximum size
class BatchWriter: class BatchWriter:
def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None): def __init__(
self,
output_prefix: str,
max_size_mb: int = 250,
after_flush: Optional[Callable[[List[Any]], Any]] = None,
):
self.output_prefix = output_prefix self.output_prefix = output_prefix
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
self.batch = [] self.batch_objects = []
self.batch_size = 0 self.batch_size = 0
self.after_flush = after_flush self.after_flush = after_flush
self.threads = [] self.threads = []
self.temp_file = None # The temporary file object
self.temp_file_path = None # Path to the temporary file
parsed = urlparse(output_prefix) parsed = urlparse(output_prefix)
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n') self.is_s3 = parsed.scheme in ("s3", "s3a", "s3n")
if not self.is_s3: if not self.is_s3:
os.makedirs(output_prefix, exist_ok=True) os.makedirs(output_prefix, exist_ok=True)
def _compute_hash(self, content: str) -> str: def write_line(self, obj: Optional[Any]):
"""Compute a 20-character SHA1 hash of the given content.""" if obj is None:
return
line_bytes = json.dumps(obj, ensure_ascii=False).encode("utf-8")
line_size = len(line_bytes) + 1 # +1 for newline
if self.batch_size + line_size > self.max_size:
self._write_batch()
if self.batch_size == 0:
# Open a new temporary file
self.temp_file = tempfile.NamedTemporaryFile(mode="wb+", delete=False)
self.temp_file_path = self.temp_file.name
self.temp_file.write(line_bytes + b"\n")
self.batch_objects.append(obj)
self.batch_size += line_size
def _write_batch(self):
if self.batch_size == 0:
return
# Close the temp file
self.temp_file.flush()
self.temp_file.close()
# Start a new thread to upload the temp file
thread = threading.Thread(
target=self._write_batch_to_file, args=(self.temp_file_path, self.batch_objects)
)
thread.start()
self.threads.append(thread)
# Reset batch_objects and batch_size
self.batch_objects = []
self.batch_size = 0
self.temp_file = None
self.temp_file_path = None
def _write_batch_to_file(self, temp_file_path: str, batch_objects: List[Any]):
# Compute hash based on file content
hash_str = self._compute_hash(temp_file_path)
output_path = self._get_output_path(hash_str)
if self.is_s3:
# Use s3 upload_file
parsed = urlparse(output_path)
bucket = parsed.netloc
key = parsed.path.lstrip("/")
# Use the s3 client directly
try:
workspace_s3.upload_file(temp_file_path, bucket, key)
except Exception as e:
print(f"Failed to upload {temp_file_path} to {output_path}: {e}")
else:
# Move the temp file to the output path
os.rename(temp_file_path, output_path)
# After writing, call the after_flush callback if it is set
if self.after_flush:
self.after_flush(batch_objects)
# Delete the temporary file
os.remove(temp_file_path)
def _compute_hash(self, temp_file_path: str) -> str:
"""Compute a 20-character SHA1 hash of the file content."""
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
sha1.update(content.encode('utf-8')) with open(temp_file_path, "rb") as f:
while True:
data = f.read(1024*1024)
if not data:
break
sha1.update(data)
return sha1.hexdigest()[:20] return sha1.hexdigest()[:20]
def _get_output_path(self, hash_str: str) -> str: def _get_output_path(self, hash_str: str) -> str:
@ -268,59 +347,15 @@ class BatchWriter:
parsed = urlparse(self.output_prefix) parsed = urlparse(self.output_prefix)
if self.is_s3: if self.is_s3:
bucket = parsed.netloc bucket = parsed.netloc
key = parsed.path.lstrip('/') key = parsed.path.lstrip("/")
if key and not key.endswith('/'): if key and not key.endswith("/"):
key += '/' key += "/"
full_key = posixpath.join(key, f"output_{hash_str}.jsonl") full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
return f"s3://{bucket}/{full_key}" return f"s3://{bucket}/{full_key}"
else: else:
filename = f"output_{hash_str}.jsonl" filename = f"output_{hash_str}.jsonl"
return os.path.join(self.output_prefix, filename) return os.path.join(self.output_prefix, filename)
def write_line(self, line: Optional[str]):
if line is None or not line.strip():
return
line_size = len(line.encode('utf-8')) + 1 # +1 for newline
if self.batch_size + line_size > self.max_size:
self._write_batch()
self.batch.append(line)
self.batch_size += line_size
def _write_batch(self):
if not self.batch:
return
batch_lines = self.batch.copy()
batch_content = "\n".join(batch_lines) + "\n"
hash_str = self._compute_hash(batch_content)
output_path = self._get_output_path(hash_str)
# Start a new thread to write the batch
thread = threading.Thread(
target=self._write_batch_to_file,
args=(batch_content, output_path, batch_lines)
)
thread.start()
self.threads.append(thread)
# Clear the batch and batch_size
self.batch = []
self.batch_size = 0
def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]):
if self.is_s3:
put_s3_bytes(workspace_s3, output_path, batch_content.encode("utf-8"))
else:
with open(output_path, 'w', encoding='utf-8') as f_out:
f_out.write(batch_content)
# After writing, call the after_flush callback if it is set
if self.after_flush:
self.after_flush(batch_lines)
def close(self): def close(self):
self._write_batch() self._write_batch()
# Wait for all threads to finish # Wait for all threads to finish
@ -520,11 +555,11 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
return dolma_doc return dolma_doc
def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]): def mark_pdfs_done(s3_workspace: str, dolma_docs: list[dict]):
db = DatabaseManager(s3_workspace) db = DatabaseManager(s3_workspace)
for line in dolma_doc_lines: for doc in dolma_docs:
db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed") db.update_pdf_status(doc["metadata"]["Source-File"], "completed")
def get_current_round(s3_workspace: str) -> int: def get_current_round(s3_workspace: str) -> int:
path = s3_workspace[5:] path = s3_workspace[5:]
@ -610,13 +645,13 @@ if __name__ == '__main__':
print("Indexing all batch inference sent to this workspace") print("Indexing all batch inference sent to this workspace")
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl") inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
inference_output_paths = [ inference_output_paths = {
(s3_path, etag) for s3_path, etag in inference_output_paths.items() s3_path: etag for s3_path, etag in inference_output_paths.items()
if not db.is_file_processed(s3_path, etag) if not db.is_file_processed(s3_path, etag)
] }
print(f"Found {len(inference_output_paths):,} new batch inference results to index") print(f"Found {len(inference_output_paths):,} new batch inference results to index")
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths} future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths.items()}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path, etag = future_to_path[future] s3_path, etag = future_to_path[future]
@ -638,29 +673,72 @@ if __name__ == '__main__':
potentially_done_pdfs = db.get_pdfs_by_status("pending") potentially_done_pdfs = db.get_pdfs_by_status("pending")
else: else:
print(f"\nCreating batch inference files for new PDFs") print(f"\nCreating batch inference files for new PDFs")
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf, current_round, args.target_longest_image_dim, args.target_anchor_text_len): pdf for pdf in db.get_pdfs_by_status("pending")} pdf_list = list(db.get_pdfs_by_status("pending"))
pdf_iter = iter(pdf_list)
pending_futures = {}
potentially_done_pdfs = [] potentially_done_pdfs = []
lines_written = 0 lines_written = 0
new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb) new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb)
total_pdfs = len(pdf_list)
max_pending = 5000
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): with tqdm(total=total_pdfs) as pbar:
pdf = future_to_path[future] # Submit initial batch of futures
inference_lines = future.result() for _ in range(min(max_pending, total_pdfs)):
pdf = next(pdf_iter)
future = executor.submit(
build_pdf_queries,
args.workspace,
pdf,
current_round,
args.target_longest_image_dim,
args.target_anchor_text_len,
)
pending_futures[future] = pdf
if len(inference_lines) == 0: while pending_futures:
potentially_done_pdfs.append(pdf) # Wait for the next future to complete
done, _ = concurrent.futures.wait(
pending_futures.keys(),
return_when=concurrent.futures.FIRST_COMPLETED,
)
for line in inference_lines: for future in done:
lines_written += 1 pdf = pending_futures.pop(future)
inference_lines = future.result()
if line is not None: if len(inference_lines) == 0:
new_inference_writer.write_line(json.dumps(line)) potentially_done_pdfs.append(pdf)
for line in inference_lines:
lines_written += 1
if line is not None:
new_inference_writer.write_line(line)
pbar.update(1)
# Submit a new future if there are more PDFs
try:
pdf = next(pdf_iter)
new_future = executor.submit(
build_pdf_queries,
args.workspace,
pdf,
current_round,
args.target_longest_image_dim,
args.target_anchor_text_len,
)
pending_futures[new_future] = pdf
except StopIteration:
pass # No more PDFs to process
new_inference_writer.close() new_inference_writer.close()
if lines_written > 0: if lines_written > 0:
print(f"Added {lines_written:,} new batch inference requests") print(f"Added {lines_written:,} new batch inference requests")
# Now, finally, assemble any potentially done docs into dolma documents # Now, finally, assemble any potentially done docs into dolma documents
print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output") print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs} future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
@ -671,7 +749,7 @@ if __name__ == '__main__':
dolma_doc = future.result() dolma_doc = future.result()
if dolma_doc is not None: if dolma_doc is not None:
new_output_writer.write_line(json.dumps(dolma_doc)) new_output_writer.write_line(dolma_doc)
new_output_writer.close() new_output_writer.close()