| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  | import unittest | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  | from torch.utils.data import DataLoader | 
					
						
							|  |  |  | from tqdm import tqdm | 
					
						
							|  |  |  | from functools import partial | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from transformers import AutoProcessor | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-18 22:52:42 +00:00
										 |  |  | from pdelfin.train.dataloader import ( | 
					
						
							|  |  |  |     build_batch_query_response_vision_dataset, | 
					
						
							|  |  |  |     extract_openai_batch_query, | 
					
						
							|  |  |  |     extract_openai_batch_response, | 
					
						
							| 
									
										
										
										
											2024-10-08 22:10:18 +00:00
										 |  |  |     load_jsonl_into_ds, | 
					
						
							|  |  |  |     list_dataset_files | 
					
						
							| 
									
										
										
										
											2024-09-18 22:52:42 +00:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-26 20:44:47 +00:00
										 |  |  | from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | class TestBatchQueryResponseDataset(unittest.TestCase): | 
					
						
							|  |  |  |     def testLoadS3(self): | 
					
						
							| 
									
										
										
										
											2024-10-07 07:49:16 -07:00
										 |  |  |         ds = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3) | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         print(f"Loaded {len(ds)} entries") | 
					
						
							|  |  |  |         print(ds) | 
					
						
							|  |  |  |         print(ds["train"]) | 
					
						
							| 
									
										
										
										
											2024-09-18 22:52:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  |     def testCombinedQueryResponse(self): | 
					
						
							| 
									
										
										
										
											2024-09-18 22:52:42 +00:00
										 |  |  |         ds = build_batch_query_response_vision_dataset( | 
					
						
							| 
									
										
										
										
											2024-10-08 22:10:18 +00:00
										 |  |  |             query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl", | 
					
						
							|  |  |  |             response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json", | 
					
						
							| 
									
										
										
										
											2024-09-18 22:52:42 +00:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         print(ds) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  |         processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | 
					
						
							|  |  |  |         from pdelfin.train.dataprep import filter_by_max_seq_len | 
					
						
							|  |  |  |         ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         print(ds[0]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-07 12:59:27 -07:00
										 |  |  |     def testLocalDS(self): | 
					
						
							|  |  |  |         ds = build_batch_query_response_vision_dataset( | 
					
						
							|  |  |  |             query_glob_path="/root/openai_batch_data_v5_1_train/*.jsonl", | 
					
						
							|  |  |  |             response_glob_path="/root/openai_batch_data_v5_1_train_done/*.json", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         print(ds) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ds.to_parquet("/root/trainds_parquet/bigds.parquet") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | 
					
						
							|  |  |  |         from pdelfin.train.dataprep import filter_by_max_seq_len | 
					
						
							|  |  |  |         ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         print(ds[0]) | 
					
						
							| 
									
										
										
										
											2024-10-07 07:49:16 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  |     def testPlotSequenceLengthHistogram(self): | 
					
						
							|  |  |  |         import plotly.express as px   | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ds = build_batch_query_response_vision_dataset( | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  |             query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl", | 
					
						
							|  |  |  |             response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json", | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  |         ) | 
					
						
							|  |  |  |         processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  |         initial_len = len(ds) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-27 15:48:56 +00:00
										 |  |  |         from pdelfin.train.dataprep import filter_by_max_seq_len | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  |         ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=2200, max_response_len=2200)) | 
					
						
							| 
									
										
										
										
											2024-09-27 15:48:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  |         formatted_dataset = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  |         train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=30, shuffle=False) | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |         max_seen_len = 0 | 
					
						
							|  |  |  |         steps = 0 | 
					
						
							|  |  |  |         sequence_lengths = []  # List to store sequence lengths | 
					
						
							|  |  |  |         for entry in tqdm(train_dataloader): | 
					
						
							|  |  |  |             num_input_tokens = entry["input_ids"].shape[1] | 
					
						
							|  |  |  |             max_seen_len = max(max_seen_len, num_input_tokens) | 
					
						
							|  |  |  |             sequence_lengths.append(num_input_tokens)  # Collecting sequence lengths | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if steps % 100 == 0: | 
					
						
							|  |  |  |                 print(f"Max input len {max_seen_len}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             steps += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()}) | 
					
						
							|  |  |  |         print(f"Max input len {max_seen_len}") | 
					
						
							| 
									
										
										
										
											2024-10-02 22:45:40 +00:00
										 |  |  |         print(f"Total elements before filtering: {initial_len}") | 
					
						
							|  |  |  |         print(f"Total elements after filtering: {steps}") | 
					
						
							| 
									
										
										
										
											2024-09-25 09:05:11 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Plotting the histogram using Plotly | 
					
						
							|  |  |  |         fig = px.histogram( | 
					
						
							|  |  |  |             sequence_lengths, | 
					
						
							|  |  |  |             nbins=100, | 
					
						
							|  |  |  |             title="Distribution of Input Sequence Lengths", | 
					
						
							|  |  |  |             labels={'value': 'Sequence Length', 'count': 'Frequency'} | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         fig.write_image("sequence_lengths_histogram.png") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  |     def testExtractBatch(self): | 
					
						
							| 
									
										
										
										
											2024-10-07 07:49:16 -07:00
										 |  |  |         query_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", first_n_files=3) | 
					
						
							| 
									
										
										
										
											2024-09-18 21:42:09 +00:00
										 |  |  |         query_data = query_data["train"] | 
					
						
							|  |  |  |         query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         print(query_data) | 
					
						
							| 
									
										
										
										
											2024-09-18 22:48:38 +00:00
										 |  |  |         print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def testExtractResponse(self): | 
					
						
							| 
									
										
										
										
											2024-10-07 07:49:16 -07:00
										 |  |  |         response_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", first_n_files=3) | 
					
						
							| 
									
										
										
										
											2024-09-18 22:48:38 +00:00
										 |  |  |         response_data = response_data["train"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         print(response_data) | 
					
						
							|  |  |  |         print(response_data[0]) | 
					
						
							| 
									
										
										
										
											2024-09-26 19:55:54 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-08 22:10:18 +00:00
										 |  |  |     def testPyArrowDirectJson(self): | 
					
						
							| 
									
										
										
										
											2024-10-09 18:11:18 +00:00
										 |  |  |         query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl" | 
					
						
							|  |  |  |         response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json" | 
					
						
							| 
									
										
										
										
											2024-10-08 22:10:18 +00:00
										 |  |  |          | 
					
						
							|  |  |  |         all_files = list_dataset_files(query_glob_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         import pyarrow as pa | 
					
						
							|  |  |  |         import pyarrow.json as paj | 
					
						
							|  |  |  |         import pyarrow.compute as pc | 
					
						
							|  |  |  |         import pyarrow.fs as fs | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         s3 = fs.S3FileSystem() | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         block_size = 200 * 1024**2 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for file in all_files: | 
					
						
							|  |  |  |             with s3.open_input_stream(file.replace("s3://", "")) as f: | 
					
						
							|  |  |  |                 table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size)) | 
					
						
							| 
									
										
										
										
											2024-09-26 20:44:47 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-08 22:10:18 +00:00
										 |  |  |                 print(f"file {file} rows {table.num_rows}") | 
					
						
							|  |  |  |                 print(table.schema) | 
					
						
							| 
									
										
										
										
											2024-09-26 20:44:47 +00:00
										 |  |  | 
 |