From 230c8a9f9a39c96aa9cd33c73b17349bf13eabc0 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 8 Oct 2024 22:10:18 +0000 Subject: [PATCH] Trying new run that will rewrite the prompts as it goes --- pdelfin/train/dataloader.py | 51 +++++++++++++++++++++++-------- scripts/qwen2vl-7b-lora-gantry.sh | 2 ++ tests/test_dataloader.py | 38 ++++++++++++++--------- 3 files changed, 65 insertions(+), 26 deletions(-) diff --git a/pdelfin/train/dataloader.py b/pdelfin/train/dataloader.py index d38c02d..adba62b 100644 --- a/pdelfin/train/dataloader.py +++ b/pdelfin/train/dataloader.py @@ -1,8 +1,8 @@ import json import logging -import multiprocessing +import tempfile import re -import random +import os import base64 import glob @@ -14,6 +14,8 @@ import boto3 from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict from .core.config import DataConfig, SourceConfig +from pdelfin.prompts.anchor import get_anchor_text + # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -124,6 +126,8 @@ def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]: return width, height + + def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]: """ Extracts necessary fields from a query entry passed to openai's batch API for vision LMs @@ -153,19 +157,42 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]: except IndexError: input_prompt_image_base64 = "" - # At this point, the input_prompt_text is the raw text that was passed to the OpenAI model - # to generate our silver data. But, we want to have a simplfied prompt for this here fine tune, - # so we're going to extract out just the raw extracted prompt text - pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END" + # This code builds the finetuning prompt from the original openai prompt by extracting the "pdf_report hint anchor text" from that + # and reusing it + # # At this point, the input_prompt_text is the raw text that was passed to the OpenAI model + # # to generate our silver data. But, we want to have a simplfied prompt for this here fine tune, + # # so we're going to extract out just the raw extracted prompt text + # pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END" - # Use re.DOTALL to ensure that the dot matches newline characters - match = re.search(pattern, input_prompt_text, re.DOTALL) + # # Use re.DOTALL to ensure that the dot matches newline characters + # match = re.search(pattern, input_prompt_text, re.DOTALL) - if match: - raw_page_text = match.group(1).strip() - else: - raw_page_text = "" + # if match: + # raw_page_text = match.group(1).strip() + # else: + # raw_page_text = "" + + # This code builds the finetuning prompt by redownloading the PDF and extracting it's report one more time + s3_path = custom_id[:custom_id.rindex("-")] + page_num = int(custom_id[custom_id.rindex("-") + 1:]) + + s3_client = boto3.client( + 's3', + aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'), + aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY') + ) + + # Split the s3_path into bucket and key + bucket_name = s3_path.split('s3://')[1].split('/')[0] + s3_key = '/'.join(s3_path.split('s3://')[1].split('/')[1:]) + + + with tempfile.NamedTemporaryFile(delete=False) as tf: + s3_client.download_fileobj(bucket_name, s3_key, tf) + + raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport") + return { "custom_id": custom_id, "input_prompt_text": input_prompt_text, diff --git a/scripts/qwen2vl-7b-lora-gantry.sh b/scripts/qwen2vl-7b-lora-gantry.sh index b8974ea..93b0a6d 100755 --- a/scripts/qwen2vl-7b-lora-gantry.sh +++ b/scripts/qwen2vl-7b-lora-gantry.sh @@ -41,6 +41,8 @@ gantry run \ --env BEAKER_USER_ID=$(beaker account whoami --format json | jq '.[0].name' -cr) \ --env-secret AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \ --env-secret AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \ + --env-secret DS_AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \ + --env-secret DS_AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \ --env-secret WANDB_API_KEY=JAKE_WANDB_API_KEY \ --shared-memory 10GiB \ --yes \ diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 610ec70..6a15823 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -9,7 +9,8 @@ from pdelfin.train.dataloader import ( build_batch_query_response_vision_dataset, extract_openai_batch_query, extract_openai_batch_response, - load_jsonl_into_ds + load_jsonl_into_ds, + list_dataset_files ) from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training @@ -25,8 +26,8 @@ class TestBatchQueryResponseDataset(unittest.TestCase): def testCombinedQueryResponse(self): ds = build_batch_query_response_vision_dataset( - query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_train/*.jsonl", - response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_train/*.json", + query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl", + response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json", ) print(ds) @@ -115,16 +116,25 @@ class TestBatchQueryResponseDataset(unittest.TestCase): print(response_data) print(response_data[0]) - def testIterableDataset(self): - dataset = build_batch_query_response_vision_dataset( - query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", - response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", - ) - processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") + def testPyArrowDirectJson(self): + query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_train/*.jsonl" + response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_train/*.json" + + all_files = list_dataset_files(query_glob_path) - formatted_dataset = dataset.to_iterable_dataset(num_shards=64) - formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500) + import pyarrow as pa + import pyarrow.json as paj + import pyarrow.compute as pc + import pyarrow.fs as fs + + s3 = fs.S3FileSystem() + + block_size = 200 * 1024**2 + + for file in all_files: + with s3.open_input_stream(file.replace("s3://", "")) as f: + table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size)) + + print(f"file {file} rows {table.num_rows}") + print(table.schema) - for entry in formatted_dataset: - print(entry) - break \ No newline at end of file