Pipeline working hopefully soon

This commit is contained in:
Jake Poznanski 2024-10-14 18:19:17 +00:00
parent f2f578cca9
commit cd8e28e459

View File

@ -27,7 +27,8 @@ s3 = boto3.client('s3')
class DatabaseManager:
@dataclass(frozen=True)
class BatchInferenceRecord:
s3_path: str
inference_s3_path: str
pdf_s3_path: str
page_num: int # 1 indexed!
start_index: int
length: int
@ -56,7 +57,8 @@ class DatabaseManager:
def _initialize_tables(self):
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS page_results (
s3_path TEXT,
inference_s3_path TEXT,
pdf_s3_path TEXT,
page_num INTEGER,
start_index BIGINT,
length BIGINT,
@ -115,29 +117,30 @@ class DatabaseManager:
def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
if index_entries:
self.cursor.executemany("""
INSERT INTO page_results (s3_path, page_num, start_index, length, finish_reason, error)
INSERT INTO page_results (inference_s3_path, pdf_s3_path, page_num, start_index, length, finish_reason, error)
VALUES (?, ?, ?, ?, ?, ?)
""", [(entry.s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
""", [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
self.conn.commit()
def get_index_entries(self, s3_path: str) -> List[BatchInferenceRecord]:
self.cursor.execute("""
SELECT s3_path, page_num, start_index, length, finish_reason, error
SELECT inference_s3_path, pdf_s3_path, page_num, start_index, length, finish_reason, error
FROM page_results
WHERE s3_path = ?
ORDER BY page_num ASC
ORDER BY inference_s3_path DESC start_index ASC page_num ASC
""", (s3_path,))
rows = self.cursor.fetchall()
return [
self.BatchInferenceRecord(
s3_path=row[0],
page_num=row[1],
start_index=row[2],
length=row[3],
finish_reason=row[4],
error=row[5]
inference_s3_path=row[0],
pdf_s3_path=row[1],
page_num=row[2],
start_index=row[3],
length=row[4],
finish_reason=row[5],
error=row[6]
)
for row in rows
]
@ -186,6 +189,7 @@ class DatabaseManager:
SELECT s3_path, num_pages, status
FROM pdfs
WHERE status == ?
ORDER BY s3_path DESC
""", (status, ))
rows = self.cursor.fetchall()
@ -334,8 +338,8 @@ def parse_custom_id(custom_id: str) -> Tuple[str, int]:
page_num = int(custom_id[custom_id.rindex("-") + 1:])
return s3_path, page_num
def process_jsonl_content(s3_path) -> List[DatabaseManager.BatchInferenceRecord]:
content = get_s3_bytes(s3_path).decode("utf-8")
def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
content = get_s3_bytes(inference_s3_path).decode("utf-8")
start_index = 0
index_entries = []
@ -345,12 +349,13 @@ def process_jsonl_content(s3_path) -> List[DatabaseManager.BatchInferenceRecord]
try:
data = json.loads(line)
s3_path, page_num = parse_custom_id(data["custom_id"])
pdf_s3_path, page_num = parse_custom_id(data["custom_id"])
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
index_entries.append(DatabaseManager.BatchInferenceRecord(
s3_path=s3_path,
inference_s3_path=inference_s3_path,
pdf_s3_path=pdf_s3_path,
page_num=page_num,
start_index=start_index,
length=line_length,
@ -410,17 +415,22 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict:
db = DatabaseManager(s3_workspace)
existing_pages = db.get_index_entries(pdf.s3_path)
document_text = ""
last_page_start_index = 0
pdf_page_spans = []
for target_page_num in range(1, pdf.num_pages + 1):
target_page = next(page for page in existing_pages if page.is_usable() and page.page_num == target_page_num)
target_row = get_s3_bytes(target_page.s3_path,
start_index=target_page.start_index,
end_index=target_page.start_index+target_page.length)
target_row = get_s3_bytes(target_page.pdf_s3_path,
start_index=target_page.start_index,
end_index=target_page.start_index+target_page.length)
target_data = json.loads(target_row.decode("utf-8"))
document_text += target_data + "\n"
document_text += target_data["natural_text"] + "\n"
pdf_page_spans.append([last_page_start_index, len(document_text), target_page_num])
last_page_start_index = len(document_text)
metadata = {
"Source-File": pdf.s3_path,
@ -431,10 +441,13 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict:
dolma_doc = {
"id": id_,
"text": document_text,
"source": "s2pdf",
"source": "pdelfin",
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
"attributes": {
"pdf_page_numbers": pdf_page_spans
}
}
return dolma_doc
@ -500,7 +513,7 @@ if __name__ == '__main__':
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")}
potentially_done_pdfs = []
lines_written = 0
new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{round}", args.max_size_mb)
new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{db.get_current_round()}", args.max_size_mb)
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
pdf = future_to_path[future]