import json import logging import tempfile import re import os import base64 import glob from functools import partial from typing import Any, Dict, Optional from logging import Logger import boto3 from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict from .core.config import DataConfig, SourceConfig from pdelfin.prompts.anchor import get_anchor_text # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def list_dataset_files(s3_glob_path: str): """ Lists files in the specified S3 path that match the glob pattern. """ if s3_glob_path.startswith("s3://"): s3 = boto3.client("s3") match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path) if not match: logger.error(f"Invalid S3 path: {s3_glob_path}") raise ValueError(f"Invalid S3 path: {s3_glob_path}") bucket, prefix_pattern = match.groups() prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard paginator = s3.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket, Prefix=prefix) files = [] pattern = re.compile(prefix_pattern.replace("*", ".*")) for page in pages: for obj in page.get("Contents", []): key = obj["Key"] if pattern.fullmatch(key): files.append(f"s3://{bucket}/{key}") return files else: return glob.glob(s3_glob_path) def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset: """ Loads JSONL files from the specified S3 path into a Hugging Face Dataset. """ all_json_files = list_dataset_files(s3_glob_path) if first_n_files: all_json_files = all_json_files[:first_n_files] # Use datasets library to load JSON files from S3 dataset = load_dataset( "json", data_files=all_json_files, ) return dataset def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]: """ Returns the (width, height) of a PNG image given its base64-encoded data, without base64-decoding the entire data or loading the PNG itself Should be really fast to support filtering Parameters: - base64_data (str): Base64-encoded PNG image data. Returns: - tuple: (width, height) of the image. Raises: - ValueError: If the data is not a valid PNG image or the required bytes are not found. """ # PNG signature is 8 bytes png_signature_base64 = base64.b64encode(b'\x89PNG\r\n\x1a\n').decode('ascii') if not base64_data.startswith(png_signature_base64[:8]): raise ValueError('Not a valid PNG file') # Positions in the binary data where width and height are stored width_start = 16 # Byte position where width starts (0-based indexing) width_end = 20 # Byte position where width ends (exclusive) height_start = 20 height_end = 24 # Compute the byte range needed (from width_start to height_end) start_byte = width_start end_byte = height_end # Calculate base64 character positions # Each group of 3 bytes corresponds to 4 base64 characters base64_start = (start_byte // 3) * 4 base64_end = ((end_byte + 2) // 3) * 4 # Add 2 to ensure we cover partial groups # Extract the necessary base64 substring base64_substring = base64_data[base64_start:base64_end] # Decode only the necessary bytes decoded_bytes = base64.b64decode(base64_substring) # Compute the offset within the decoded bytes offset = start_byte % 3 # Extract width and height bytes width_bytes = decoded_bytes[offset:offset+4] height_bytes = decoded_bytes[offset+4:offset+8] if len(width_bytes) < 4 or len(height_bytes) < 4: raise ValueError('Insufficient data to extract dimensions') # Convert bytes to integers width = int.from_bytes(width_bytes, 'big') height = int.from_bytes(height_bytes, 'big') return width, height def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]: """ Extracts necessary fields from a query entry passed to openai's batch API for vision LMs """ custom_id = query.get("custom_id", "") body = query.get("body", {}) messages = body.get("messages", []) input_prompt_text = "" input_prompt_image_base64 = "" for message in messages: if message.get("role") != "user": continue # We are only interested in user messages contents = message.get("content", []) for content_item in contents: if content_item.get("type") == "text": input_prompt_text = content_item.get("text", "") elif content_item.get("type") == "image_url": image_url = content_item.get("image_url", {}).get("url", "") if image_url.startswith("data:image"): # Extract base64 part from data URL try: base64_data = image_url.split(",", 1)[1] input_prompt_image_base64 = base64_data except IndexError: input_prompt_image_base64 = "" # This code builds the finetuning prompt from the original openai prompt by extracting the "pdf_report hint anchor text" from that # and reusing it # # At this point, the input_prompt_text is the raw text that was passed to the OpenAI model # # to generate our silver data. But, we want to have a simplfied prompt for this here fine tune, # # so we're going to extract out just the raw extracted prompt text # pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END" # # Use re.DOTALL to ensure that the dot matches newline characters # match = re.search(pattern, input_prompt_text, re.DOTALL) # if match: # raw_page_text = match.group(1).strip() # else: # raw_page_text = "" # This code builds the finetuning prompt by redownloading the PDF and extracting it's report one more time s3_path = custom_id[:custom_id.rindex("-")] page_num = int(custom_id[custom_id.rindex("-") + 1:]) s3_client = boto3.client( 's3', aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'), aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY') ) # Split the s3_path into bucket and key bucket_name = s3_path.split('s3://')[1].split('/')[0] s3_key = '/'.join(s3_path.split('s3://')[1].split('/')[1:]) with tempfile.NamedTemporaryFile(delete=False) as tf: s3_client.download_fileobj(bucket_name, s3_key, tf) raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport") return { "custom_id": custom_id, "input_prompt_text": input_prompt_text, "input_prompt_image_base64": input_prompt_image_base64, "raw_page_text": raw_page_text, } def extract_openai_batch_response(example): custom_id = example.get("custom_id", None) response_body = example.get("response", {}).get("body", {}) choices = response_body.get("choices", []) response = "" finish_reason = "" if choices: first_choice = choices[0] message = first_choice.get("message", {}) response = message.get("content", "") finish_reason = first_choice.get("finish_reason", "") return {"custom_id": custom_id, "response": response, "finish_reason": finish_reason} def merge_query_response(query_example, response_data: Dataset, response_map: dict[str, int]): custom_id = query_example["custom_id"] if custom_id not in response_map: return { "response": None, "finish_reason": None, } response_row = response_data[response_map[custom_id]] return {"response": response_row["response"], "finish_reason": response_row["finish_reason"]} def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str, num_proc: int=32) -> Dataset: logger.info("Loading query and response datasets") query_data = load_jsonl_into_ds(query_glob_path) response_data = load_jsonl_into_ds(response_glob_path) # Map the datasets down to the core fields that we're going to need to make them easier to process logger.info("Mapping query data") query_data = query_data["train"] query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc) logger.info("Mapping response data") response_data = response_data["train"] response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc) # What we're going to do, is build an in-memory map for the response data from custom_id to row # This will let us do quick lookups when we do a merge step, but it will not scale past a certain point logger.info("Building custom_id to row map") custom_id_to_response_row = {} for row_id, entry in enumerate(response_data): custom_id_to_response_row[entry["custom_id"]] = row_id logger.info("Running merge map") final_dataset = query_data.map( partial(merge_query_response, response_data=response_data, response_map=custom_id_to_response_row), num_proc=num_proc ) # Don't include data where the model cut off due to a length issue, or moderation issue final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc) # Pick things that have a reasonable image size only def pick_image_sizes(x): width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"]) return 1800 <= max(width, height) <= 2200 final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc) return final_dataset