mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-02 11:04:25 +00:00
More pipeline code
This commit is contained in:
parent
39333f2c96
commit
f2f578cca9
@ -6,6 +6,7 @@ import json
|
||||
import argparse
|
||||
import glob
|
||||
import tempfile
|
||||
import datetime
|
||||
import posixpath
|
||||
import smart_open
|
||||
|
||||
@ -93,6 +94,14 @@ class DatabaseManager:
|
||||
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
|
||||
result = self.cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
def set_metadata(self, key: str, value: str) -> None:
|
||||
self.cursor.execute("""
|
||||
INSERT INTO metadata (key, value)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(key) DO UPDATE SET value=excluded.value
|
||||
""", (key, value))
|
||||
self.conn.commit()
|
||||
|
||||
def get_current_round(self):
|
||||
round_value = self.get_metadata("round")
|
||||
@ -385,18 +394,52 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> list
|
||||
tf.write(get_s3_bytes(pdf.s3_path))
|
||||
tf.flush()
|
||||
|
||||
for page in range(1, pdf.num_pages + 1):
|
||||
for target_page_num in range(1, pdf.num_pages + 1):
|
||||
# Is there an existing page that has no error
|
||||
if any(page.is_usable() for page in existing_pages):
|
||||
if any(page.is_usable() and page.page_num == target_page_num for page in existing_pages):
|
||||
continue
|
||||
|
||||
# TODO: Later, you may want to retry with different sampling parameters or do something else
|
||||
new_queries.append(build_page_query(tf.name, pdf.s3_path, page))
|
||||
new_queries.append(build_page_query(tf.name, pdf.s3_path, target_page_num))
|
||||
except Exception as ex:
|
||||
print(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}")
|
||||
|
||||
return new_queries
|
||||
|
||||
def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict:
|
||||
db = DatabaseManager(s3_workspace)
|
||||
existing_pages = db.get_index_entries(pdf.s3_path)
|
||||
document_text = ""
|
||||
|
||||
for target_page_num in range(1, pdf.num_pages + 1):
|
||||
target_page = next(page for page in existing_pages if page.is_usable() and page.page_num == target_page_num)
|
||||
|
||||
target_row = get_s3_bytes(target_page.s3_path,
|
||||
start_index=target_page.start_index,
|
||||
end_index=target_page.start_index+target_page.length)
|
||||
|
||||
target_data = json.loads(target_row.decode("utf-8"))
|
||||
|
||||
document_text += target_data + "\n"
|
||||
|
||||
metadata = {
|
||||
"Source-File": pdf.s3_path,
|
||||
"pdf-total-pages": pdf.num_pages,
|
||||
}
|
||||
id_ = hashlib.sha1(document_text.encode()).hexdigest()
|
||||
|
||||
dolma_doc = {
|
||||
"id": id_,
|
||||
"text": document_text,
|
||||
"source": "s2pdf",
|
||||
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
return dolma_doc
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
||||
@ -456,6 +499,7 @@ if __name__ == '__main__':
|
||||
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
|
||||
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")}
|
||||
potentially_done_pdfs = []
|
||||
lines_written = 0
|
||||
new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{round}", args.max_size_mb)
|
||||
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||
@ -466,15 +510,26 @@ if __name__ == '__main__':
|
||||
potentially_done_pdfs.append(pdf)
|
||||
|
||||
for line in inference_lines:
|
||||
lines_written += 1
|
||||
new_inference_writer.write_line(json.dumps(line))
|
||||
|
||||
new_inference_writer.close()
|
||||
|
||||
if lines_written > 0:
|
||||
db.set_metadata("round", str(db.get_current_round() + 1))
|
||||
|
||||
# Now, finally, assemble any potentially done docs into dolma documents
|
||||
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
|
||||
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb)
|
||||
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||
pdf = future_to_path[future]
|
||||
dolma_doc = future.result()
|
||||
|
||||
new_output_writer.write_line(json.dumps(dolma_doc))
|
||||
|
||||
new_output_writer.close()
|
||||
|
||||
# TODO
|
||||
# 1. build a class that will manage taking in dicts and outputting them as jsonls of up to the max size to the bucket
|
||||
# you'll need one for new batch inference lines, and one for finished dolma docs
|
||||
# 2. Have a way to apply basic spam + language filter if you can during add pdfs step
|
||||
# 3. For retrying, make it so you retry several times with different sampling parameters
|
||||
Loading…
x
Reference in New Issue
Block a user