mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-09 06:59:03 +00:00
Runs to the end now
This commit is contained in:
parent
879b974af2
commit
1ed9e4c947
@ -8,13 +8,14 @@ import glob
|
|||||||
import tempfile
|
import tempfile
|
||||||
import datetime
|
import datetime
|
||||||
import posixpath
|
import posixpath
|
||||||
import smart_open
|
import threading
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Optional, List, Tuple, Dict
|
from functools import partial
|
||||||
|
from typing import Optional, List, Tuple, Dict, Callable, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
|
||||||
@ -120,6 +121,14 @@ class DatabaseManager:
|
|||||||
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
|
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
|
||||||
result = self.cursor.fetchone()
|
result = self.cursor.fetchone()
|
||||||
return result is not None and result[0] == etag
|
return result is not None and result[0] == etag
|
||||||
|
|
||||||
|
def update_processed_file(self, s3_path, etag):
|
||||||
|
self.cursor.execute("""
|
||||||
|
INSERT INTO processed_files (s3_path, etag)
|
||||||
|
VALUES (?, ?)
|
||||||
|
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
|
||||||
|
""", (s3_path, etag))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
|
def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
|
||||||
if index_entries:
|
if index_entries:
|
||||||
@ -162,14 +171,6 @@ class DatabaseManager:
|
|||||||
result = self.cursor.fetchone()
|
result = self.cursor.fetchone()
|
||||||
return -1 if result[0] is None else result[0]
|
return -1 if result[0] is None else result[0]
|
||||||
|
|
||||||
def update_processed_file(self, s3_path, etag):
|
|
||||||
self.cursor.execute("""
|
|
||||||
INSERT INTO processed_files (s3_path, etag)
|
|
||||||
VALUES (?, ?)
|
|
||||||
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
|
|
||||||
""", (s3_path, etag))
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def pdf_exists(self, s3_path: str) -> bool:
|
def pdf_exists(self, s3_path: str) -> bool:
|
||||||
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
|
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
|
||||||
return self.cursor.fetchone() is not None
|
return self.cursor.fetchone() is not None
|
||||||
@ -184,6 +185,14 @@ class DatabaseManager:
|
|||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
print(f"PDF with s3_path '{s3_path}' already exists.")
|
print(f"PDF with s3_path '{s3_path}' already exists.")
|
||||||
|
|
||||||
|
def update_pdf_status(self, s3_path: str, new_status: str) -> None:
|
||||||
|
self.cursor.execute("""
|
||||||
|
UPDATE pdfs
|
||||||
|
SET status = ?
|
||||||
|
WHERE s3_path = ?
|
||||||
|
""", (new_status, s3_path))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
def get_pdf(self, s3_path: str) -> Optional[PDFRecord]:
|
def get_pdf(self, s3_path: str) -> Optional[PDFRecord]:
|
||||||
self.cursor.execute("""
|
self.cursor.execute("""
|
||||||
SELECT s3_path, num_pages, status
|
SELECT s3_path, num_pages, status
|
||||||
@ -226,11 +235,13 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# Writes batches of lines out to a set of files, keeping each file below some maximum size
|
# Writes batches of lines out to a set of files, keeping each file below some maximum size
|
||||||
class BatchWriter:
|
class BatchWriter:
|
||||||
def __init__(self, output_prefix: str, max_size_mb: int = 250):
|
def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None):
|
||||||
self.output_prefix = output_prefix
|
self.output_prefix = output_prefix
|
||||||
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
|
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||||
self.batch = []
|
self.batch = []
|
||||||
self.batch_size = 0
|
self.batch_size = 0
|
||||||
|
self.after_flush = after_flush
|
||||||
|
self.threads = []
|
||||||
|
|
||||||
parsed = urlparse(output_prefix)
|
parsed = urlparse(output_prefix)
|
||||||
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
|
self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n')
|
||||||
@ -259,14 +270,14 @@ class BatchWriter:
|
|||||||
return os.path.join(self.output_prefix, filename)
|
return os.path.join(self.output_prefix, filename)
|
||||||
|
|
||||||
def write_line(self, line: Optional[str]):
|
def write_line(self, line: Optional[str]):
|
||||||
if line is None or len(line.strip()) == 0:
|
if line is None or not line.strip():
|
||||||
return
|
return
|
||||||
|
|
||||||
line_size = len(line.encode('utf-8')) + 1 # +1 for newline
|
line_size = len(line.encode('utf-8')) + 1 # +1 for newline
|
||||||
|
|
||||||
if self.batch_size + line_size > self.max_size:
|
if self.batch_size + line_size > self.max_size:
|
||||||
self._write_batch()
|
self._write_batch()
|
||||||
|
|
||||||
self.batch.append(line)
|
self.batch.append(line)
|
||||||
self.batch_size += line_size
|
self.batch_size += line_size
|
||||||
|
|
||||||
@ -274,19 +285,53 @@ class BatchWriter:
|
|||||||
if not self.batch:
|
if not self.batch:
|
||||||
return
|
return
|
||||||
|
|
||||||
batch_content = "\n".join(self.batch) + "\n"
|
batch_lines = self.batch.copy()
|
||||||
|
batch_content = "\n".join(batch_lines) + "\n"
|
||||||
hash_str = self._compute_hash(batch_content)
|
hash_str = self._compute_hash(batch_content)
|
||||||
output_path = self._get_output_path(hash_str)
|
output_path = self._get_output_path(hash_str)
|
||||||
|
|
||||||
with smart_open.open(output_path, 'w') as f_out:
|
# Start a new thread to write the batch
|
||||||
f_out.write(batch_content)
|
thread = threading.Thread(
|
||||||
print(f"Wrote batch to {output_path}")
|
target=self._write_batch_to_file,
|
||||||
|
args=(batch_content, output_path, batch_lines)
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
self.threads.append(thread)
|
||||||
|
|
||||||
|
# Clear the batch and batch_size
|
||||||
self.batch = []
|
self.batch = []
|
||||||
self.batch_size = 0
|
self.batch_size = 0
|
||||||
|
|
||||||
|
def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]):
|
||||||
|
if self.is_s3:
|
||||||
|
self._write_to_s3(batch_content, output_path)
|
||||||
|
else:
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f_out:
|
||||||
|
f_out.write(batch_content)
|
||||||
|
|
||||||
|
# After writing, call the after_flush callback if it is set
|
||||||
|
if self.after_flush:
|
||||||
|
self.after_flush(batch_lines)
|
||||||
|
|
||||||
|
def _write_to_s3(self, content: str, s3_path: str):
|
||||||
|
# Parse the s3_path to get bucket and key
|
||||||
|
parsed = urlparse(s3_path)
|
||||||
|
bucket = parsed.netloc
|
||||||
|
key = parsed.path.lstrip('/')
|
||||||
|
|
||||||
|
s3 = boto3.client('s3')
|
||||||
|
s3.put_object(
|
||||||
|
Bucket=bucket,
|
||||||
|
Key=key,
|
||||||
|
Body=content.encode('utf-8'),
|
||||||
|
ContentType='text/plain; charset=utf-8'
|
||||||
|
)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._write_batch()
|
self._write_batch()
|
||||||
|
# Wait for all threads to finish
|
||||||
|
for thread in self.threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
|
||||||
def parse_s3_path(s3_path):
|
def parse_s3_path(s3_path):
|
||||||
@ -493,6 +538,12 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
|
|||||||
|
|
||||||
return dolma_doc
|
return dolma_doc
|
||||||
|
|
||||||
|
def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
|
||||||
|
db = DatabaseManager(s3_workspace)
|
||||||
|
|
||||||
|
for line in dolma_doc_lines:
|
||||||
|
db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||||
@ -569,7 +620,9 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
for line in inference_lines:
|
for line in inference_lines:
|
||||||
lines_written += 1
|
lines_written += 1
|
||||||
new_inference_writer.write_line(json.dumps(line))
|
|
||||||
|
if line is not None:
|
||||||
|
new_inference_writer.write_line(json.dumps(line))
|
||||||
|
|
||||||
new_inference_writer.close()
|
new_inference_writer.close()
|
||||||
|
|
||||||
@ -579,20 +632,24 @@ if __name__ == '__main__':
|
|||||||
# Now, finally, assemble any potentially done docs into dolma documents
|
# Now, finally, assemble any potentially done docs into dolma documents
|
||||||
print(f"\nAssembling {len(potentially_done_pdfs):,} potentially finished PDFs into Dolma documents at {args.workspace}/output")
|
print(f"\nAssembling {len(potentially_done_pdfs):,} potentially finished PDFs into Dolma documents at {args.workspace}/output")
|
||||||
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
|
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
|
||||||
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb)
|
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb, after_flush=partial(mark_pdfs_done, args.workspace))
|
||||||
|
|
||||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||||
pdf = future_to_path[future]
|
pdf = future_to_path[future]
|
||||||
dolma_doc = future.result()
|
dolma_doc = future.result()
|
||||||
|
|
||||||
new_output_writer.write_line(json.dumps(dolma_doc))
|
if dolma_doc is not None:
|
||||||
|
new_output_writer.write_line(json.dumps(dolma_doc))
|
||||||
|
|
||||||
new_output_writer.close()
|
new_output_writer.close()
|
||||||
|
|
||||||
|
print("\nFinal statistics:")
|
||||||
|
# Output the number of documents in each status "pending" and "completed"
|
||||||
|
|
||||||
|
|
||||||
print("\nWork finished, waiting for all workers to finish cleaning up")
|
print("\nWork finished, waiting for all workers to finish cleaning up")
|
||||||
executor.shutdown(wait=True)
|
executor.shutdown(wait=True)
|
||||||
|
db.close()
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# 2. Have a way to apply basic spam + language filter if you can during add pdfs step
|
# 2. Have a way to apply basic spam + language filter if you can during add pdfs step
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user