| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | import math | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import modules.scripts as scripts | 
					
						
							|  |  |  | import gradio as gr | 
					
						
							|  |  |  | from PIL import Image, ImageDraw | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-11 23:24:24 +03:00
										 |  |  | from modules import images, processing, devices | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | from modules.processing import Processed, process_images | 
					
						
							|  |  |  | from modules.shared import opts, cmd_opts, state | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Script(scripts.Script): | 
					
						
							|  |  |  |     def title(self): | 
					
						
							|  |  |  |         return "Poor man's outpainting" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def show(self, is_img2img): | 
					
						
							|  |  |  |         return is_img2img | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def ui(self, is_img2img): | 
					
						
							|  |  |  |         if not is_img2img: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-07 17:00:51 +03:00
										 |  |  |         pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) | 
					
						
							| 
									
										
										
										
											2022-11-04 08:38:11 +03:00
										 |  |  |         mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) | 
					
						
							|  |  |  |         inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index") | 
					
						
							| 
									
										
										
										
											2022-09-06 14:21:10 +03:00
										 |  |  |         direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-06 14:21:10 +03:00
										 |  |  |         return [pixels, mask_blur, inpainting_fill, direction] | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-06 14:21:10 +03:00
										 |  |  |     def run(self, p, pixels, mask_blur, inpainting_fill, direction): | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  |         initial_seed = None | 
					
						
							|  |  |  |         initial_info = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-07 17:00:51 +03:00
										 |  |  |         p.mask_blur = mask_blur * 2 | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  |         p.inpainting_fill = inpainting_fill | 
					
						
							|  |  |  |         p.inpaint_full_res = False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-06 14:21:10 +03:00
										 |  |  |         left = pixels if "left" in direction else 0 | 
					
						
							|  |  |  |         right = pixels if "right" in direction else 0 | 
					
						
							|  |  |  |         up = pixels if "up" in direction else 0 | 
					
						
							|  |  |  |         down = pixels if "down" in direction else 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  |         init_img = p.init_images[0] | 
					
						
							| 
									
										
										
										
											2022-09-06 14:21:10 +03:00
										 |  |  |         target_w = math.ceil((init_img.width + left + right) / 64) * 64 | 
					
						
							|  |  |  |         target_h = math.ceil((init_img.height + up + down) / 64) * 64 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if left > 0: | 
					
						
							|  |  |  |             left = left * (target_w - init_img.width) // (left + right) | 
					
						
							| 
									
										
										
										
											2022-09-07 19:22:45 +03:00
										 |  |  |         if right > 0: | 
					
						
							|  |  |  |             right = target_w - init_img.width - left | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-06 14:21:10 +03:00
										 |  |  |         if up > 0: | 
					
						
							|  |  |  |             up = up * (target_h - init_img.height) // (up + down) | 
					
						
							| 
									
										
										
										
											2022-09-07 19:22:45 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if down > 0: | 
					
						
							|  |  |  |             down = target_h - init_img.height - up | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         img = Image.new("RGB", (target_w, target_h)) | 
					
						
							| 
									
										
										
										
											2022-09-06 14:21:10 +03:00
										 |  |  |         img.paste(init_img, (left, up)) | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         mask = Image.new("L", (img.width, img.height), "white") | 
					
						
							|  |  |  |         draw = ImageDraw.Draw(mask) | 
					
						
							| 
									
										
										
										
											2022-09-06 14:21:10 +03:00
										 |  |  |         draw.rectangle(( | 
					
						
							|  |  |  |             left + (mask_blur * 2 if left > 0 else 0), | 
					
						
							|  |  |  |             up + (mask_blur * 2 if up > 0 else 0), | 
					
						
							|  |  |  |             mask.width - right - (mask_blur * 2 if right > 0 else 0), | 
					
						
							|  |  |  |             mask.height - down - (mask_blur * 2 if down > 0 else 0) | 
					
						
							|  |  |  |         ), fill="black") | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         latent_mask = Image.new("L", (img.width, img.height), "white") | 
					
						
							|  |  |  |         latent_draw = ImageDraw.Draw(latent_mask) | 
					
						
							| 
									
										
										
										
											2022-09-07 17:00:51 +03:00
										 |  |  |         latent_draw.rectangle(( | 
					
						
							|  |  |  |              left + (mask_blur//2 if left > 0 else 0), | 
					
						
							|  |  |  |              up + (mask_blur//2 if up > 0 else 0), | 
					
						
							|  |  |  |              mask.width - right - (mask_blur//2 if right > 0 else 0), | 
					
						
							|  |  |  |              mask.height - down - (mask_blur//2 if down > 0 else 0) | 
					
						
							|  |  |  |         ), fill="black") | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-11 23:24:24 +03:00
										 |  |  |         devices.torch_gc() | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) | 
					
						
							|  |  |  |         grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) | 
					
						
							| 
									
										
										
										
											2022-09-07 17:00:51 +03:00
										 |  |  |         grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels) | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         p.n_iter = 1 | 
					
						
							|  |  |  |         p.batch_size = 1 | 
					
						
							|  |  |  |         p.do_not_save_grid = True | 
					
						
							|  |  |  |         p.do_not_save_samples = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         work = [] | 
					
						
							|  |  |  |         work_mask = [] | 
					
						
							|  |  |  |         work_latent_mask = [] | 
					
						
							|  |  |  |         work_results = [] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-07 17:00:51 +03:00
										 |  |  |         for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  |             for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): | 
					
						
							| 
									
										
										
										
											2022-09-07 17:00:51 +03:00
										 |  |  |                 x, w = tiledata[0:2] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  |                 work.append(tiledata[2]) | 
					
						
							|  |  |  |                 work_mask.append(tiledata_mask[2]) | 
					
						
							|  |  |  |                 work_latent_mask.append(tiledata_latent_mask[2]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         batch_count = len(work) | 
					
						
							|  |  |  |         print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-06 02:09:01 +03:00
										 |  |  |         state.job_count = batch_count | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  |         for i in range(batch_count): | 
					
						
							|  |  |  |             p.init_images = [work[i]] | 
					
						
							|  |  |  |             p.image_mask = work_mask[i] | 
					
						
							|  |  |  |             p.latent_mask = work_latent_mask[i] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             state.job = f"Batch {i + 1} out of {batch_count}" | 
					
						
							|  |  |  |             processed = process_images(p) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if initial_seed is None: | 
					
						
							|  |  |  |                 initial_seed = processed.seed | 
					
						
							|  |  |  |                 initial_info = processed.info | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             p.seed = processed.seed + 1 | 
					
						
							|  |  |  |             work_results += processed.images | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-06 02:09:01 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  |         image_index = 0 | 
					
						
							|  |  |  |         for y, h, row in grid.tiles: | 
					
						
							|  |  |  |             for tiledata in row: | 
					
						
							| 
									
										
										
										
											2022-09-07 17:00:51 +03:00
										 |  |  |                 x, w = tiledata[0:2] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  |                 tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) | 
					
						
							|  |  |  |                 image_index += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         combined_image = images.combine_grid(grid) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if opts.samples_save: | 
					
						
							| 
									
										
										
										
											2022-09-12 15:41:30 +03:00
										 |  |  |             images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p) | 
					
						
							| 
									
										
										
										
											2022-09-04 01:29:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         processed = Processed(p, [combined_image], initial_seed, initial_info) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return processed | 
					
						
							|  |  |  | 
 |