mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 10:04:26 +00:00 
			
		
		
		
	Fixing dataloader hopefully
This commit is contained in:
		
							parent
							
								
									6d53683001
								
							
						
					
					
						commit
						fc8fcfaeba
					
				| @ -5,6 +5,7 @@ import re | |||||||
| import os | import os | ||||||
| import base64 | import base64 | ||||||
| import glob | import glob | ||||||
|  | import pypdf, pypdf.errors | ||||||
| 
 | 
 | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Any, Dict, Optional | from typing import Any, Dict, Optional | ||||||
| @ -191,7 +192,10 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]: | |||||||
|     with tempfile.NamedTemporaryFile(delete=False) as tf: |     with tempfile.NamedTemporaryFile(delete=False) as tf: | ||||||
|         s3_client.download_fileobj(bucket_name, s3_key, tf) |         s3_client.download_fileobj(bucket_name, s3_key, tf) | ||||||
|      |      | ||||||
|     raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport") |     try: | ||||||
|  |         raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport") | ||||||
|  |     except pypdf.errors.PdfReadError: | ||||||
|  |         raw_page_text = None | ||||||
|    |    | ||||||
|     return { |     return { | ||||||
|         "custom_id": custom_id, |         "custom_id": custom_id, | ||||||
| @ -239,6 +243,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo | |||||||
|     logger.info("Mapping query data") |     logger.info("Mapping query data") | ||||||
|     query_data = query_data["train"] |     query_data = query_data["train"] | ||||||
|     query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc) |     query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names, num_proc=num_proc) | ||||||
|  |     query_data = query_data.filter(lambda x: x["raw_page_text"] is not None, num_proc=num_proc) | ||||||
| 
 | 
 | ||||||
|     logger.info("Mapping response data") |     logger.info("Mapping response data") | ||||||
|     response_data = response_data["train"] |     response_data = response_data["train"] | ||||||
| @ -258,6 +263,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo | |||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     # Don't include data where the model cut off due to a length issue, or moderation issue |     # Don't include data where the model cut off due to a length issue, or moderation issue | ||||||
|  |     logger.info("Filtering on finish_reason == stop") | ||||||
|     final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc) |     final_dataset = final_dataset.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc) | ||||||
| 
 | 
 | ||||||
|     # Pick things that have a reasonable image size only |     # Pick things that have a reasonable image size only | ||||||
| @ -265,6 +271,7 @@ def build_batch_query_response_vision_dataset(query_glob_path: str, response_glo | |||||||
|         width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"]) |         width, height = get_png_dimensions_from_base64(x["input_prompt_image_base64"]) | ||||||
|         return 1800 <= max(width, height) <= 2200 |         return 1800 <= max(width, height) <= 2200 | ||||||
| 
 | 
 | ||||||
|  |     logger.info("Filtering on image size") | ||||||
|     final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc) |     final_dataset = final_dataset.filter(pick_image_sizes, num_proc=num_proc) | ||||||
| 
 | 
 | ||||||
|     return final_dataset |     return final_dataset | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jake Poznanski
						Jake Poznanski