Pipeline working hopefully soon

This commit is contained in:
Jake Poznanski 2024-10-14 18:19:17 +00:00
parent f2f578cca9
commit cd8e28e459

View File

@ -27,7 +27,8 @@ s3 = boto3.client('s3')
class DatabaseManager: class DatabaseManager:
@dataclass(frozen=True) @dataclass(frozen=True)
class BatchInferenceRecord: class BatchInferenceRecord:
s3_path: str inference_s3_path: str
pdf_s3_path: str
page_num: int # 1 indexed! page_num: int # 1 indexed!
start_index: int start_index: int
length: int length: int
@ -56,7 +57,8 @@ class DatabaseManager:
def _initialize_tables(self): def _initialize_tables(self):
self.cursor.execute(""" self.cursor.execute("""
CREATE TABLE IF NOT EXISTS page_results ( CREATE TABLE IF NOT EXISTS page_results (
s3_path TEXT, inference_s3_path TEXT,
pdf_s3_path TEXT,
page_num INTEGER, page_num INTEGER,
start_index BIGINT, start_index BIGINT,
length BIGINT, length BIGINT,
@ -115,29 +117,30 @@ class DatabaseManager:
def add_index_entries(self, index_entries: List[BatchInferenceRecord]): def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
if index_entries: if index_entries:
self.cursor.executemany(""" self.cursor.executemany("""
INSERT INTO page_results (s3_path, page_num, start_index, length, finish_reason, error) INSERT INTO page_results (inference_s3_path, pdf_s3_path, page_num, start_index, length, finish_reason, error)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
""", [(entry.s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries]) """, [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
self.conn.commit() self.conn.commit()
def get_index_entries(self, s3_path: str) -> List[BatchInferenceRecord]: def get_index_entries(self, s3_path: str) -> List[BatchInferenceRecord]:
self.cursor.execute(""" self.cursor.execute("""
SELECT s3_path, page_num, start_index, length, finish_reason, error SELECT inference_s3_path, pdf_s3_path, page_num, start_index, length, finish_reason, error
FROM page_results FROM page_results
WHERE s3_path = ? WHERE s3_path = ?
ORDER BY page_num ASC ORDER BY inference_s3_path DESC start_index ASC page_num ASC
""", (s3_path,)) """, (s3_path,))
rows = self.cursor.fetchall() rows = self.cursor.fetchall()
return [ return [
self.BatchInferenceRecord( self.BatchInferenceRecord(
s3_path=row[0], inference_s3_path=row[0],
page_num=row[1], pdf_s3_path=row[1],
start_index=row[2], page_num=row[2],
length=row[3], start_index=row[3],
finish_reason=row[4], length=row[4],
error=row[5] finish_reason=row[5],
error=row[6]
) )
for row in rows for row in rows
] ]
@ -186,6 +189,7 @@ class DatabaseManager:
SELECT s3_path, num_pages, status SELECT s3_path, num_pages, status
FROM pdfs FROM pdfs
WHERE status == ? WHERE status == ?
ORDER BY s3_path DESC
""", (status, )) """, (status, ))
rows = self.cursor.fetchall() rows = self.cursor.fetchall()
@ -334,8 +338,8 @@ def parse_custom_id(custom_id: str) -> Tuple[str, int]:
page_num = int(custom_id[custom_id.rindex("-") + 1:]) page_num = int(custom_id[custom_id.rindex("-") + 1:])
return s3_path, page_num return s3_path, page_num
def process_jsonl_content(s3_path) -> List[DatabaseManager.BatchInferenceRecord]: def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
content = get_s3_bytes(s3_path).decode("utf-8") content = get_s3_bytes(inference_s3_path).decode("utf-8")
start_index = 0 start_index = 0
index_entries = [] index_entries = []
@ -345,12 +349,13 @@ def process_jsonl_content(s3_path) -> List[DatabaseManager.BatchInferenceRecord]
try: try:
data = json.loads(line) data = json.loads(line)
s3_path, page_num = parse_custom_id(data["custom_id"]) pdf_s3_path, page_num = parse_custom_id(data["custom_id"])
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected" assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
index_entries.append(DatabaseManager.BatchInferenceRecord( index_entries.append(DatabaseManager.BatchInferenceRecord(
s3_path=s3_path, inference_s3_path=inference_s3_path,
pdf_s3_path=pdf_s3_path,
page_num=page_num, page_num=page_num,
start_index=start_index, start_index=start_index,
length=line_length, length=line_length,
@ -410,17 +415,22 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict:
db = DatabaseManager(s3_workspace) db = DatabaseManager(s3_workspace)
existing_pages = db.get_index_entries(pdf.s3_path) existing_pages = db.get_index_entries(pdf.s3_path)
document_text = "" document_text = ""
last_page_start_index = 0
pdf_page_spans = []
for target_page_num in range(1, pdf.num_pages + 1): for target_page_num in range(1, pdf.num_pages + 1):
target_page = next(page for page in existing_pages if page.is_usable() and page.page_num == target_page_num) target_page = next(page for page in existing_pages if page.is_usable() and page.page_num == target_page_num)
target_row = get_s3_bytes(target_page.s3_path, target_row = get_s3_bytes(target_page.pdf_s3_path,
start_index=target_page.start_index, start_index=target_page.start_index,
end_index=target_page.start_index+target_page.length) end_index=target_page.start_index+target_page.length)
target_data = json.loads(target_row.decode("utf-8")) target_data = json.loads(target_row.decode("utf-8"))
document_text += target_data + "\n" document_text += target_data["natural_text"] + "\n"
pdf_page_spans.append([last_page_start_index, len(document_text), target_page_num])
last_page_start_index = len(document_text)
metadata = { metadata = {
"Source-File": pdf.s3_path, "Source-File": pdf.s3_path,
@ -431,10 +441,13 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> dict:
dolma_doc = { dolma_doc = {
"id": id_, "id": id_,
"text": document_text, "text": document_text,
"source": "s2pdf", "source": "pdelfin",
"added": datetime.datetime.now().strftime("%Y-%m-%d"), "added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"), "created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata, "metadata": metadata,
"attributes": {
"pdf_page_numbers": pdf_page_spans
}
} }
return dolma_doc return dolma_doc
@ -500,7 +513,7 @@ if __name__ == '__main__':
future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")} future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf): pdf for pdf in db.get_pdfs_by_status("pending")}
potentially_done_pdfs = [] potentially_done_pdfs = []
lines_written = 0 lines_written = 0
new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{round}", args.max_size_mb) new_inference_writer = BatchWriter(f"{args.workspace}/inference/round_{db.get_current_round()}", args.max_size_mb)
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
pdf = future_to_path[future] pdf = future_to_path[future]