mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Trying new run that will rewrite the prompts as it goes
This commit is contained in:
parent
97291b3f6a
commit
230c8a9f9a
@ -1,8 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import tempfile
|
||||||
import re
|
import re
|
||||||
import random
|
import os
|
||||||
import base64
|
import base64
|
||||||
import glob
|
import glob
|
||||||
|
|
||||||
@ -14,6 +14,8 @@ import boto3
|
|||||||
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
|
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
|
||||||
from .core.config import DataConfig, SourceConfig
|
from .core.config import DataConfig, SourceConfig
|
||||||
|
|
||||||
|
from pdelfin.prompts.anchor import get_anchor_text
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -124,6 +126,8 @@ def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]:
|
|||||||
return width, height
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Extracts necessary fields from a query entry passed to openai's batch API for vision LMs
|
Extracts necessary fields from a query entry passed to openai's batch API for vision LMs
|
||||||
@ -153,19 +157,42 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
input_prompt_image_base64 = ""
|
input_prompt_image_base64 = ""
|
||||||
|
|
||||||
# At this point, the input_prompt_text is the raw text that was passed to the OpenAI model
|
# This code builds the finetuning prompt from the original openai prompt by extracting the "pdf_report hint anchor text" from that
|
||||||
# to generate our silver data. But, we want to have a simplfied prompt for this here fine tune,
|
# and reusing it
|
||||||
# so we're going to extract out just the raw extracted prompt text
|
# # At this point, the input_prompt_text is the raw text that was passed to the OpenAI model
|
||||||
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
|
# # to generate our silver data. But, we want to have a simplfied prompt for this here fine tune,
|
||||||
|
# # so we're going to extract out just the raw extracted prompt text
|
||||||
|
# pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
|
||||||
|
|
||||||
# Use re.DOTALL to ensure that the dot matches newline characters
|
# # Use re.DOTALL to ensure that the dot matches newline characters
|
||||||
match = re.search(pattern, input_prompt_text, re.DOTALL)
|
# match = re.search(pattern, input_prompt_text, re.DOTALL)
|
||||||
|
|
||||||
if match:
|
# if match:
|
||||||
raw_page_text = match.group(1).strip()
|
# raw_page_text = match.group(1).strip()
|
||||||
else:
|
# else:
|
||||||
raw_page_text = ""
|
# raw_page_text = ""
|
||||||
|
|
||||||
|
|
||||||
|
# This code builds the finetuning prompt by redownloading the PDF and extracting it's report one more time
|
||||||
|
s3_path = custom_id[:custom_id.rindex("-")]
|
||||||
|
page_num = int(custom_id[custom_id.rindex("-") + 1:])
|
||||||
|
|
||||||
|
s3_client = boto3.client(
|
||||||
|
's3',
|
||||||
|
aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'),
|
||||||
|
aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split the s3_path into bucket and key
|
||||||
|
bucket_name = s3_path.split('s3://')[1].split('/')[0]
|
||||||
|
s3_key = '/'.join(s3_path.split('s3://')[1].split('/')[1:])
|
||||||
|
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||||
|
s3_client.download_fileobj(bucket_name, s3_key, tf)
|
||||||
|
|
||||||
|
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"custom_id": custom_id,
|
"custom_id": custom_id,
|
||||||
"input_prompt_text": input_prompt_text,
|
"input_prompt_text": input_prompt_text,
|
||||||
|
@ -41,6 +41,8 @@ gantry run \
|
|||||||
--env BEAKER_USER_ID=$(beaker account whoami --format json | jq '.[0].name' -cr) \
|
--env BEAKER_USER_ID=$(beaker account whoami --format json | jq '.[0].name' -cr) \
|
||||||
--env-secret AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \
|
--env-secret AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \
|
||||||
--env-secret AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \
|
--env-secret AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \
|
||||||
|
--env-secret DS_AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \
|
||||||
|
--env-secret DS_AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \
|
||||||
--env-secret WANDB_API_KEY=JAKE_WANDB_API_KEY \
|
--env-secret WANDB_API_KEY=JAKE_WANDB_API_KEY \
|
||||||
--shared-memory 10GiB \
|
--shared-memory 10GiB \
|
||||||
--yes \
|
--yes \
|
||||||
|
@ -9,7 +9,8 @@ from pdelfin.train.dataloader import (
|
|||||||
build_batch_query_response_vision_dataset,
|
build_batch_query_response_vision_dataset,
|
||||||
extract_openai_batch_query,
|
extract_openai_batch_query,
|
||||||
extract_openai_batch_response,
|
extract_openai_batch_response,
|
||||||
load_jsonl_into_ds
|
load_jsonl_into_ds,
|
||||||
|
list_dataset_files
|
||||||
)
|
)
|
||||||
|
|
||||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
|
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
|
||||||
@ -25,8 +26,8 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
|||||||
|
|
||||||
def testCombinedQueryResponse(self):
|
def testCombinedQueryResponse(self):
|
||||||
ds = build_batch_query_response_vision_dataset(
|
ds = build_batch_query_response_vision_dataset(
|
||||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_train/*.jsonl",
|
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
||||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_train/*.json",
|
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
print(ds)
|
print(ds)
|
||||||
@ -115,16 +116,25 @@ class TestBatchQueryResponseDataset(unittest.TestCase):
|
|||||||
print(response_data)
|
print(response_data)
|
||||||
print(response_data[0])
|
print(response_data[0])
|
||||||
|
|
||||||
def testIterableDataset(self):
|
def testPyArrowDirectJson(self):
|
||||||
dataset = build_batch_query_response_vision_dataset(
|
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_train/*.jsonl"
|
||||||
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
|
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_train/*.json"
|
||||||
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
|
|
||||||
)
|
all_files = list_dataset_files(query_glob_path)
|
||||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
|
||||||
|
|
||||||
formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
|
import pyarrow as pa
|
||||||
formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500)
|
import pyarrow.json as paj
|
||||||
|
import pyarrow.compute as pc
|
||||||
|
import pyarrow.fs as fs
|
||||||
|
|
||||||
|
s3 = fs.S3FileSystem()
|
||||||
|
|
||||||
|
block_size = 200 * 1024**2
|
||||||
|
|
||||||
|
for file in all_files:
|
||||||
|
with s3.open_input_stream(file.replace("s3://", "")) as f:
|
||||||
|
table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size))
|
||||||
|
|
||||||
|
print(f"file {file} rows {table.num_rows}")
|
||||||
|
print(table.schema)
|
||||||
|
|
||||||
for entry in formatted_dataset:
|
|
||||||
print(entry)
|
|
||||||
break
|
|
Loading…
x
Reference in New Issue
Block a user