mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 11:35:29 +00:00
More pipeline code
This commit is contained in:
parent
39333f2c96
commit
f2f578cca9
@ -6,6 +6,7 @@ import json
|
|||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import datetime
|
||||||
import posixpath
|
import posixpath
|
||||||
import smart_open
|
import smart_open
|
||||||
|
|
||||||
@ -93,6 +94,14 @@ class DatabaseManager:
|
|||||||
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
|
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
|
||||||
result = self.cursor.fetchone()
|
result = self.cursor.fetchone()
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
def set_metadata(self, key: str, value: str) -> None:
|
||||||
|
self.cursor.execute("""
|
||||||
|
INSERT INTO metadata (key, value)
|
||||||
|
VALUES (?, ?)
|
||||||
|
ON CONFLICT(key) DO UPDATE SET value=excluded.value
|
||||||
|
""", (key, value))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
def get_current_round(self):
|
def get_current_round(self):
|
||||||
round_value = self.get_metadata("round")
|
round_value = self.get_metadata("round")
|
||||||
@ -385,18 +394,52 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> list
|
|||||||
tf.write(get_s3_bytes(pdf.s3_path))
|
tf.write(get_s3_bytes(pdf.s3_path))
|
||||||
tf.flush()
|
tf.flush()
|
||||||
|
|
||||||
for page in range(1, pdf.num_pages + 1):
|
for target_page_num in range(1, pdf.num_pages + 1):
|
||||||
# Is there an existing page that has no error
|
# Is there an existing page that has no error
|
||||||
if any(page.is_usable() for page in existing_pages):
|
if any(page.is_usable() and page.page_num == target_page_num for page in existing_pages):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# TODO: Later, you may want to retry with different sampling parameters or do something else
|
# TODO: Later, you may want to retry with different sampling parameters or do something else
|
||||||
new_queries.append(build_page_query(tf.name, pdf.s3_path, page))
|
new_queries.append(build_page_query(tf.name, pdf.s3_path, target_page_num))
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}")
|
print(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}")
|
||||||
|
|
||||||
return new_queries
|
return new_queries
|
||||||
|
|
||||||
|
def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict:
|
||||||
|
db = DatabaseManager(s3_workspace)
|
||||||
|
existing_pages = db.get_index_entries(pdf.s3_path)
|
||||||
|
document_text = ""
|
||||||
|
|
||||||
|
for target_page_num in range(1, pdf.num_pages + 1):
|
||||||
|
target_page = next(page for page in existing_pages if page.is_usable() and page.page_num == target_page_num)
|
||||||
|
|
||||||
|
target_row = get_s3_bytes(target_page.s3_path,
|
||||||
|
start_index=target_page.start_index,
|
||||||
|
end_index=target_page.start_index+target_page.length)
|
||||||
|
|
||||||
|
target_data = json.loads(target_row.decode("utf-8"))
|
||||||
|
|
||||||
|
document_text += target_data + "\n"
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"Source-File": pdf.s3_path,
|
||||||
|
"pdf-total-pages": pdf.num_pages,
|
||||||
|
}
|
||||||
|
id_ = hashlib.sha1(document_text.encode()).hexdigest()
|
||||||
|
|
||||||
|
dolma_doc = {
|
||||||
|
"id": id_,
|
||||||
|
"text": document_text,
|
||||||
|
"source": "s2pdf",
|
||||||
|
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||||
|
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
return dolma_doc
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
||||||
@ -456,6 +499,7 @@ if __name__ == '__main__':
|
|||||||
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
|
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
|
||||||
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")}
|
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")}
|
||||||
potentially_done_pdfs = []
|
potentially_done_pdfs = []
|
||||||
|
lines_written = 0
|
||||||
new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{round}", args.max_size_mb)
|
new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{round}", args.max_size_mb)
|
||||||
|
|
||||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||||
@ -466,15 +510,26 @@ if __name__ == '__main__':
|
|||||||
potentially_done_pdfs.append(pdf)
|
potentially_done_pdfs.append(pdf)
|
||||||
|
|
||||||
for line in inference_lines:
|
for line in inference_lines:
|
||||||
|
lines_written += 1
|
||||||
new_inference_writer.write_line(json.dumps(line))
|
new_inference_writer.write_line(json.dumps(line))
|
||||||
|
|
||||||
new_inference_writer.close()
|
new_inference_writer.close()
|
||||||
|
|
||||||
|
if lines_written > 0:
|
||||||
|
db.set_metadata("round", str(db.get_current_round() + 1))
|
||||||
|
|
||||||
# Now, finally, assemble any potentially done docs into dolma documents
|
# Now, finally, assemble any potentially done docs into dolma documents
|
||||||
|
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
|
||||||
|
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb)
|
||||||
|
|
||||||
|
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||||
|
pdf = future_to_path[future]
|
||||||
|
dolma_doc = future.result()
|
||||||
|
|
||||||
|
new_output_writer.write_line(json.dumps(dolma_doc))
|
||||||
|
|
||||||
|
new_output_writer.close()
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# 1. build a class that will manage taking in dicts and outputting them as jsonls of up to the max size to the bucket
|
|
||||||
# you'll need one for new batch inference lines, and one for finished dolma docs
|
|
||||||
# 2. Have a way to apply basic spam + language filter if you can during add pdfs step
|
# 2. Have a way to apply basic spam + language filter if you can during add pdfs step
|
||||||
# 3. For retrying, make it so you retry several times with different sampling parameters
|
# 3. For retrying, make it so you retry several times with different sampling parameters
|
||||||
Loading…
x
Reference in New Issue
Block a user