More pipeline code

This commit is contained in:
Jake Poznanski 2024-10-14 17:23:09 +00:00
parent 39333f2c96
commit f2f578cca9

View File

@ -6,6 +6,7 @@ import json
import argparse
import glob
import tempfile
import datetime
import posixpath
import smart_open
@ -93,6 +94,14 @@ class DatabaseManager:
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
result = self.cursor.fetchone()
return result[0] if result else None
def set_metadata(self, key: str, value: str) -> None:
self.cursor.execute("""
INSERT INTO metadata (key, value)
VALUES (?, ?)
ON CONFLICT(key) DO UPDATE SET value=excluded.value
""", (key, value))
self.conn.commit()
def get_current_round(self):
round_value = self.get_metadata("round")
@ -385,18 +394,52 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> list
tf.write(get_s3_bytes(pdf.s3_path))
tf.flush()
for page in range(1, pdf.num_pages + 1):
for target_page_num in range(1, pdf.num_pages + 1):
# Is there an existing page that has no error
if any(page.is_usable() for page in existing_pages):
if any(page.is_usable() and page.page_num == target_page_num for page in existing_pages):
continue
# TODO: Later, you may want to retry with different sampling parameters or do something else
new_queries.append(build_page_query(tf.name, pdf.s3_path, page))
new_queries.append(build_page_query(tf.name, pdf.s3_path, target_page_num))
except Exception as ex:
print(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}")
return new_queries
def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict:
db = DatabaseManager(s3_workspace)
existing_pages = db.get_index_entries(pdf.s3_path)
document_text = ""
for target_page_num in range(1, pdf.num_pages + 1):
target_page = next(page for page in existing_pages if page.is_usable() and page.page_num == target_page_num)
target_row = get_s3_bytes(target_page.s3_path,
start_index=target_page.start_index,
end_index=target_page.start_index+target_page.length)
target_data = json.loads(target_row.decode("utf-8"))
document_text += target_data + "\n"
metadata = {
"Source-File": pdf.s3_path,
"pdf-total-pages": pdf.num_pages,
}
id_ = hashlib.sha1(document_text.encode()).hexdigest()
dolma_doc = {
"id": id_,
"text": document_text,
"source": "s2pdf",
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
}
return dolma_doc
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
@ -456,6 +499,7 @@ if __name__ == '__main__':
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")}
potentially_done_pdfs = []
lines_written = 0
new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{round}", args.max_size_mb)
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
@ -466,15 +510,26 @@ if __name__ == '__main__':
potentially_done_pdfs.append(pdf)
for line in inference_lines:
lines_written += 1
new_inference_writer.write_line(json.dumps(line))
new_inference_writer.close()
if lines_written > 0:
db.set_metadata("round", str(db.get_current_round() + 1))
# Now, finally, assemble any potentially done docs into dolma documents
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb)
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
pdf = future_to_path[future]
dolma_doc = future.result()
new_output_writer.write_line(json.dumps(dolma_doc))
new_output_writer.close()
# TODO
# 1. build a class that will manage taking in dicts and outputting them as jsonls of up to the max size to the bucket
# you'll need one for new batch inference lines, and one for finished dolma docs
# 2. Have a way to apply basic spam + language filter if you can during add pdfs step
# 3. For retrying, make it so you retry several times with different sampling parameters