import json import logging import tempfile import re import os import base64 import glob import pypdf, pypdf.errors from functools import partial from typing import Any, Dict, Optional from logging import Logger from filelock import FileLock import boto3 from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict from .core.config import DataConfig, SourceConfig from pdelfin.prompts.anchor import get_anchor_text from pdelfin.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Quiet logs from pypdf and smart open logging.getLogger("pypdf").setLevel(logging.ERROR) logging.getLogger("smart_open").setLevel(logging.ERROR) def list_dataset_files(s3_glob_path: str): """ Lists files in the specified S3 path that match the glob pattern. """ if s3_glob_path.startswith("s3://"): s3 = boto3.client("s3") match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path) if not match: logger.error(f"Invalid S3 path: {s3_glob_path}") raise ValueError(f"Invalid S3 path: {s3_glob_path}") bucket, prefix_pattern = match.groups() prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard paginator = s3.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket, Prefix=prefix) files = [] pattern = re.compile(prefix_pattern.replace("*", ".*")) for page in pages: for obj in page.get("Contents", []): key = obj["Key"] if pattern.fullmatch(key): files.append(f"s3://{bucket}/{key}") return files else: return glob.glob(s3_glob_path) def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset: """ Loads JSONL files from the specified S3 path into a Hugging Face Dataset. """ all_json_files = list_dataset_files(s3_glob_path) if first_n_files: all_json_files = all_json_files[:first_n_files] # Use datasets library to load JSON files from S3 dataset = load_dataset( "json", data_files=all_json_files, ) return dataset def extract_openai_batch_response(example): custom_id = example.get("custom_id", None) # Parse the custom id into an s3 document path and page number (1indexed) s3_path, page_num = parse_custom_id(custom_id) response_body = example.get("response", {}).get("body", {}) choices = response_body.get("choices", []) response = "" finish_reason = "" if choices: first_choice = choices[0] message = first_choice.get("message", {}) response = message.get("content", "") finish_reason = first_choice.get("finish_reason", "") # TODO Maybe in the future we can parse the response (which is a structured JSON document itself) # into its own columns return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason} def _cache_s3_file(s3_path: str, local_cache_dir: str): """ Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file. """ bucket, key = parse_s3_path(s3_path) # Define the local file path local_file_path = os.path.join(local_cache_dir, bucket + "__" + key.replace("/", "_")) os.makedirs(os.path.dirname(local_file_path), exist_ok=True) lock_file = f"{local_file_path}.lock" # Use a file lock to prevent concurrent writes with FileLock(lock_file): if not os.path.exists(local_file_path): logger.info(f"Downloading {s3_path} to {local_file_path}") s3_client = boto3.client( 's3', aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'), aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY') ) s3_client.download_file(bucket, key, local_file_path) else: pass #logger.info(f"File {local_file_path} already exists, skipping download.") return local_file_path def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset: """ Caches all S3 paths in the dataset to the local cache directory. """ # Define the download function to use in parallel processing def cache_file(example): s3_path = example["s3_path"] if s3_path: # Download the file and cache it locally local_path = _cache_s3_file(s3_path, pdf_cache_location) return {"local_pdf_path": local_path} return {"local_pdf_path": None} # Map the caching function to the dataset (with parallelism if needed) dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False) return dataset def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str]=None, num_proc: int=32) -> Dataset: if pdf_cache_location is None: pdf_cache_location = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin_pdfs') logger.info("Loading fine tuning dataset from OpenAI style batch responses") response_data = load_jsonl_into_ds(response_glob_path) response_data = response_data["train"] response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc) # Don't include data where the model cut off due to a length issue, or moderation issue logger.info("Filtering on finish_reason == stop") final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc) # Cache all the s3_paths that were accessed to a local storage location, final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc) # Filter out pages where you cannot get an anchor text generated, to prevent errors during actual training def _can_create_anchor_text(example): try: anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=4000) return anchor_text is not None except: return False final_dataset = final_dataset.filter(_can_create_anchor_text, num_proc=num_proc) return final_dataset