| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  | import unittest | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  | from torch.utils.data import DataLoader | 
					
						
							|  |  |  | from tqdm import tqdm | 
					
						
							|  |  |  | from functools import partial | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from transformers import AutoProcessor | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-27 18:30:41 +00:00
										 |  |  | from olmocr.train.dataloader import ( | 
					
						
							| 
									
										
										
										
											2024-10-16 18:06:27 +00:00
										 |  |  |     build_finetuning_dataset, | 
					
						
							| 
									
										
										
										
											2024-09-18 22:52:42 +00:00
										 |  |  |     extract_openai_batch_response, | 
					
						
							| 
									
										
										
										
											2024-10-08 22:10:18 +00:00
										 |  |  |     load_jsonl_into_ds, | 
					
						
							|  |  |  |     list_dataset_files | 
					
						
							| 
									
										
										
										
											2024-09-18 22:52:42 +00:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-27 18:30:41 +00:00
										 |  |  | from olmocr.train.dataprep import batch_prepare_data_for_qwen2_training | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | class TestBatchQueryResponseDataset(unittest.TestCase): | 
					
						
							|  |  |  |     def testLoadS3(self): | 
					
						
							| 
									
										
										
										
											2024-10-07 07:49:16 -07:00
										 |  |  |         ds = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3) | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         print(f"Loaded {len(ds)} entries") | 
					
						
							|  |  |  |         print(ds) | 
					
						
							|  |  |  |         print(ds["train"]) | 
					
						
							| 
									
										
										
										
											2024-09-18 22:52:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 18:06:27 +00:00
										 |  |  |     def testFinetuningDS(self): | 
					
						
							|  |  |  |         ds = build_finetuning_dataset( | 
					
						
							| 
									
										
										
										
											2024-10-08 22:10:18 +00:00
										 |  |  |             response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json", | 
					
						
							| 
									
										
										
										
											2024-09-18 22:52:42 +00:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         print(ds) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  |         processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | 
					
						
							| 
									
										
										
										
											2024-10-16 18:26:25 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000)) | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         print(ds[0]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  |     def testPlotSequenceLengthHistogram(self): | 
					
						
							|  |  |  |         import plotly.express as px   | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 18:26:25 +00:00
										 |  |  |         ds = build_finetuning_dataset( | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  |             response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json", | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-16 18:26:25 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  |         processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 18:26:25 +00:00
										 |  |  |         ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000)) | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 18:26:25 +00:00
										 |  |  |         processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | 
					
						
							| 
									
										
										
										
											2024-09-27 15:48:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-16 18:26:25 +00:00
										 |  |  |         initial_len = len(ds) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         train_dataloader = DataLoader(ds, batch_size=1, num_workers=30, shuffle=False) | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |         max_seen_len = 0 | 
					
						
							|  |  |  |         steps = 0 | 
					
						
							|  |  |  |         sequence_lengths = []  # List to store sequence lengths | 
					
						
							|  |  |  |         for entry in tqdm(train_dataloader): | 
					
						
							|  |  |  |             num_input_tokens = entry["input_ids"].shape[1] | 
					
						
							|  |  |  |             max_seen_len = max(max_seen_len, num_input_tokens) | 
					
						
							|  |  |  |             sequence_lengths.append(num_input_tokens)  # Collecting sequence lengths | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if steps % 100 == 0: | 
					
						
							|  |  |  |                 print(f"Max input len {max_seen_len}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             steps += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()}) | 
					
						
							|  |  |  |         print(f"Max input len {max_seen_len}") | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  |         print(f"Total elements before filtering: {initial_len}") | 
					
						
							|  |  |  |         print(f"Total elements after filtering: {steps}") | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Plotting the histogram using Plotly | 
					
						
							|  |  |  |         fig = px.histogram( | 
					
						
							|  |  |  |             sequence_lengths, | 
					
						
							|  |  |  |             nbins=100, | 
					
						
							|  |  |  |             title="Distribution of Input Sequence Lengths", | 
					
						
							|  |  |  |             labels={'value': 'Sequence Length', 'count': 'Frequency'} | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         fig.write_image("sequence_lengths_histogram.png") |