mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 01:55:06 +00:00 
			
		
		
		
	Fixing dataloader hopefully
This commit is contained in:
		
							parent
							
								
									6d53683001
								
							
						
					
					
						commit
						fc8fcfaeba
					
				| @ -5,6 +5,7 @@ import re | ||||
| import os | ||||
| import base64 | ||||
| import glob | ||||
| import pypdf, pypdf.errors | ||||
| 
 | ||||
| from functools import partial | ||||
| from typing import Any, Dict, Optional | ||||
| @ -191,7 +192,10 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]: | ||||
|     with tempfile.NamedTemporaryFile(delete=False) as tf: | ||||
|         s3_client.download_fileobj(bucket_name, s3_key, tf) | ||||
|      | ||||
|     try: | ||||
|         raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport") | ||||
|     except pypdf.errors.PdfReadError: | ||||
|         raw_page_text = None | ||||
|    | ||||
|     return { | ||||
|         "custom_id": custom_id, | ||||
| @ -239,6 +243,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo | ||||
|     logger.info("Mapping query data") | ||||
|     query_data = query_data["train"] | ||||
|     query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc) | ||||
|     query_data = query_data.filter(lambda x: x["raw_page_text"] is not None, num_proc=num_proc) | ||||
| 
 | ||||
|     logger.info("Mapping response data") | ||||
|     response_data = response_data["train"] | ||||
| @ -258,6 +263,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo | ||||
|     ) | ||||
| 
 | ||||
|     # Don't include data where the model cut off due to a length issue, or moderation issue | ||||
|     logger.info("Filtering on finish_reason == stop") | ||||
|     final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc) | ||||
| 
 | ||||
|     # Pick things that have a reasonable image size only | ||||
| @ -265,6 +271,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo | ||||
|         width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"]) | ||||
|         return 1800 <= max(width, height) <= 2200 | ||||
| 
 | ||||
|     logger.info("Filtering on image size") | ||||
|     final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc) | ||||
| 
 | ||||
|     return final_dataset | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jake Poznanski
						Jake Poznanski