Fixing dataloader hopefully

This commit is contained in:
Jake Poznanski 2024-10-15 15:13:25 +00:00
parent 6d53683001
commit fc8fcfaeba

View File

@ -5,6 +5,7 @@ import re
import os
import base64
import glob
import pypdf, pypdf.errors
from functools import partial
from typing import Any, Dict, Optional
@ -190,9 +191,12 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
with tempfile.NamedTemporaryFile(delete=False) as tf:
s3_client.download_fileobj(bucket_name, s3_key, tf)
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
try:
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
except pypdf.errors.PdfReadError:
raw_page_text = None
return {
"custom_id": custom_id,
"input_prompt_text": input_prompt_text,
@ -239,6 +243,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
logger.info("Mapping query data")
query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc)
query_data = query_data.filter(lambda x: x["raw_page_text"] is not None, num_proc=num_proc)
logger.info("Mapping response data")
response_data = response_data["train"]
@ -258,6 +263,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
)
# Don't include data where the model cut off due to a length issue, or moderation issue
logger.info("Filtering on finish_reason == stop")
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
# Pick things that have a reasonable image size only
@ -265,6 +271,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"])
return 1800 <= max(width, height) <= 2200
logger.info("Filtering on image size")
final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc)
return final_dataset