| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  | import unittest | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 20:42:19 +00:00
										 |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  | import requests | 
					
						
							| 
									
										
										
										
											2025-01-29 15:25:10 -08:00
										 |  |  | from PIL import Image | 
					
						
							|  |  |  | from transformers import ( | 
					
						
							|  |  |  |     AutoModelForCausalLM, | 
					
						
							|  |  |  |     AutoProcessor, | 
					
						
							|  |  |  |     AutoTokenizer, | 
					
						
							|  |  |  |     GenerationConfig, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 20:42:19 +00:00
										 |  |  | @pytest.mark.nonci | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  | class MolmoProcessorTest(unittest.TestCase): | 
					
						
							|  |  |  |     def test_molmo_demo(self): | 
					
						
							|  |  |  |         # load the processor | 
					
						
							|  |  |  |         processor = AutoProcessor.from_pretrained( | 
					
						
							| 
									
										
										
										
											2025-01-29 15:30:39 -08:00
										 |  |  |             "allenai/Molmo-7B-O-0924", | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  |             trust_remote_code=True, | 
					
						
							| 
									
										
										
										
											2025-01-29 15:30:39 -08:00
										 |  |  |             torch_dtype="auto", | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # load the model | 
					
						
							|  |  |  |         model = AutoModelForCausalLM.from_pretrained( | 
					
						
							| 
									
										
										
										
											2025-01-29 15:30:39 -08:00
										 |  |  |             "allenai/Molmo-7B-O-0924", | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  |             trust_remote_code=True, | 
					
						
							| 
									
										
										
										
											2025-01-29 15:30:39 -08:00
										 |  |  |             torch_dtype="auto", | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         device = "cuda:0" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model = model.to(device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # process the image and text | 
					
						
							| 
									
										
										
										
											2025-01-29 15:30:39 -08:00
										 |  |  |         inputs = processor.process(images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)], text="Describe this image.") | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # move inputs to the correct device and make a batch of size 1 | 
					
						
							|  |  |  |         inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         print("Raw inputs") | 
					
						
							|  |  |  |         print(inputs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         print("\nShapes") | 
					
						
							| 
									
										
										
										
											2025-01-23 10:58:43 -08:00
										 |  |  |         # {('input_ids', torch.Size([1, 589])), ('images', torch.Size([1, 5, 576, 588])), ('image_masks', torch.Size([1, 5, 576])), ('image_input_idx', torch.Size([1, 5, 144]))} | 
					
						
							| 
									
										
										
										
											2025-01-29 15:30:39 -08:00
										 |  |  |         print({(x, y.shape) for x, y in inputs.items()}) | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         print("\nTokens") | 
					
						
							|  |  |  |         print(processor.tokenizer.batch_decode(inputs["input_ids"])) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # generate output; maximum 200 new tokens; stop generation when <|endoftext|> is generated | 
					
						
							| 
									
										
										
										
											2025-01-29 15:30:39 -08:00
										 |  |  |         output = model.generate_from_batch(inputs, GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"), tokenizer=processor.tokenizer) | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # only get generated tokens; decode them to text | 
					
						
							| 
									
										
										
										
											2025-01-29 15:30:39 -08:00
										 |  |  |         generated_tokens = output[0, inputs["input_ids"].size(1) :] | 
					
						
							| 
									
										
										
										
											2025-01-22 15:23:08 -08:00
										 |  |  |         generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # print the generated text | 
					
						
							| 
									
										
										
										
											2025-01-29 15:30:39 -08:00
										 |  |  |         print(generated_text) |