mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Trying new run that will rewrite the prompts as it goes
This commit is contained in:
parent
97291b3f6a
commit
230c8a9f9a
@ -1,8 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
import re
|
||||
import random
|
||||
import os
|
||||
import base64
|
||||
import glob
|
||||
|
||||
@ -14,6 +14,8 @@ import boto3
|
||||
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
|
||||
from .core.config import DataConfig, SourceConfig
|
||||
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -124,6 +126,8 @@ def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]:
|
||||
return width, height
|
||||
|
||||
|
||||
|
||||
|
||||
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts necessary fields from a query entry passed to openai's batch API for vision LMs
|
||||
@ -153,19 +157,42 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
||||
except IndexError:
|
||||
input_prompt_image_base64 = ""
|
||||
|
||||
# At this point, the input_prompt_text is the raw text that was passed to the OpenAI model
|
||||
# to generate our silver data. But, we want to have a simplfied prompt for this here fine tune,
|
||||
# so we're going to extract out just the raw extracted prompt text
|
||||
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
|
||||
# This code builds the finetuning prompt from the original openai prompt by extracting the "pdf_report hint anchor text" from that
|
||||
# and reusing it
|
||||
# # At this point, the input_prompt_text is the raw text that was passed to the OpenAI model
|
||||
# # to generate our silver data. But, we want to have a simplfied prompt for this here fine tune,
|
||||
# # so we're going to extract out just the raw extracted prompt text
|
||||
# pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
|
||||
|
||||
# Use re.DOTALL to ensure that the dot matches newline characters
|
||||
match = re.search(pattern, input_prompt_text, re.DOTALL)
|
||||
# # Use re.DOTALL to ensure that the dot matches newline characters
|
||||
# match = re.search(pattern, input_prompt_text, re.DOTALL)
|
||||
|
||||
if match:
|
||||
raw_page_text = match.group(1).strip()
|
||||
else:
|
||||
raw_page_text = ""
|
||||
# if match:
|
||||
# raw_page_text = match.group(1).strip()
|
||||
# else:
|
||||
# raw_page_text = ""
|
||||
|
||||
|
||||
# This code builds the finetuning prompt by redownloading the PDF and extracting it's report one more time
|
||||
s3_path = custom_id[:custom_id.rindex("-")]
|
||||
page_num = int(custom_id[custom_id.rindex("-") + 1:])
|
||||
|
||||
s3_client = boto3.client(
|
||||
's3',
|
||||
aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'),
|
||||
aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY')
|
||||
)
|
||||
|
||||
# Split the s3_path into bucket and key
|
||||
bucket_name = s3_path.split('s3://')[1].split('/')[0]
|
||||
s3_key = '/'.join(s3_path.split('s3://')[1].split('/')[1:])
|
||||
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
s3_client.download_fileobj(bucket_name, s3_key, tf)
|
||||
|
||||
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
|
||||
|
||||
return {
|
||||
"custom_id": custom_id,
|
||||
"input_prompt_text": input_prompt_text,
|
||||
|
@ -41,6 +41,8 @@ gantry run \
|
||||
--env BEAKER_USER_ID=$(beaker account whoami --format json | jq '.[0].name' -cr) \
|
||||
--env-secret AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \
|
||||
--env-secret AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \
|
||||
--env-secret DS_AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \
|
||||
--env-secret DS_AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \
|
||||
--env-secret WANDB_API_KEY=JAKE_WANDB_API_KEY \
|
||||
--shared-memory 10GiB \
|
||||
--yes \
|
||||
|
@ -9,7 +9,8 @@ from pdelfin.train.dataloader import (
|
||||
build_batch_query_response_vision_dataset,
|
||||
extract_openai_batch_query,
|
||||
extract_openai_batch_response,
|
||||
load_jsonl_into_ds
|
||||
load_jsonl_into_ds,
|
||||
list_dataset_files
|
||||
)
|
||||
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
|
||||
@ -25,8 +26,8 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
|
||||
def testCombinedQueryResponse(self):
|
||||
ds = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_train/*.jsonl",
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_train/*.json",
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
|
||||
)
|
||||
|
||||
print(ds)
|
||||
@ -115,16 +116,25 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
||||
print(response_data)
|
||||
print(response_data[0])
|
||||
|
||||
def testIterableDataset(self):
|
||||
dataset = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
|
||||
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
def testPyArrowDirectJson(self):
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_train/*.jsonl"
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_train/*.json"
|
||||
|
||||
all_files = list_dataset_files(query_glob_path)
|
||||
|
||||
formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
|
||||
formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500)
|
||||
import pyarrow as pa
|
||||
import pyarrow.json as paj
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as fs
|
||||
|
||||
s3 = fs.S3FileSystem()
|
||||
|
||||
block_size = 200 * 1024**2
|
||||
|
||||
for file in all_files:
|
||||
with s3.open_input_stream(file.replace("s3://", "")) as f:
|
||||
table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size))
|
||||
|
||||
print(f"file {file} rows {table.num_rows}")
|
||||
print(table.schema)
|
||||
|
||||
for entry in formatted_dataset:
|
||||
print(entry)
|
||||
break
|
Loading…
x
Reference in New Issue
Block a user