mirror of
https://github.com/allenai/olmocr.git
synced 2025-07-31 04:46:33 +00:00
Fixing dataloader hopefully
This commit is contained in:
parent
6d53683001
commit
fc8fcfaeba
@ -5,6 +5,7 @@ import re
|
|||||||
import os
|
import os
|
||||||
import base64
|
import base64
|
||||||
import glob
|
import glob
|
||||||
|
import pypdf, pypdf.errors
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
@ -190,9 +191,12 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||||
s3_client.download_fileobj(bucket_name, s3_key, tf)
|
s3_client.download_fileobj(bucket_name, s3_key, tf)
|
||||||
|
|
||||||
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
|
try:
|
||||||
|
raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")
|
||||||
|
except pypdf.errors.PdfReadError:
|
||||||
|
raw_page_text = None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"custom_id": custom_id,
|
"custom_id": custom_id,
|
||||||
"input_prompt_text": input_prompt_text,
|
"input_prompt_text": input_prompt_text,
|
||||||
@ -239,6 +243,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
|
|||||||
logger.info("Mapping query data")
|
logger.info("Mapping query data")
|
||||||
query_data = query_data["train"]
|
query_data = query_data["train"]
|
||||||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc)
|
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc)
|
||||||
|
query_data = query_data.filter(lambda x: x["raw_page_text"] is not None, num_proc=num_proc)
|
||||||
|
|
||||||
logger.info("Mapping response data")
|
logger.info("Mapping response data")
|
||||||
response_data = response_data["train"]
|
response_data = response_data["train"]
|
||||||
@ -258,6 +263,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Don't include data where the model cut off due to a length issue, or moderation issue
|
# Don't include data where the model cut off due to a length issue, or moderation issue
|
||||||
|
logger.info("Filtering on finish_reason == stop")
|
||||||
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
|
final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
|
||||||
|
|
||||||
# Pick things that have a reasonable image size only
|
# Pick things that have a reasonable image size only
|
||||||
@ -265,6 +271,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo
|
|||||||
width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"])
|
width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"])
|
||||||
return 1800 <= max(width, height) <= 2200
|
return 1800 <= max(width, height) <= 2200
|
||||||
|
|
||||||
|
logger.info("Filtering on image size")
|
||||||
final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc)
|
final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc)
|
||||||
|
|
||||||
return final_dataset
|
return final_dataset
|
||||||
|
Loading…
x
Reference in New Issue
Block a user