2024-09-18 21:42:09 +00:00
|
|
|
import json
|
|
|
|
import logging
|
2024-10-08 22:10:18 +00:00
|
|
|
import tempfile
|
2024-09-18 22:52:42 +00:00
|
|
|
import re
|
2024-10-08 22:10:18 +00:00
|
|
|
import os
|
2024-09-25 09:49:03 -07:00
|
|
|
import base64
|
2024-10-07 07:49:16 -07:00
|
|
|
import glob
|
2024-10-15 15:13:25 +00:00
|
|
|
import pypdf, pypdf.errors
|
2024-09-23 09:40:24 -07:00
|
|
|
|
2024-09-18 22:48:38 +00:00
|
|
|
from functools import partial
|
2024-09-23 09:40:24 -07:00
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
from logging import Logger
|
2024-10-16 18:06:27 +00:00
|
|
|
from filelock import FileLock
|
2024-09-18 21:42:09 +00:00
|
|
|
|
2024-09-18 22:52:42 +00:00
|
|
|
import boto3
|
2024-09-23 09:43:36 -07:00
|
|
|
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
|
2024-09-23 09:40:24 -07:00
|
|
|
from .core.config import DataConfig, SourceConfig
|
|
|
|
|
2024-10-08 22:10:18 +00:00
|
|
|
from pdelfin.prompts.anchor import get_anchor_text
|
2024-10-16 18:06:27 +00:00
|
|
|
from pdelfin.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path
|
2024-10-17 02:28:43 +00:00
|
|
|
from pdelfin.data.renderpdf import get_pdf_media_box_width_height
|
2024-10-08 22:10:18 +00:00
|
|
|
|
2024-09-18 21:42:09 +00:00
|
|
|
# Configure logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2024-10-16 18:26:25 +00:00
|
|
|
# Quiet logs from pypdf and smart open
|
|
|
|
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
|
|
|
logging.getLogger("smart_open").setLevel(logging.ERROR)
|
2024-09-18 21:42:09 +00:00
|
|
|
|
2024-10-07 21:14:33 +00:00
|
|
|
def list_dataset_files(s3_glob_path: str):
|
2024-09-18 21:42:09 +00:00
|
|
|
"""
|
|
|
|
Lists files in the specified S3 path that match the glob pattern.
|
|
|
|
"""
|
2024-10-07 21:14:33 +00:00
|
|
|
if s3_glob_path.startswith("s3://"):
|
2024-10-07 07:49:16 -07:00
|
|
|
s3 = boto3.client("s3")
|
2024-10-07 21:14:33 +00:00
|
|
|
match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path)
|
2024-10-07 07:49:16 -07:00
|
|
|
if not match:
|
2024-10-07 21:14:33 +00:00
|
|
|
logger.error(f"Invalid S3 path: {s3_glob_path}")
|
|
|
|
raise ValueError(f"Invalid S3 path: {s3_glob_path}")
|
2024-10-07 07:49:16 -07:00
|
|
|
|
|
|
|
bucket, prefix_pattern = match.groups()
|
|
|
|
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
|
|
|
|
paginator = s3.get_paginator("list_objects_v2")
|
|
|
|
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
|
|
|
|
|
|
|
files = []
|
|
|
|
pattern = re.compile(prefix_pattern.replace("*", ".*"))
|
|
|
|
for page in pages:
|
|
|
|
for obj in page.get("Contents", []):
|
|
|
|
key = obj["Key"]
|
|
|
|
if pattern.fullmatch(key):
|
|
|
|
files.append(f"s3://{bucket}/{key}")
|
|
|
|
return files
|
|
|
|
else:
|
2024-10-07 21:14:33 +00:00
|
|
|
return glob.glob(s3_glob_path)
|
2024-10-07 07:49:16 -07:00
|
|
|
|
|
|
|
|
|
|
|
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset:
|
2024-09-18 21:42:09 +00:00
|
|
|
"""
|
|
|
|
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
|
|
|
|
"""
|
2024-10-07 15:25:48 -07:00
|
|
|
all_json_files = list_dataset_files(s3_glob_path)
|
2024-09-18 21:42:09 +00:00
|
|
|
|
|
|
|
if first_n_files:
|
2024-10-07 07:49:16 -07:00
|
|
|
all_json_files = all_json_files[:first_n_files]
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-09-18 21:42:09 +00:00
|
|
|
# Use datasets library to load JSON files from S3
|
|
|
|
dataset = load_dataset(
|
2024-09-18 22:48:38 +00:00
|
|
|
"json",
|
2024-10-07 07:49:16 -07:00
|
|
|
data_files=all_json_files,
|
2024-09-18 21:42:09 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
return dataset
|
|
|
|
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-09-18 22:48:38 +00:00
|
|
|
def extract_openai_batch_response(example):
|
2024-09-18 22:52:42 +00:00
|
|
|
custom_id = example.get("custom_id", None)
|
2024-10-16 18:06:27 +00:00
|
|
|
|
|
|
|
# Parse the custom id into an s3 document path and page number (1indexed)
|
|
|
|
s3_path, page_num = parse_custom_id(custom_id)
|
|
|
|
|
2024-09-18 22:52:42 +00:00
|
|
|
response_body = example.get("response", {}).get("body", {})
|
|
|
|
choices = response_body.get("choices", [])
|
|
|
|
response = ""
|
|
|
|
finish_reason = ""
|
2024-09-18 22:48:38 +00:00
|
|
|
if choices:
|
|
|
|
first_choice = choices[0]
|
2024-09-18 22:52:42 +00:00
|
|
|
message = first_choice.get("message", {})
|
|
|
|
response = message.get("content", "")
|
|
|
|
finish_reason = first_choice.get("finish_reason", "")
|
2024-09-18 22:48:38 +00:00
|
|
|
|
2024-10-16 18:06:27 +00:00
|
|
|
# TODO Maybe in the future we can parse the response (which is a structured JSON document itself)
|
|
|
|
# into its own columns
|
2024-09-18 22:48:38 +00:00
|
|
|
|
2024-10-16 18:06:27 +00:00
|
|
|
return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason}
|
2024-09-18 22:48:38 +00:00
|
|
|
|
2024-10-16 18:06:27 +00:00
|
|
|
def _cache_s3_file(s3_path: str, local_cache_dir: str):
|
|
|
|
"""
|
|
|
|
Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file.
|
|
|
|
"""
|
|
|
|
bucket, key = parse_s3_path(s3_path)
|
|
|
|
|
|
|
|
# Define the local file path
|
2024-10-16 20:44:23 +00:00
|
|
|
local_file_path = os.path.join(local_cache_dir, bucket + "__" + key.replace("/", "_"))
|
2024-10-16 18:06:27 +00:00
|
|
|
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
|
|
|
lock_file = f"{local_file_path}.lock"
|
|
|
|
|
|
|
|
# Use a file lock to prevent concurrent writes
|
|
|
|
with FileLock(lock_file):
|
|
|
|
if not os.path.exists(local_file_path):
|
|
|
|
logger.info(f"Downloading {s3_path} to {local_file_path}")
|
|
|
|
s3_client = boto3.client(
|
|
|
|
's3',
|
|
|
|
aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'),
|
|
|
|
aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY')
|
|
|
|
)
|
|
|
|
s3_client.download_file(bucket, key, local_file_path)
|
|
|
|
else:
|
2024-10-16 23:01:40 +00:00
|
|
|
pass
|
|
|
|
#logger.info(f"File {local_file_path} already exists, skipping download.")
|
2024-10-16 18:06:27 +00:00
|
|
|
|
|
|
|
return local_file_path
|
|
|
|
|
|
|
|
def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset:
|
|
|
|
"""
|
|
|
|
Caches all S3 paths in the dataset to the local cache directory.
|
|
|
|
"""
|
|
|
|
# Define the download function to use in parallel processing
|
|
|
|
def cache_file(example):
|
|
|
|
s3_path = example["s3_path"]
|
|
|
|
if s3_path:
|
|
|
|
# Download the file and cache it locally
|
|
|
|
local_path = _cache_s3_file(s3_path, pdf_cache_location)
|
|
|
|
return {"local_pdf_path": local_path}
|
|
|
|
return {"local_pdf_path": None}
|
|
|
|
|
|
|
|
# Map the caching function to the dataset (with parallelism if needed)
|
|
|
|
dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False)
|
2024-09-18 22:48:38 +00:00
|
|
|
|
2024-10-16 18:06:27 +00:00
|
|
|
return dataset
|
2024-09-18 22:52:42 +00:00
|
|
|
|
2024-10-16 23:31:40 +00:00
|
|
|
|
2024-10-16 18:06:27 +00:00
|
|
|
def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str]=None, num_proc: int=32) -> Dataset:
|
|
|
|
if pdf_cache_location is None:
|
|
|
|
pdf_cache_location = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin_pdfs')
|
2024-09-18 22:48:38 +00:00
|
|
|
|
2024-10-16 18:06:27 +00:00
|
|
|
logger.info("Loading fine tuning dataset from OpenAI style batch responses")
|
2024-10-07 07:49:16 -07:00
|
|
|
response_data = load_jsonl_into_ds(response_glob_path)
|
2024-09-18 22:48:38 +00:00
|
|
|
response_data = response_data["train"]
|
|
|
|
|
2024-10-16 18:06:27 +00:00
|
|
|
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc)
|
2024-09-25 09:49:03 -07:00
|
|
|
|
|
|
|
# Don't include data where the model cut off due to a length issue, or moderation issue
|
2024-10-15 15:13:25 +00:00
|
|
|
logger.info("Filtering on finish_reason == stop")
|
2024-10-16 18:06:27 +00:00
|
|
|
final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
|
2024-09-18 21:42:09 +00:00
|
|
|
|
2024-10-16 18:06:27 +00:00
|
|
|
# Cache all the s3_paths that were accessed to a local storage location,
|
|
|
|
final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc)
|
2024-09-25 09:49:03 -07:00
|
|
|
|
2024-10-16 23:31:40 +00:00
|
|
|
# Filter out pages where you cannot get an anchor text generated, to prevent errors during actual training
|
|
|
|
def _can_create_anchor_text(example):
|
|
|
|
try:
|
|
|
|
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=4000)
|
2024-10-17 02:28:43 +00:00
|
|
|
_ = get_pdf_media_box_width_height(example["local_pdf_path"], example["page_num"])
|
2024-10-16 23:31:40 +00:00
|
|
|
return anchor_text is not None
|
|
|
|
except:
|
|
|
|
return False
|
|
|
|
|
|
|
|
final_dataset = final_dataset.filter(_can_create_anchor_text, num_proc=num_proc)
|
|
|
|
|
2024-09-18 22:48:38 +00:00
|
|
|
return final_dataset
|