mirror of
https://github.com/allenai/olmocr.git
synced 2025-07-30 12:32:16 +00:00
Fixing dataloader hopefully
This commit is contained in:
parent
6d53683001
commit
fc8fcfaeba
@ -5,6 +5,7 @@ import re
|
||||
import os
|
||||
import base64
|
||||
import glob
|
||||
import pypdf, pypdf.errors
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional
|
||||
@ -190,9 +191,12 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
s3_client.download_fileobj(bucket_name, s3_key, tf)
|
||||
|
||||
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
|
||||
|
||||
|
||||
try:
|
||||
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
|
||||
except pypdf.errors.PdfReadError:
|
||||
raw_page_text = None
|
||||
|
||||
return {
|
||||
"custom_id": custom_id,
|
||||
"input_prompt_text": input_prompt_text,
|
||||
@ -239,6 +243,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
|
||||
logger.info("Mapping query data")
|
||||
query_data = query_data["train"]
|
||||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc)
|
||||
query_data = query_data.filter(lambda x: x["raw_page_text"] is not None, num_proc=num_proc)
|
||||
|
||||
logger.info("Mapping response data")
|
||||
response_data = response_data["train"]
|
||||
@ -258,6 +263,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
|
||||
)
|
||||
|
||||
# Don't include data where the model cut off due to a length issue, or moderation issue
|
||||
logger.info("Filtering on finish_reason == stop")
|
||||
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
|
||||
|
||||
# Pick things that have a reasonable image size only
|
||||
@ -265,6 +271,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
|
||||
width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"])
|
||||
return 1800 <= max(width, height) <= 2200
|
||||
|
||||
logger.info("Filtering on image size")
|
||||
final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc)
|
||||
|
||||
return final_dataset
|
||||
|
Loading…
x
Reference in New Issue
Block a user