Runs to the end now

This commit is contained in:
Jake Poznanski 2024-10-14 20:28:54 +00:00
parent 879b974af2
commit 1ed9e4c947

View File

@ -8,13 +8,14 @@ import glob
import tempfile import tempfile
import datetime import datetime
import posixpath import posixpath
import smart_open import threading
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from pypdf import PdfReader from pypdf import PdfReader
from tqdm import tqdm from tqdm import tqdm
from typing import Optional, List, Tuple, Dict from functools import partial
from typing import Optional, List, Tuple, Dict, Callable, Any
from urllib.parse import urlparse from urllib.parse import urlparse
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
@ -120,6 +121,14 @@ class DatabaseManager:
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,)) self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
result = self.cursor.fetchone() result = self.cursor.fetchone()
return result is not None and result[0] == etag return result is not None and result[0] == etag
def update_processed_file(self, s3_path, etag):
self.cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
""", (s3_path, etag))
self.conn.commit()
def add_index_entries(self, index_entries: List[BatchInferenceRecord]): def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
if index_entries: if index_entries:
@ -162,14 +171,6 @@ class DatabaseManager:
result = self.cursor.fetchone() result = self.cursor.fetchone()
return -1 if result[0] is None else result[0] return -1 if result[0] is None else result[0]
def update_processed_file(self, s3_path, etag):
self.cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
""", (s3_path, etag))
self.conn.commit()
def pdf_exists(self, s3_path: str) -> bool: def pdf_exists(self, s3_path: str) -> bool:
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,)) self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
return self.cursor.fetchone() is not None return self.cursor.fetchone() is not None
@ -184,6 +185,14 @@ class DatabaseManager:
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
print(f"PDF with s3_path '{s3_path}' already exists.") print(f"PDF with s3_path '{s3_path}' already exists.")
def update_pdf_status(self, s3_path: str, new_status: str) -> None:
self.cursor.execute("""
UPDATE pdfs
SET status = ?
WHERE s3_path = ?
""", (new_status, s3_path))
self.conn.commit()
def get_pdf(self, s3_path: str) -> Optional[PDFRecord]: def get_pdf(self, s3_path: str) -> Optional[PDFRecord]:
self.cursor.execute(""" self.cursor.execute("""
SELECT s3_path, num_pages, status SELECT s3_path, num_pages, status
@ -226,11 +235,13 @@ class DatabaseManager:
# Writes batches of lines out to a set of files, keeping each file below some maximum size # Writes batches of lines out to a set of files, keeping each file below some maximum size
class BatchWriter: class BatchWriter:
def __init__(self, output_prefix: str, max_size_mb: int = 250): def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None):
self.output_prefix = output_prefix self.output_prefix = output_prefix
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
self.batch = [] self.batch = []
self.batch_size = 0 self.batch_size = 0
self.after_flush = after_flush
self.threads = []
parsed = urlparse(output_prefix) parsed = urlparse(output_prefix)
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n') self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
@ -259,14 +270,14 @@ class BatchWriter:
return os.path.join(self.output_prefix, filename) return os.path.join(self.output_prefix, filename)
def write_line(self, line: Optional[str]): def write_line(self, line: Optional[str]):
if line is None or len(line.strip()) == 0: if line is None or not line.strip():
return return
line_size = len(line.encode('utf-8')) + 1 # +1 for newline line_size = len(line.encode('utf-8')) + 1 # +1 for newline
if self.batch_size + line_size > self.max_size: if self.batch_size + line_size > self.max_size:
self._write_batch() self._write_batch()
self.batch.append(line) self.batch.append(line)
self.batch_size += line_size self.batch_size += line_size
@ -274,19 +285,53 @@ class BatchWriter:
if not self.batch: if not self.batch:
return return
batch_content = "\n".join(self.batch) + "\n" batch_lines = self.batch.copy()
batch_content = "\n".join(batch_lines) + "\n"
hash_str = self._compute_hash(batch_content) hash_str = self._compute_hash(batch_content)
output_path = self._get_output_path(hash_str) output_path = self._get_output_path(hash_str)
with smart_open.open(output_path, 'w') as f_out: # Start a new thread to write the batch
f_out.write(batch_content) thread = threading.Thread(
print(f"Wrote batch to {output_path}") target=self._write_batch_to_file,
args=(batch_content, output_path, batch_lines)
)
thread.start()
self.threads.append(thread)
# Clear the batch and batch_size
self.batch = [] self.batch = []
self.batch_size = 0 self.batch_size = 0
def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]):
if self.is_s3:
self._write_to_s3(batch_content, output_path)
else:
with open(output_path, 'w', encoding='utf-8') as f_out:
f_out.write(batch_content)
# After writing, call the after_flush callback if it is set
if self.after_flush:
self.after_flush(batch_lines)
def _write_to_s3(self, content: str, s3_path: str):
# Parse the s3_path to get bucket and key
parsed = urlparse(s3_path)
bucket = parsed.netloc
key = parsed.path.lstrip('/')
s3 = boto3.client('s3')
s3.put_object(
Bucket=bucket,
Key=key,
Body=content.encode('utf-8'),
ContentType='text/plain; charset=utf-8'
)
def close(self): def close(self):
self._write_batch() self._write_batch()
# Wait for all threads to finish
for thread in self.threads:
thread.join()
def parse_s3_path(s3_path): def parse_s3_path(s3_path):
@ -493,6 +538,12 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
return dolma_doc return dolma_doc
def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
db = DatabaseManager(s3_workspace)
for line in dolma_doc_lines:
db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed")
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
@ -569,7 +620,9 @@ if __name__ == '__main__':
for line in inference_lines: for line in inference_lines:
lines_written += 1 lines_written += 1
new_inference_writer.write_line(json.dumps(line))
if line is not None:
new_inference_writer.write_line(json.dumps(line))
new_inference_writer.close() new_inference_writer.close()
@ -579,20 +632,24 @@ if __name__ == '__main__':
# Now, finally, assemble any potentially done docs into dolma documents # Now, finally, assemble any potentially done docs into dolma documents
print(f"\nAssembling {len(potentially_done_pdfs):,} potentially finished PDFs into Dolma documents at {args.workspace}/output") print(f"\nAssembling {len(potentially_done_pdfs):,} potentially finished PDFs into Dolma documents at {args.workspace}/output")
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs} future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb) new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb, after_flush=partial(mark_pdfs_done, args.workspace))
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
pdf = future_to_path[future] pdf = future_to_path[future]
dolma_doc = future.result() dolma_doc = future.result()
new_output_writer.write_line(json.dumps(dolma_doc)) if dolma_doc is not None:
new_output_writer.write_line(json.dumps(dolma_doc))
new_output_writer.close() new_output_writer.close()
print("\nFinal statistics:")
# Output the number of documents in each status "pending" and "completed"
print("\nWork finished, waiting for all workers to finish cleaning up") print("\nWork finished, waiting for all workers to finish cleaning up")
executor.shutdown(wait=True) executor.shutdown(wait=True)
db.close()
# TODO # TODO
# 2. Have a way to apply basic spam + language filter if you can during add pdfs step # 2. Have a way to apply basic spam + language filter if you can during add pdfs step