Fixing dataloader hopefully

This commit is contained in:
Jake Poznanski 2024-10-15 15:13:25 +00:00
parent 6d53683001
commit fc8fcfaeba

View File

@ -5,6 +5,7 @@ import re
import os import os
import base64 import base64
import glob import glob
import pypdf, pypdf.errors
from functools import partial from functools import partial
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -190,9 +191,12 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
with tempfile.NamedTemporaryFile(delete=False) as tf: with tempfile.NamedTemporaryFile(delete=False) as tf:
s3_client.download_fileobj(bucket_name, s3_key, tf) s3_client.download_fileobj(bucket_name, s3_key, tf)
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport") try:
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
except pypdf.errors.PdfReadError:
raw_page_text = None
return { return {
"custom_id": custom_id, "custom_id": custom_id,
"input_prompt_text": input_prompt_text, "input_prompt_text": input_prompt_text,
@ -239,6 +243,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
logger.info("Mapping query data") logger.info("Mapping query data")
query_data = query_data["train"] query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc) query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc)
query_data = query_data.filter(lambda x: x["raw_page_text"] is not None, num_proc=num_proc)
logger.info("Mapping response data") logger.info("Mapping response data")
response_data = response_data["train"] response_data = response_data["train"]
@ -258,6 +263,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
) )
# Don't include data where the model cut off due to a length issue, or moderation issue # Don't include data where the model cut off due to a length issue, or moderation issue
logger.info("Filtering on finish_reason == stop")
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc) final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
# Pick things that have a reasonable image size only # Pick things that have a reasonable image size only
@ -265,6 +271,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"]) width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"])
return 1800 <= max(width, height) <= 2200 return 1800 <= max(width, height) <= 2200
logger.info("Filtering on image size")
final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc) final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc)
return final_dataset return final_dataset