mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-27 09:27:55 +00:00
Pipeline working hopefully soon
This commit is contained in:
parent
f2f578cca9
commit
cd8e28e459
@ -27,7 +27,8 @@ s3 = boto3.client('s3')
|
|||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BatchInferenceRecord:
|
class BatchInferenceRecord:
|
||||||
s3_path: str
|
inference_s3_path: str
|
||||||
|
pdf_s3_path: str
|
||||||
page_num: int # 1 indexed!
|
page_num: int # 1 indexed!
|
||||||
start_index: int
|
start_index: int
|
||||||
length: int
|
length: int
|
||||||
@ -56,7 +57,8 @@ class DatabaseManager:
|
|||||||
def _initialize_tables(self):
|
def _initialize_tables(self):
|
||||||
self.cursor.execute("""
|
self.cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS page_results (
|
CREATE TABLE IF NOT EXISTS page_results (
|
||||||
s3_path TEXT,
|
inference_s3_path TEXT,
|
||||||
|
pdf_s3_path TEXT,
|
||||||
page_num INTEGER,
|
page_num INTEGER,
|
||||||
start_index BIGINT,
|
start_index BIGINT,
|
||||||
length BIGINT,
|
length BIGINT,
|
||||||
@ -115,29 +117,30 @@ class DatabaseManager:
|
|||||||
def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
|
def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
|
||||||
if index_entries:
|
if index_entries:
|
||||||
self.cursor.executemany("""
|
self.cursor.executemany("""
|
||||||
INSERT INTO page_results (s3_path, page_num, start_index, length, finish_reason, error)
|
INSERT INTO page_results (inference_s3_path, pdf_s3_path, page_num, start_index, length, finish_reason, error)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
""", [(entry.s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
|
""", [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def get_index_entries(self, s3_path: str) -> List[BatchInferenceRecord]:
|
def get_index_entries(self, s3_path: str) -> List[BatchInferenceRecord]:
|
||||||
self.cursor.execute("""
|
self.cursor.execute("""
|
||||||
SELECT s3_path, page_num, start_index, length, finish_reason, error
|
SELECT inference_s3_path, pdf_s3_path, page_num, start_index, length, finish_reason, error
|
||||||
FROM page_results
|
FROM page_results
|
||||||
WHERE s3_path = ?
|
WHERE s3_path = ?
|
||||||
ORDER BY page_num ASC
|
ORDER BY inference_s3_path DESC start_index ASC page_num ASC
|
||||||
""", (s3_path,))
|
""", (s3_path,))
|
||||||
|
|
||||||
rows = self.cursor.fetchall()
|
rows = self.cursor.fetchall()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
self.BatchInferenceRecord(
|
self.BatchInferenceRecord(
|
||||||
s3_path=row[0],
|
inference_s3_path=row[0],
|
||||||
page_num=row[1],
|
pdf_s3_path=row[1],
|
||||||
start_index=row[2],
|
page_num=row[2],
|
||||||
length=row[3],
|
start_index=row[3],
|
||||||
finish_reason=row[4],
|
length=row[4],
|
||||||
error=row[5]
|
finish_reason=row[5],
|
||||||
|
error=row[6]
|
||||||
)
|
)
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
@ -186,6 +189,7 @@ class DatabaseManager:
|
|||||||
SELECT s3_path, num_pages, status
|
SELECT s3_path, num_pages, status
|
||||||
FROM pdfs
|
FROM pdfs
|
||||||
WHERE status == ?
|
WHERE status == ?
|
||||||
|
ORDER BY s3_path DESC
|
||||||
""", (status, ))
|
""", (status, ))
|
||||||
|
|
||||||
rows = self.cursor.fetchall()
|
rows = self.cursor.fetchall()
|
||||||
@ -334,8 +338,8 @@ def parse_custom_id(custom_id: str) -> Tuple[str, int]:
|
|||||||
page_num = int(custom_id[custom_id.rindex("-") + 1:])
|
page_num = int(custom_id[custom_id.rindex("-") + 1:])
|
||||||
return s3_path, page_num
|
return s3_path, page_num
|
||||||
|
|
||||||
def process_jsonl_content(s3_path) -> List[DatabaseManager.BatchInferenceRecord]:
|
def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
|
||||||
content = get_s3_bytes(s3_path).decode("utf-8")
|
content = get_s3_bytes(inference_s3_path).decode("utf-8")
|
||||||
|
|
||||||
start_index = 0
|
start_index = 0
|
||||||
index_entries = []
|
index_entries = []
|
||||||
@ -345,12 +349,13 @@ def process_jsonl_content(s3_path) -> List[DatabaseManager.BatchInferenceRecord]
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
s3_path, page_num = parse_custom_id(data["custom_id"])
|
pdf_s3_path, page_num = parse_custom_id(data["custom_id"])
|
||||||
|
|
||||||
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
|
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
|
||||||
|
|
||||||
index_entries.append(DatabaseManager.BatchInferenceRecord(
|
index_entries.append(DatabaseManager.BatchInferenceRecord(
|
||||||
s3_path=s3_path,
|
inference_s3_path=inference_s3_path,
|
||||||
|
pdf_s3_path=pdf_s3_path,
|
||||||
page_num=page_num,
|
page_num=page_num,
|
||||||
start_index=start_index,
|
start_index=start_index,
|
||||||
length=line_length,
|
length=line_length,
|
||||||
@ -410,17 +415,22 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict:
|
|||||||
db = DatabaseManager(s3_workspace)
|
db = DatabaseManager(s3_workspace)
|
||||||
existing_pages = db.get_index_entries(pdf.s3_path)
|
existing_pages = db.get_index_entries(pdf.s3_path)
|
||||||
document_text = ""
|
document_text = ""
|
||||||
|
last_page_start_index = 0
|
||||||
|
pdf_page_spans = []
|
||||||
|
|
||||||
for target_page_num in range(1, pdf.num_pages + 1):
|
for target_page_num in range(1, pdf.num_pages + 1):
|
||||||
target_page = next(page for page in existing_pages if page.is_usable() and page.page_num == target_page_num)
|
target_page = next(page for page in existing_pages if page.is_usable() and page.page_num == target_page_num)
|
||||||
|
|
||||||
target_row = get_s3_bytes(target_page.s3_path,
|
target_row = get_s3_bytes(target_page.pdf_s3_path,
|
||||||
start_index=target_page.start_index,
|
start_index=target_page.start_index,
|
||||||
end_index=target_page.start_index+target_page.length)
|
end_index=target_page.start_index+target_page.length)
|
||||||
|
|
||||||
target_data = json.loads(target_row.decode("utf-8"))
|
target_data = json.loads(target_row.decode("utf-8"))
|
||||||
|
|
||||||
document_text += target_data + "\n"
|
document_text += target_data["natural_text"] + "\n"
|
||||||
|
|
||||||
|
pdf_page_spans.append([last_page_start_index, len(document_text), target_page_num])
|
||||||
|
last_page_start_index = len(document_text)
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"Source-File": pdf.s3_path,
|
"Source-File": pdf.s3_path,
|
||||||
@ -431,10 +441,13 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict:
|
|||||||
dolma_doc = {
|
dolma_doc = {
|
||||||
"id": id_,
|
"id": id_,
|
||||||
"text": document_text,
|
"text": document_text,
|
||||||
"source": "s2pdf",
|
"source": "pdelfin",
|
||||||
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
|
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||||
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
|
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
|
"attributes": {
|
||||||
|
"pdf_page_numbers": pdf_page_spans
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return dolma_doc
|
return dolma_doc
|
||||||
@ -500,7 +513,7 @@ if __name__ == '__main__':
|
|||||||
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")}
|
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")}
|
||||||
potentially_done_pdfs = []
|
potentially_done_pdfs = []
|
||||||
lines_written = 0
|
lines_written = 0
|
||||||
new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{round}", args.max_size_mb)
|
new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{db.get_current_round()}", args.max_size_mb)
|
||||||
|
|
||||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
|
||||||
pdf = future_to_path[future]
|
pdf = future_to_path[future]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user