mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 10:04:26 +00:00 
			
		
		
		
	Generating parquets for hugging face
This commit is contained in:
		
							parent
							
								
									84c0c71393
								
							
						
					
					
						commit
						6ed6f85c42
					
				
							
								
								
									
										186
									
								
								olmocr/train/convertjsontoparquet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								olmocr/train/convertjsontoparquet.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,186 @@ | ||||
| # Script to generate parquet dataset files to upload to hugging face | ||||
| # Input is a dataset location /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json | ||||
| # Each json line has a custom id that looks like {"custom_id": "s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1", ... more data} | ||||
| 
 | ||||
| # Fix this script so that it works, and that it will take a path to an input dataset, and sqllite database location | ||||
| # And then it will build a parquet file with rows that look like: "id", "url", "page_number", "response" | ||||
| # Where Id will be the output of parse_pdf_hash plus "-" plus the page number | ||||
| # The url will be the result of get_uri_from_db | ||||
| # Rresponse will be NormalizedEntry.text | ||||
| 
 | ||||
| import argparse | ||||
| import glob | ||||
| import json | ||||
| import re | ||||
| import sqlite3 | ||||
| from dataclasses import dataclass | ||||
| from typing import Optional | ||||
| 
 | ||||
| from tqdm import tqdm | ||||
| import pandas as pd | ||||
| 
 | ||||
| 
 | ||||
| def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]: | ||||
|     """ | ||||
|     Extracts a hash from a pretty PDF S3 URL. | ||||
|     For example, given: | ||||
|       s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1 | ||||
|     it will return "de80a57e6c57b45796d2e020173227f7eae44232". | ||||
|     """ | ||||
|     pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf" | ||||
|     match = re.match(pattern, pretty_pdf_path) | ||||
|     if match: | ||||
|         return match.group(1) + match.group(2) | ||||
|     return None | ||||
| 
 | ||||
| 
 | ||||
| def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]: | ||||
|     """ | ||||
|     Looks up the URL for the given pdf_hash in the sqlite database. | ||||
|     Assumes there is a table called 'pdf_mapping' with a column 'uri'. | ||||
|     """ | ||||
|     conn = sqlite3.connect(db_path) | ||||
|     cursor = conn.cursor() | ||||
|     cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,)) | ||||
|     result = cursor.fetchone() | ||||
|     conn.close() | ||||
|     return result[0] if result else None | ||||
| 
 | ||||
| 
 | ||||
| @dataclass(frozen=True) | ||||
| class NormalizedEntry: | ||||
|     s3_path: str | ||||
|     pagenum: int | ||||
|     text: Optional[str] | ||||
|     finish_reason: Optional[str] | ||||
|     error: Optional[str] = None | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_goldkey(goldkey: str, **kwargs): | ||||
|         """ | ||||
|         Constructs a NormalizedEntry from a goldkey string. | ||||
|         The goldkey is expected to be of the format:  | ||||
|           <s3_path>-<page_number> | ||||
|         """ | ||||
|         s3_path = goldkey[: goldkey.rindex("-")] | ||||
|         page_num = int(goldkey[goldkey.rindex("-") + 1 :]) | ||||
|         return NormalizedEntry(s3_path, page_num, **kwargs) | ||||
| 
 | ||||
|     @property | ||||
|     def goldkey(self): | ||||
|         return f"{self.s3_path}-{self.pagenum}" | ||||
| 
 | ||||
| 
 | ||||
| def normalize_json_entry(data: dict) -> NormalizedEntry: | ||||
|     """ | ||||
|     Normalizes a JSON entry from any of the supported formats. | ||||
|     It supports: | ||||
|       - Birr: looks for an "outputs" field. | ||||
|       - Already normalized entries: if they contain s3_path, pagenum, etc. | ||||
|       - OpenAI: where the response is in data["response"]["body"]["choices"]. | ||||
|       - SGLang: where the response is in data["response"]["choices"]. | ||||
|     """ | ||||
|     if "outputs" in data: | ||||
|         # Birr case | ||||
|         if data["outputs"] is None: | ||||
|             text = None | ||||
|             finish_reason = None | ||||
|         else: | ||||
|             text = data["outputs"][0]["text"] | ||||
|             finish_reason = data["outputs"][0]["finish_reason"] | ||||
| 
 | ||||
|         return NormalizedEntry.from_goldkey( | ||||
|             goldkey=data["custom_id"], | ||||
|             text=text, | ||||
|             finish_reason=finish_reason, | ||||
|             error=data.get("completion_error", None), | ||||
|         ) | ||||
|     elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]): | ||||
|         # Already normalized | ||||
|         return NormalizedEntry(**data) | ||||
|     elif "response" in data and "body" in data["response"] and "choices" in data["response"]["body"]: | ||||
|             return NormalizedEntry.from_goldkey( | ||||
|                 goldkey=data["custom_id"], | ||||
|                 text=data["response"]["body"]["choices"][0]["message"]["content"], | ||||
|                 finish_reason=data["response"]["body"]["choices"][0]["finish_reason"], | ||||
|             ) | ||||
| 
 | ||||
| 
 | ||||
| def main(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Generate a Parquet dataset file for HuggingFace upload." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "input_dataset", | ||||
|         help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "db_path", help="Path to the SQLite database file." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--output", default="output.parquet", help="Output Parquet file path." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
|     rows = [] | ||||
|     files = glob.glob(args.input_dataset) | ||||
|     print(f"Found {len(files)} files matching pattern: {args.input_dataset}") | ||||
| 
 | ||||
|     for file_path in tqdm(files): | ||||
|         print(f"Processing file: {file_path}") | ||||
|         with open(file_path, "r", encoding="utf-8") as f: | ||||
|             for line_num, line in enumerate(f, start=1): | ||||
|                 line = line.strip() | ||||
|                 if not line: | ||||
|                     continue | ||||
|                 try: | ||||
|                     data = json.loads(line) | ||||
|                 except json.JSONDecodeError as e: | ||||
|                     print(f"Skipping invalid JSON at {file_path}:{line_num} - {e}") | ||||
|                     continue | ||||
| 
 | ||||
|                 try: | ||||
|                     normalized = normalize_json_entry(data) | ||||
|                 except Exception as e: | ||||
|                     print(f"Error normalizing entry at {file_path}:{line_num} - {e}") | ||||
|                     continue | ||||
| 
 | ||||
|                 # Use the s3_path from the normalized entry to extract the pdf hash. | ||||
|                 pdf_hash = parse_pdf_hash(normalized.s3_path) | ||||
|                 if pdf_hash is None: | ||||
|                     print( | ||||
|                         f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}" | ||||
|                     ) | ||||
|                     continue | ||||
| 
 | ||||
|                 # The output id is the pdf hash plus '-' plus the page number. | ||||
|                 combined_id = f"{pdf_hash}-{normalized.pagenum}" | ||||
| 
 | ||||
|                 # Look up the corresponding URL from the sqlite database. | ||||
|                 url = get_uri_from_db(args.db_path, pdf_hash) | ||||
|                 if url is None: | ||||
|                     print( | ||||
|                         f"No URL found in DB for pdf hash {pdf_hash} at {file_path}:{line_num}" | ||||
|                     ) | ||||
|                     continue | ||||
| 
 | ||||
|                 row = { | ||||
|                     "id": combined_id, | ||||
|                     "url": url, | ||||
|                     "page_number": normalized.pagenum, | ||||
|                     "response": normalized.text, | ||||
|                 } | ||||
|                 rows.append(row) | ||||
| 
 | ||||
|         break | ||||
| 
 | ||||
|     if rows: | ||||
|         df = pd.DataFrame(rows) | ||||
|         df.to_parquet(args.output, index=False) | ||||
|         print(f"Successfully wrote {len(df)} rows to {args.output}") | ||||
|     else: | ||||
|         print("No rows to write. Exiting.") | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jake Poznanski
						Jake Poznanski