Runs to the end now

This commit is contained in:
Jake Poznanski 2024-10-14 20:28:54 +00:00
parent 879b974af2
commit 1ed9e4c947

View File

@ -8,13 +8,14 @@ import glob
import tempfile
import datetime
import posixpath
import smart_open
import threading
import logging
from dataclasses import dataclass
from pypdf import PdfReader
from tqdm import tqdm
from typing import Optional, List, Tuple, Dict
from functools import partial
from typing import Optional, List, Tuple, Dict, Callable, Any
from urllib.parse import urlparse
from concurrent.futures import ProcessPoolExecutor, as_completed
@ -120,6 +121,14 @@ class DatabaseManager:
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
result = self.cursor.fetchone()
return result is not None and result[0] == etag
def update_processed_file(self, s3_path, etag):
self.cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
""", (s3_path, etag))
self.conn.commit()
def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
if index_entries:
@ -162,14 +171,6 @@ class DatabaseManager:
result = self.cursor.fetchone()
return -1 if result[0] is None else result[0]
def update_processed_file(self, s3_path, etag):
self.cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
""", (s3_path, etag))
self.conn.commit()
def pdf_exists(self, s3_path: str) -> bool:
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
return self.cursor.fetchone() is not None
@ -184,6 +185,14 @@ class DatabaseManager:
except sqlite3.IntegrityError:
print(f"PDF with s3_path '{s3_path}' already exists.")
def update_pdf_status(self, s3_path: str, new_status: str) -> None:
self.cursor.execute("""
UPDATE pdfs
SET status = ?
WHERE s3_path = ?
""", (new_status, s3_path))
self.conn.commit()
def get_pdf(self, s3_path: str) -> Optional[PDFRecord]:
self.cursor.execute("""
SELECT s3_path, num_pages, status
@ -226,11 +235,13 @@ class DatabaseManager:
# Writes batches of lines out to a set of files, keeping each file below some maximum size
class BatchWriter:
def __init__(self, output_prefix: str, max_size_mb: int = 250):
def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None):
self.output_prefix = output_prefix
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
self.batch = []
self.batch_size = 0
self.after_flush = after_flush
self.threads = []
parsed = urlparse(output_prefix)
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
@ -259,14 +270,14 @@ class BatchWriter:
return os.path.join(self.output_prefix, filename)
def write_line(self, line: Optional[str]):
if line is None or len(line.strip()) == 0:
if line is None or not line.strip():
return
line_size = len(line.encode('utf-8')) + 1 # +1 for newline
if self.batch_size + line_size > self.max_size:
self._write_batch()
self.batch.append(line)
self.batch_size += line_size
@ -274,19 +285,53 @@ class BatchWriter:
if not self.batch:
return
batch_content = "\n".join(self.batch) + "\n"
batch_lines = self.batch.copy()
batch_content = "\n".join(batch_lines) + "\n"
hash_str = self._compute_hash(batch_content)
output_path = self._get_output_path(hash_str)
with smart_open.open(output_path, 'w') as f_out:
f_out.write(batch_content)
print(f"Wrote batch to {output_path}")
# Start a new thread to write the batch
thread = threading.Thread(
target=self._write_batch_to_file,
args=(batch_content, output_path, batch_lines)
)
thread.start()
self.threads.append(thread)
# Clear the batch and batch_size
self.batch = []
self.batch_size = 0
def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]):
if self.is_s3:
self._write_to_s3(batch_content, output_path)
else:
with open(output_path, 'w', encoding='utf-8') as f_out:
f_out.write(batch_content)
# After writing, call the after_flush callback if it is set
if self.after_flush:
self.after_flush(batch_lines)
def _write_to_s3(self, content: str, s3_path: str):
# Parse the s3_path to get bucket and key
parsed = urlparse(s3_path)
bucket = parsed.netloc
key = parsed.path.lstrip('/')
s3 = boto3.client('s3')
s3.put_object(
Bucket=bucket,
Key=key,
Body=content.encode('utf-8'),
ContentType='text/plain; charset=utf-8'
)
def close(self):
self._write_batch()
# Wait for all threads to finish
for thread in self.threads:
thread.join()
def parse_s3_path(s3_path):
@ -493,6 +538,12 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
return dolma_doc
def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
db = DatabaseManager(s3_workspace)
for line in dolma_doc_lines:
db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
@ -569,7 +620,9 @@ if __name__ == '__main__':
for line in inference_lines:
lines_written += 1
new_inference_writer.write_line(json.dumps(line))
if line is not None:
new_inference_writer.write_line(json.dumps(line))
new_inference_writer.close()
@ -579,20 +632,24 @@ if __name__ == '__main__':
# Now, finally, assemble any potentially done docs into dolma documents
print(f"\nAssembling {len(potentially_done_pdfs):,} potentially finished PDFs into Dolma documents at {args.workspace}/output")
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb)
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb, after_flush=partial(mark_pdfs_done, args.workspace))
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
pdf = future_to_path[future]
dolma_doc = future.result()
new_output_writer.write_line(json.dumps(dolma_doc))
if dolma_doc is not None:
new_output_writer.write_line(json.dumps(dolma_doc))
new_output_writer.close()
print("\nFinal statistics:")
# Output the number of documents in each status "pending" and "completed"
print("\nWork finished, waiting for all workers to finish cleaning up")
executor.shutdown(wait=True)
db.close()
# TODO
# 2. Have a way to apply basic spam + language filter if you can during add pdfs step