| 
									
										
										
										
											2022-11-01 19:13:59 +03:00
										 |  |  | import base64 | 
					
						
							|  |  |  | import io | 
					
						
							| 
									
										
										
										
											2023-07-15 07:44:37 +03:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2022-10-30 03:45:29 +08:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2023-01-03 09:45:16 -05:00
										 |  |  | import datetime | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | import uvicorn | 
					
						
							| 
									
										
										
										
											2023-08-20 21:41:27 +08:00
										 |  |  | import ipaddress | 
					
						
							|  |  |  | import requests | 
					
						
							| 
									
										
										
										
											2023-03-25 14:16:35 -04:00
										 |  |  | import gradio as gr | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  | from threading import Lock | 
					
						
							| 
									
										
										
										
											2022-11-23 17:43:58 +08:00
										 |  |  | from io import BytesIO | 
					
						
							| 
									
										
										
										
											2023-03-15 15:11:04 -04:00
										 |  |  | from fastapi import APIRouter, Depends, FastAPI, Request, Response | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  | from fastapi.security import HTTPBasic, HTTPBasicCredentials | 
					
						
							| 
									
										
										
										
											2023-03-15 15:11:04 -04:00
										 |  |  | from fastapi.exceptions import HTTPException | 
					
						
							|  |  |  | from fastapi.responses import JSONResponse | 
					
						
							|  |  |  | from fastapi.encoders import jsonable_encoder | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  | from secrets import compare_digest | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | import modules.shared as shared | 
					
						
							| 
									
										
										
										
											2023-10-15 09:41:02 +03:00
										 |  |  | from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  | from modules.api import models | 
					
						
							|  |  |  | from modules.shared import opts | 
					
						
							| 
									
										
										
										
											2022-10-21 19:27:40 -04:00
										 |  |  | from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  | from modules.textual_inversion.textual_inversion import create_embedding, train_embedding | 
					
						
							|  |  |  | from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork | 
					
						
							| 
									
										
										
										
											2023-10-15 09:41:02 +03:00
										 |  |  | from PIL import PngImagePlugin, Image | 
					
						
							| 
									
										
										
										
											2023-01-27 11:28:12 +03:00
										 |  |  | from modules.sd_models_config import find_checkpoint_config_near_filename | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  | from modules.realesrgan_model import get_realesrgan_models | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  | from modules import devices | 
					
						
							| 
									
										
										
										
											2023-08-25 01:58:19 -06:00
										 |  |  | from typing import Any | 
					
						
							| 
									
										
										
										
											2023-01-23 10:10:59 -05:00
										 |  |  | import piexif | 
					
						
							|  |  |  | import piexif.helper | 
					
						
							| 
									
										
										
										
											2023-07-03 20:17:47 +03:00
										 |  |  | from contextlib import closing | 
					
						
							| 
									
										
										
										
											2023-05-10 11:19:16 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-05 21:21:48 +00:00
										 |  |  | def script_name_to_index(name, scripts): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return [script.title().lower() for script in scripts].index(name.lower()) | 
					
						
							| 
									
										
										
										
											2023-05-10 11:19:16 +03:00
										 |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-30 09:10:22 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-19 12:01:51 +03:00
										 |  |  | def validate_sampler_name(name): | 
					
						
							|  |  |  |     config = sd_samplers.all_samplers_map.get(name, None) | 
					
						
							|  |  |  |     if config is None: | 
					
						
							|  |  |  |         raise HTTPException(status_code=404, detail="Sampler not found") | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-19 12:01:51 +03:00
										 |  |  |     return name | 
					
						
							| 
									
										
										
										
											2022-10-21 19:27:40 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 11:19:16 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-23 16:01:16 -03:00
										 |  |  | def setUpscalers(req: dict): | 
					
						
							|  |  |  |     reqDict = vars(req) | 
					
						
							| 
									
										
										
										
											2023-01-23 09:24:43 +03:00
										 |  |  |     reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) | 
					
						
							|  |  |  |     reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) | 
					
						
							| 
									
										
										
										
											2022-10-23 16:01:16 -03:00
										 |  |  |     return reqDict | 
					
						
							| 
									
										
										
										
											2022-10-27 15:20:15 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 11:19:16 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-21 07:38:07 +03:00
										 |  |  | def verify_url(url): | 
					
						
							|  |  |  |     """Returns True if the url refers to a global resource.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     import socket | 
					
						
							|  |  |  |     from urllib.parse import urlparse | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         parsed_url = urlparse(url) | 
					
						
							|  |  |  |         domain_name = parsed_url.netloc | 
					
						
							|  |  |  |         host = socket.gethostbyname_ex(domain_name) | 
					
						
							|  |  |  |         for ip in host[2]: | 
					
						
							|  |  |  |             ip_addr = ipaddress.ip_address(ip) | 
					
						
							|  |  |  |             if not ip_addr.is_global: | 
					
						
							|  |  |  |                 return False | 
					
						
							|  |  |  |     except Exception: | 
					
						
							|  |  |  |         return False | 
					
						
							| 
									
										
										
										
											2023-08-20 21:41:27 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-21 07:38:07 +03:00
										 |  |  |     return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def decode_base64_to_image(encoding): | 
					
						
							| 
									
										
										
										
											2023-08-19 12:19:21 +08:00
										 |  |  |     if encoding.startswith("http://") or encoding.startswith("https://"): | 
					
						
							| 
									
										
										
										
											2023-08-21 07:38:07 +03:00
										 |  |  |         if not opts.api_enable_requests: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail="Requests not allowed") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if opts.api_forbid_local_requests and not verify_url(encoding): | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail="Request to local resource not allowed") | 
					
						
							| 
									
										
										
										
											2023-08-20 21:41:27 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-21 07:38:07 +03:00
										 |  |  |         headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {} | 
					
						
							|  |  |  |         response = requests.get(encoding, timeout=30, headers=headers) | 
					
						
							| 
									
										
										
										
											2023-08-19 12:19:21 +08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             image = Image.open(BytesIO(response.content)) | 
					
						
							|  |  |  |             return image | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail="Invalid image url") from e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-24 13:10:40 +08:00
										 |  |  |     if encoding.startswith("data:image/"): | 
					
						
							|  |  |  |         encoding = encoding.split(";")[1].split(",")[1] | 
					
						
							| 
									
										
										
										
											2023-01-23 17:11:22 -05:00
										 |  |  |     try: | 
					
						
							|  |  |  |         image = Image.open(BytesIO(base64.b64decode(encoding))) | 
					
						
							|  |  |  |         return image | 
					
						
							| 
									
										
										
										
											2023-05-10 11:19:16 +03:00
										 |  |  |     except Exception as e: | 
					
						
							|  |  |  |         raise HTTPException(status_code=500, detail="Invalid encoded image") from e | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-01 19:13:59 +03:00
										 |  |  | def encode_pil_to_base64(image): | 
					
						
							| 
									
										
										
										
											2022-11-02 22:37:45 +08:00
										 |  |  |     with io.BytesIO() as output_bytes: | 
					
						
							| 
									
										
										
										
											2023-10-01 18:06:48 +03:00
										 |  |  |         if isinstance(image, str): | 
					
						
							|  |  |  |             return image | 
					
						
							| 
									
										
										
										
											2023-01-23 10:10:59 -05:00
										 |  |  |         if opts.samples_format.lower() == 'png': | 
					
						
							|  |  |  |             use_metadata = False | 
					
						
							|  |  |  |             metadata = PngImagePlugin.PngInfo() | 
					
						
							|  |  |  |             for key, value in image.info.items(): | 
					
						
							|  |  |  |                 if isinstance(key, str) and isinstance(value, str): | 
					
						
							|  |  |  |                     metadata.add_text(key, value) | 
					
						
							|  |  |  |                     use_metadata = True | 
					
						
							|  |  |  |             image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): | 
					
						
							| 
									
										
										
										
											2023-07-06 18:43:17 +08:00
										 |  |  |             if image.mode == "RGBA": | 
					
						
							|  |  |  |                 image = image.convert("RGB") | 
					
						
							| 
									
										
										
										
											2023-01-23 10:10:59 -05:00
										 |  |  |             parameters = image.info.get('parameters', None) | 
					
						
							|  |  |  |             exif_bytes = piexif.dump({ | 
					
						
							|  |  |  |                 "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } | 
					
						
							|  |  |  |             }) | 
					
						
							|  |  |  |             if opts.samples_format.lower() in ("jpg", "jpeg"): | 
					
						
							|  |  |  |                 image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail="Invalid image format") | 
					
						
							| 
									
										
										
										
											2022-11-02 22:37:45 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         bytes_data = output_bytes.getvalue() | 
					
						
							| 
									
										
										
										
											2023-01-23 10:10:59 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-02 22:37:45 +08:00
										 |  |  |     return base64.b64encode(bytes_data) | 
					
						
							| 
									
										
										
										
											2022-11-01 19:13:59 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 11:19:16 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-03 10:58:52 -05:00
										 |  |  | def api_middleware(app: FastAPI): | 
					
						
							| 
									
										
										
										
											2023-07-15 07:44:37 +03:00
										 |  |  |     rich_available = False | 
					
						
							| 
									
										
										
										
											2023-03-15 15:11:04 -04:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2023-07-15 07:44:37 +03:00
										 |  |  |         if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None: | 
					
						
							|  |  |  |             import anyio  # importing just so it can be placed on silent list | 
					
						
							|  |  |  |             import starlette  # importing just so it can be placed on silent list | 
					
						
							|  |  |  |             from rich.console import Console | 
					
						
							|  |  |  |             console = Console() | 
					
						
							|  |  |  |             rich_available = True | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |     except Exception: | 
					
						
							| 
									
										
										
										
											2023-07-15 07:44:37 +03:00
										 |  |  |         pass | 
					
						
							| 
									
										
										
										
											2023-03-15 15:11:04 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-03 09:45:16 -05:00
										 |  |  |     @app.middleware("http") | 
					
						
							|  |  |  |     async def log_and_time(req: Request, call_next): | 
					
						
							|  |  |  |         ts = time.time() | 
					
						
							|  |  |  |         res: Response = await call_next(req) | 
					
						
							|  |  |  |         duration = str(round(time.time() - ts, 4)) | 
					
						
							|  |  |  |         res.headers["X-Process-Time"] = duration | 
					
						
							| 
									
										
										
										
											2023-01-03 10:58:52 -05:00
										 |  |  |         endpoint = req.scope.get('path', 'err') | 
					
						
							|  |  |  |         if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'): | 
					
						
							|  |  |  |             print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( | 
					
						
							| 
									
										
										
										
											2023-07-15 07:44:37 +03:00
										 |  |  |                 t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), | 
					
						
							|  |  |  |                 code=res.status_code, | 
					
						
							|  |  |  |                 ver=req.scope.get('http_version', '0.0'), | 
					
						
							|  |  |  |                 cli=req.scope.get('client', ('0:0.0.0', 0))[0], | 
					
						
							|  |  |  |                 prot=req.scope.get('scheme', 'err'), | 
					
						
							|  |  |  |                 method=req.scope.get('method', 'err'), | 
					
						
							|  |  |  |                 endpoint=endpoint, | 
					
						
							|  |  |  |                 duration=duration, | 
					
						
							| 
									
										
										
										
											2023-01-03 09:45:16 -05:00
										 |  |  |             )) | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-15 15:11:04 -04:00
										 |  |  |     def handle_exception(request: Request, e: Exception): | 
					
						
							|  |  |  |         err = { | 
					
						
							|  |  |  |             "error": type(e).__name__, | 
					
						
							|  |  |  |             "detail": vars(e).get('detail', ''), | 
					
						
							|  |  |  |             "body": vars(e).get('body', ''), | 
					
						
							|  |  |  |             "errors": str(e), | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2023-07-15 07:44:37 +03:00
										 |  |  |         if not isinstance(e, HTTPException):  # do not print backtrace on known httpexceptions | 
					
						
							| 
									
										
										
										
											2023-05-29 08:54:13 +03:00
										 |  |  |             message = f"API error: {request.method}: {request.url} {err}" | 
					
						
							| 
									
										
										
										
											2023-03-15 15:11:04 -04:00
										 |  |  |             if rich_available: | 
					
						
							| 
									
										
										
										
											2023-05-29 08:54:13 +03:00
										 |  |  |                 print(message) | 
					
						
							| 
									
										
										
										
											2023-03-15 15:11:04 -04:00
										 |  |  |                 console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) | 
					
						
							|  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2023-05-31 19:56:37 +03:00
										 |  |  |                 errors.report(message, exc_info=True) | 
					
						
							| 
									
										
										
										
											2023-03-15 15:11:04 -04:00
										 |  |  |         return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @app.middleware("http") | 
					
						
							|  |  |  |     async def exception_handling(request: Request, call_next): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             return await call_next(request) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             return handle_exception(request, e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @app.exception_handler(Exception) | 
					
						
							|  |  |  |     async def fastapi_exception_handler(request: Request, e: Exception): | 
					
						
							|  |  |  |         return handle_exception(request, e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @app.exception_handler(HTTPException) | 
					
						
							|  |  |  |     async def http_exception_handler(request: Request, e: HTTPException): | 
					
						
							|  |  |  |         return handle_exception(request, e) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-01 19:13:59 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | class Api: | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  |     def __init__(self, app: FastAPI, queue_lock: Lock): | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  |         if shared.cmd_opts.api_auth: | 
					
						
							| 
									
										
										
										
											2023-05-10 11:55:09 +03:00
										 |  |  |             self.credentials = {} | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  |             for auth in shared.cmd_opts.api_auth.split(","): | 
					
						
							|  |  |  |                 user, password = auth.split(":") | 
					
						
							| 
									
										
										
										
											2022-12-14 21:01:32 -05:00
										 |  |  |                 self.credentials[user] = password | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  |         self.router = APIRouter() | 
					
						
							| 
									
										
										
										
											2022-10-18 06:51:53 +00:00
										 |  |  |         self.app = app | 
					
						
							|  |  |  |         self.queue_lock = queue_lock | 
					
						
							| 
									
										
										
										
											2023-01-04 06:36:57 -05:00
										 |  |  |         api_middleware(self.app) | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  |         self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) | 
					
						
							| 
									
										
										
										
											2022-11-19 20:13:07 +08:00
										 |  |  |         self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  |         self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) | 
					
						
							| 
									
										
										
										
											2023-08-25 01:58:19 -06:00
										 |  |  |         self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem]) | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) | 
					
						
							| 
									
										
										
										
											2022-12-11 19:16:44 +00:00
										 |  |  |         self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) | 
					
						
							| 
									
										
										
										
											2023-07-24 19:45:08 +08:00
										 |  |  |         self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) | 
					
						
							| 
									
										
										
										
											2023-03-09 07:56:19 +03:00
										 |  |  |         self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) | 
					
						
							| 
									
										
										
										
											2023-08-25 01:58:19 -06:00
										 |  |  |         self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo]) | 
					
						
							|  |  |  |         self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem]) | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-29 14:21:28 +09:00
										 |  |  |         if shared.cmd_opts.api_server_stop: | 
					
						
							| 
									
										
										
										
											2023-06-14 18:51:47 +09:00
										 |  |  |             self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) | 
					
						
							|  |  |  |             self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) | 
					
						
							| 
									
										
										
										
											2023-06-14 19:53:08 +09:00
										 |  |  |             self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"]) | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-25 14:16:35 -04:00
										 |  |  |         self.default_script_arg_txt2img = [] | 
					
						
							|  |  |  |         self.default_script_arg_img2img = [] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  |     def add_api_route(self, path: str, endpoint, **kwargs): | 
					
						
							|  |  |  |         if shared.cmd_opts.api_auth: | 
					
						
							|  |  |  |             return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) | 
					
						
							|  |  |  |         return self.app.add_api_route(path, endpoint, **kwargs) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-14 21:01:32 -05:00
										 |  |  |     def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())): | 
					
						
							|  |  |  |         if credentials.username in self.credentials: | 
					
						
							|  |  |  |             if compare_digest(credentials.password, self.credentials[credentials.username]): | 
					
						
							| 
									
										
										
										
											2022-11-15 16:12:34 +08:00
										 |  |  |                 return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  |     def get_selectable_script(self, script_name, script_runner): | 
					
						
							|  |  |  |         if script_name is None or script_name == "": | 
					
						
							| 
									
										
										
										
											2023-01-08 16:14:38 +03:00
										 |  |  |             return None, None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) | 
					
						
							|  |  |  |         script = script_runner.selectable_scripts[script_idx] | 
					
						
							|  |  |  |         return script, script_idx | 
					
						
							| 
									
										
										
										
											2023-05-11 18:28:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-04 11:46:07 +08:00
										 |  |  |     def get_scripts_list(self): | 
					
						
							| 
									
										
										
										
											2023-05-17 22:43:24 +03:00
										 |  |  |         t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None] | 
					
						
							|  |  |  |         i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None] | 
					
						
							| 
									
										
										
										
											2023-03-04 11:46:07 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist) | 
					
						
							| 
									
										
										
										
											2023-01-07 14:21:31 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-17 22:43:24 +03:00
										 |  |  |     def get_script_info(self): | 
					
						
							|  |  |  |         res = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]: | 
					
						
							|  |  |  |             res += [script.api_info for script in script_list if script.api_info is not None] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  |     def get_script(self, script_name, script_runner): | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |         if script_name is None or script_name == "": | 
					
						
							|  |  |  |             return None, None | 
					
						
							| 
									
										
										
										
											2023-05-11 18:28:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |         script_idx = script_name_to_index(script_name, script_runner.scripts) | 
					
						
							|  |  |  |         return script_runner.scripts[script_idx] | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-25 14:16:35 -04:00
										 |  |  |     def init_default_script_args(self, script_runner): | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  |         #find max idx from the scripts in runner and generate a none array to init script_args | 
					
						
							|  |  |  |         last_arg_index = 1 | 
					
						
							|  |  |  |         for script in script_runner.scripts: | 
					
						
							|  |  |  |             if last_arg_index < script.args_to: | 
					
						
							|  |  |  |                 last_arg_index = script.args_to | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |         # None everywhere except position 0 to initialize script args | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  |         script_args = [None]*last_arg_index | 
					
						
							| 
									
										
										
										
											2023-03-25 14:16:35 -04:00
										 |  |  |         script_args[0] = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # get default values | 
					
						
							|  |  |  |         with gr.Blocks(): # will throw errors calling ui function without this | 
					
						
							|  |  |  |             for script in script_runner.scripts: | 
					
						
							|  |  |  |                 if script.ui(script.is_img2img): | 
					
						
							|  |  |  |                     ui_default_values = [] | 
					
						
							|  |  |  |                     for elem in script.ui(script.is_img2img): | 
					
						
							|  |  |  |                         ui_default_values.append(elem.value) | 
					
						
							|  |  |  |                     script_args[script.args_from:script.args_to] = ui_default_values | 
					
						
							|  |  |  |         return script_args | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner): | 
					
						
							|  |  |  |         script_args = default_script_args.copy() | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |         # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() | 
					
						
							|  |  |  |         if selectable_scripts: | 
					
						
							|  |  |  |             script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args | 
					
						
							|  |  |  |             script_args[0] = selectable_idx + 1 | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Now check for always on scripts | 
					
						
							| 
									
										
										
										
											2023-06-02 14:58:10 +03:00
										 |  |  |         if request.alwayson_scripts: | 
					
						
							| 
									
										
										
										
											2023-03-11 12:21:33 -05:00
										 |  |  |             for alwayson_script_name in request.alwayson_scripts.keys(): | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  |                 alwayson_script = self.get_script(alwayson_script_name, script_runner) | 
					
						
							| 
									
										
										
										
											2023-05-10 07:52:45 +03:00
										 |  |  |                 if alwayson_script is None: | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  |                     raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") | 
					
						
							|  |  |  |                 # Selectable script in always on script param check | 
					
						
							| 
									
										
										
										
											2023-05-10 07:52:45 +03:00
										 |  |  |                 if alwayson_script.alwayson is False: | 
					
						
							|  |  |  |                     raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params") | 
					
						
							| 
									
										
										
										
											2023-03-11 12:21:33 -05:00
										 |  |  |                 # always on script with no arg should always run so you don't really need to add them to the requests | 
					
						
							|  |  |  |                 if "args" in request.alwayson_scripts[alwayson_script_name]: | 
					
						
							| 
									
										
										
										
											2023-03-28 23:52:51 -04:00
										 |  |  |                     # min between arg length in scriptrunner and arg length in the request | 
					
						
							|  |  |  |                     for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))): | 
					
						
							|  |  |  |                         script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |         return script_args | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |     def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |         script_runner = scripts.scripts_txt2img | 
					
						
							|  |  |  |         if not script_runner.scripts: | 
					
						
							|  |  |  |             script_runner.initialize_scripts(False) | 
					
						
							|  |  |  |             ui.create_ui() | 
					
						
							| 
									
										
										
										
											2023-03-25 14:16:35 -04:00
										 |  |  |         if not self.default_script_arg_txt2img: | 
					
						
							|  |  |  |             self.default_script_arg_txt2img = self.init_default_script_args(script_runner) | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |         selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-11 13:22:59 +03:00
										 |  |  |         populate = txt2imgreq.copy(update={  # Override __init__ params | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |             "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), | 
					
						
							| 
									
										
										
										
											2023-03-11 13:22:59 +03:00
										 |  |  |             "do_not_save_samples": not txt2imgreq.save_images, | 
					
						
							|  |  |  |             "do_not_save_grid": not txt2imgreq.save_images, | 
					
						
							|  |  |  |         }) | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |         if populate.sampler_name: | 
					
						
							|  |  |  |             populate.sampler_index = None  # prevent a warning later on | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         args = vars(populate) | 
					
						
							|  |  |  |         args.pop('script_name', None) | 
					
						
							|  |  |  |         args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them | 
					
						
							| 
									
										
										
										
											2023-03-11 12:21:33 -05:00
										 |  |  |         args.pop('alwayson_scripts', None) | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-25 14:16:35 -04:00
										 |  |  |         script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) | 
					
						
							| 
									
										
										
										
											2023-01-07 14:21:31 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-11 13:22:59 +03:00
										 |  |  |         send_images = args.pop('send_images', True) | 
					
						
							|  |  |  |         args.pop('save_images', None) | 
					
						
							| 
									
										
										
										
											2023-03-03 09:00:52 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-18 06:51:53 +00:00
										 |  |  |         with self.queue_lock: | 
					
						
							| 
									
										
										
										
											2023-07-03 20:17:47 +03:00
										 |  |  |             with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: | 
					
						
							| 
									
										
										
										
											2023-08-14 10:43:18 +03:00
										 |  |  |                 p.is_api = True | 
					
						
							| 
									
										
										
										
											2023-07-03 20:02:30 +03:00
										 |  |  |                 p.scripts = script_runner | 
					
						
							|  |  |  |                 p.outpath_grids = opts.outdir_txt2img_grids | 
					
						
							|  |  |  |                 p.outpath_samples = opts.outdir_txt2img_samples | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-22 07:03:21 +03:00
										 |  |  |                 try: | 
					
						
							|  |  |  |                     shared.state.begin(job="scripts_txt2img") | 
					
						
							|  |  |  |                     if selectable_scripts is not None: | 
					
						
							|  |  |  |                         p.script_args = script_args | 
					
						
							|  |  |  |                         processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         p.script_args = tuple(script_args) # Need to pass args as tuple here | 
					
						
							|  |  |  |                         processed = process_images(p) | 
					
						
							|  |  |  |                 finally: | 
					
						
							|  |  |  |                     shared.state.end() | 
					
						
							| 
									
										
										
										
											2023-07-28 11:40:10 +08:00
										 |  |  |                     shared.total_tqdm.clear() | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-03 09:00:52 -05:00
										 |  |  |         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] | 
					
						
							| 
									
										
										
										
											2022-10-26 22:33:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |     def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): | 
					
						
							| 
									
										
										
										
											2022-10-21 19:27:40 -04:00
										 |  |  |         init_images = img2imgreq.init_images | 
					
						
							|  |  |  |         if init_images is None: | 
					
						
							| 
									
										
										
										
											2022-10-26 22:33:45 +08:00
										 |  |  |             raise HTTPException(status_code=404, detail="Init image not found") | 
					
						
							| 
									
										
										
										
											2022-10-21 19:27:40 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-22 15:42:00 -04:00
										 |  |  |         mask = img2imgreq.mask | 
					
						
							|  |  |  |         if mask: | 
					
						
							| 
									
										
										
										
											2022-11-24 13:10:40 +08:00
										 |  |  |             mask = decode_base64_to_image(mask) | 
					
						
							| 
									
										
										
										
											2022-10-22 15:42:00 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  |         script_runner = scripts.scripts_img2img | 
					
						
							|  |  |  |         if not script_runner.scripts: | 
					
						
							|  |  |  |             script_runner.initialize_scripts(True) | 
					
						
							|  |  |  |             ui.create_ui() | 
					
						
							| 
									
										
										
										
											2023-03-25 14:16:35 -04:00
										 |  |  |         if not self.default_script_arg_img2img: | 
					
						
							|  |  |  |             self.default_script_arg_img2img = self.init_default_script_args(script_runner) | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |         selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-11 14:34:56 -05:00
										 |  |  |         populate = img2imgreq.copy(update={  # Override __init__ params | 
					
						
							| 
									
										
										
										
											2022-11-27 21:12:37 +08:00
										 |  |  |             "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), | 
					
						
							| 
									
										
										
										
											2023-03-11 13:22:59 +03:00
										 |  |  |             "do_not_save_samples": not img2imgreq.save_images, | 
					
						
							|  |  |  |             "do_not_save_grid": not img2imgreq.save_images, | 
					
						
							|  |  |  |             "mask": mask, | 
					
						
							|  |  |  |         }) | 
					
						
							| 
									
										
										
										
											2022-11-27 21:19:47 +08:00
										 |  |  |         if populate.sampler_name: | 
					
						
							|  |  |  |             populate.sampler_index = None  # prevent a warning later on | 
					
						
							| 
									
										
										
										
											2022-12-03 09:15:24 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         args = vars(populate) | 
					
						
							|  |  |  |         args.pop('include_init_images', None)  # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. | 
					
						
							| 
									
										
										
										
											2023-01-05 21:21:48 +00:00
										 |  |  |         args.pop('script_name', None) | 
					
						
							| 
									
										
										
										
											2023-02-27 23:27:33 -05:00
										 |  |  |         args.pop('script_args', None)  # will refeed them to the pipeline directly after initializing them | 
					
						
							| 
									
										
										
										
											2023-03-11 12:21:33 -05:00
										 |  |  |         args.pop('alwayson_scripts', None) | 
					
						
							| 
									
										
										
										
											2023-02-26 19:17:58 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-25 14:16:35 -04:00
										 |  |  |         script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner) | 
					
						
							| 
									
										
										
										
											2022-10-30 09:10:22 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-11 13:22:59 +03:00
										 |  |  |         send_images = args.pop('send_images', True) | 
					
						
							|  |  |  |         args.pop('save_images', None) | 
					
						
							| 
									
										
										
										
											2023-03-03 09:00:52 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-21 19:27:40 -04:00
										 |  |  |         with self.queue_lock: | 
					
						
							| 
									
										
										
										
											2023-07-03 20:17:47 +03:00
										 |  |  |             with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: | 
					
						
							| 
									
										
										
										
											2023-07-03 20:02:30 +03:00
										 |  |  |                 p.init_images = [decode_base64_to_image(x) for x in init_images] | 
					
						
							| 
									
										
										
										
											2023-08-14 10:43:18 +03:00
										 |  |  |                 p.is_api = True | 
					
						
							| 
									
										
										
										
											2023-07-03 20:02:30 +03:00
										 |  |  |                 p.scripts = script_runner | 
					
						
							|  |  |  |                 p.outpath_grids = opts.outdir_img2img_grids | 
					
						
							|  |  |  |                 p.outpath_samples = opts.outdir_img2img_samples | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-22 07:03:21 +03:00
										 |  |  |                 try: | 
					
						
							|  |  |  |                     shared.state.begin(job="scripts_img2img") | 
					
						
							|  |  |  |                     if selectable_scripts is not None: | 
					
						
							|  |  |  |                         p.script_args = script_args | 
					
						
							|  |  |  |                         processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         p.script_args = tuple(script_args) # Need to pass args as tuple here | 
					
						
							|  |  |  |                         processed = process_images(p) | 
					
						
							|  |  |  |                 finally: | 
					
						
							|  |  |  |                     shared.state.end() | 
					
						
							| 
									
										
										
										
											2023-07-28 11:40:10 +08:00
										 |  |  |                     shared.total_tqdm.clear() | 
					
						
							| 
									
										
										
										
											2022-10-26 22:33:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-03 09:00:52 -05:00
										 |  |  |         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] | 
					
						
							| 
									
										
										
										
											2022-10-21 19:27:40 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-03 09:15:24 +03:00
										 |  |  |         if not img2imgreq.include_init_images: | 
					
						
							| 
									
										
										
										
											2022-10-24 11:16:07 -04:00
										 |  |  |             img2imgreq.init_images = None | 
					
						
							|  |  |  |             img2imgreq.mask = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js()) | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |     def extras_single_image_api(self, req: models.ExtrasSingleImageRequest): | 
					
						
							| 
									
										
										
										
											2022-10-23 16:01:16 -03:00
										 |  |  |         reqDict = setUpscalers(req) | 
					
						
							| 
									
										
										
										
											2022-10-22 23:13:32 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-23 16:01:16 -03:00
										 |  |  |         reqDict['image'] = decode_base64_to_image(reqDict['image']) | 
					
						
							| 
									
										
										
										
											2022-10-22 23:13:32 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         with self.queue_lock: | 
					
						
							| 
									
										
										
										
											2023-01-23 09:24:43 +03:00
										 |  |  |             result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) | 
					
						
							| 
									
										
										
										
											2022-10-22 23:13:32 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) | 
					
						
							| 
									
										
										
										
											2022-10-23 13:07:59 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |     def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest): | 
					
						
							| 
									
										
										
										
											2022-10-23 16:01:16 -03:00
										 |  |  |         reqDict = setUpscalers(req) | 
					
						
							| 
									
										
										
										
											2022-10-23 13:07:59 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-29 09:17:35 +03:00
										 |  |  |         image_list = reqDict.pop('imageList', []) | 
					
						
							|  |  |  |         image_folder = [decode_base64_to_image(x.data) for x in image_list] | 
					
						
							| 
									
										
										
										
											2022-10-23 13:07:59 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         with self.queue_lock: | 
					
						
							| 
									
										
										
										
											2023-04-29 09:17:35 +03:00
										 |  |  |             result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict) | 
					
						
							| 
									
										
										
										
											2022-10-23 13:07:59 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |     def pnginfoapi(self, req: models.PNGInfoRequest): | 
					
						
							| 
									
										
										
										
											2023-01-04 20:36:30 +00:00
										 |  |  |         image = decode_base64_to_image(req.image.strip()) | 
					
						
							|  |  |  |         if image is None: | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             return models.PNGInfoResponse(info="") | 
					
						
							| 
									
										
										
										
											2023-01-04 20:36:30 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         geninfo, items = images.read_info_from_image(image) | 
					
						
							|  |  |  |         if geninfo is None: | 
					
						
							|  |  |  |             geninfo = "" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-26 06:52:18 +03:00
										 |  |  |         params = generation_parameters_copypaste.parse_generation_parameters(geninfo) | 
					
						
							|  |  |  |         script_callbacks.infotext_pasted_callback(geninfo, params) | 
					
						
							| 
									
										
										
										
											2022-10-29 16:09:19 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-26 06:52:18 +03:00
										 |  |  |         return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |     def progressapi(self, req: models.ProgressRequest = Depends()): | 
					
						
							| 
									
										
										
										
											2022-10-26 22:33:45 +08:00
										 |  |  |         # copy from check_progress_call of ui.py | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if shared.state.job_count == 0: | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) | 
					
						
							| 
									
										
										
										
											2022-10-26 22:33:45 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # avoid dividing zero | 
					
						
							|  |  |  |         progress = 0.01 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if shared.state.job_count > 0: | 
					
						
							|  |  |  |             progress += shared.state.job_no / shared.state.job_count | 
					
						
							|  |  |  |         if shared.state.sampling_steps > 0: | 
					
						
							|  |  |  |             progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         time_since_start = time.time() - shared.state.time_start | 
					
						
							|  |  |  |         eta = (time_since_start/progress) | 
					
						
							|  |  |  |         eta_relative = eta-time_since_start | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         progress = min(progress, 1) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-02 12:12:32 +03:00
										 |  |  |         shared.state.set_current_image() | 
					
						
							| 
									
										
										
										
											2022-10-30 17:02:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-30 05:19:17 +08:00
										 |  |  |         current_image = None | 
					
						
							| 
									
										
										
										
											2022-10-30 06:03:32 +08:00
										 |  |  |         if shared.state.current_image and not req.skip_current_image: | 
					
						
							| 
									
										
										
										
											2022-10-30 05:19:17 +08:00
										 |  |  |             current_image = encode_pil_to_base64(shared.state.current_image) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) | 
					
						
							| 
									
										
										
										
											2022-10-26 22:33:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |     def interrogateapi(self, interrogatereq: models.InterrogateRequest): | 
					
						
							| 
									
										
										
										
											2022-10-27 15:20:15 -04:00
										 |  |  |         image_b64 = interrogatereq.image | 
					
						
							|  |  |  |         if image_b64 is None: | 
					
						
							| 
									
										
										
										
											2022-12-14 21:01:32 -05:00
										 |  |  |             raise HTTPException(status_code=404, detail="Image not found") | 
					
						
							| 
									
										
										
										
											2022-10-27 15:20:15 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-07 02:32:06 +08:00
										 |  |  |         img = decode_base64_to_image(image_b64) | 
					
						
							|  |  |  |         img = img.convert('RGB') | 
					
						
							| 
									
										
										
										
											2022-10-27 15:20:15 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Override object param | 
					
						
							|  |  |  |         with self.queue_lock: | 
					
						
							| 
									
										
										
										
											2022-11-07 02:32:06 +08:00
										 |  |  |             if interrogatereq.model == "clip": | 
					
						
							|  |  |  |                 processed = shared.interrogator.interrogate(img) | 
					
						
							|  |  |  |             elif interrogatereq.model == "deepdanbooru": | 
					
						
							| 
									
										
										
										
											2022-11-20 16:39:20 +03:00
										 |  |  |                 processed = deepbooru.model.tag(img) | 
					
						
							| 
									
										
										
										
											2022-11-07 02:32:06 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 raise HTTPException(status_code=404, detail="Model not found") | 
					
						
							| 
									
										
										
										
											2022-12-14 21:01:32 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |         return models.InterrogateResponse(caption=processed) | 
					
						
							| 
									
										
										
										
											2022-10-17 06:58:42 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-30 18:08:40 +08:00
										 |  |  |     def interruptapi(self): | 
					
						
							|  |  |  |         shared.state.interrupt() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return {} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-09 07:56:19 +03:00
										 |  |  |     def unloadapi(self): | 
					
						
							| 
									
										
										
										
											2023-10-15 09:41:02 +03:00
										 |  |  |         sd_models.unload_model_weights() | 
					
						
							| 
									
										
										
										
											2023-03-09 07:56:19 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reloadapi(self): | 
					
						
							| 
									
										
										
										
											2023-10-15 09:41:02 +03:00
										 |  |  |         sd_models.send_model_to_device(shared.sd_model) | 
					
						
							| 
									
										
										
										
											2023-03-09 07:56:19 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return {} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-05 19:05:15 -03:00
										 |  |  |     def skip(self): | 
					
						
							|  |  |  |         shared.state.skip() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  |     def get_config(self): | 
					
						
							|  |  |  |         options = {} | 
					
						
							|  |  |  |         for key in shared.opts.data.keys(): | 
					
						
							|  |  |  |             metadata = shared.opts.data_labels.get(key) | 
					
						
							|  |  |  |             if(metadata is not None): | 
					
						
							|  |  |  |                 options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)}) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 options.update({key: shared.opts.data.get(key, None)}) | 
					
						
							| 
									
										
										
										
											2022-11-05 01:43:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  |         return options | 
					
						
							| 
									
										
										
										
											2022-11-05 01:43:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 01:58:19 -06:00
										 |  |  |     def set_config(self, req: dict[str, Any]): | 
					
						
							| 
									
										
										
										
											2023-06-27 09:26:18 +03:00
										 |  |  |         checkpoint_name = req.get("sd_model_checkpoint", None) | 
					
						
							| 
									
										
										
										
											2023-10-15 09:41:02 +03:00
										 |  |  |         if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: | 
					
						
							| 
									
										
										
										
											2023-06-27 09:26:18 +03:00
										 |  |  |             raise RuntimeError(f"model {checkpoint_name!r} not found") | 
					
						
							| 
									
										
										
										
											2023-06-12 15:22:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-19 15:15:24 +03:00
										 |  |  |         for k, v in req.items(): | 
					
						
							| 
									
										
										
										
											2023-08-21 07:59:57 +03:00
										 |  |  |             shared.opts.set(k, v, is_api=True) | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         shared.opts.save(shared.config_filename) | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_cmd_flags(self): | 
					
						
							|  |  |  |         return vars(shared.cmd_opts) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_samplers(self): | 
					
						
							| 
									
										
										
										
											2022-11-19 15:15:24 +03:00
										 |  |  |         return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_upscalers(self): | 
					
						
							| 
									
										
										
										
											2023-01-24 10:05:45 +03:00
										 |  |  |         return [ | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "name": upscaler.name, | 
					
						
							|  |  |  |                 "model_name": upscaler.scaler.model_name, | 
					
						
							|  |  |  |                 "model_path": upscaler.data_path, | 
					
						
							| 
									
										
										
										
											2023-01-24 10:09:30 +03:00
										 |  |  |                 "model_url": None, | 
					
						
							| 
									
										
										
										
											2023-01-24 10:05:45 +03:00
										 |  |  |                 "scale": upscaler.scale, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             for upscaler in shared.sd_upscalers | 
					
						
							|  |  |  |         ] | 
					
						
							| 
									
										
										
										
											2022-11-05 01:43:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-04 16:59:23 +03:00
										 |  |  |     def get_latent_upscale_modes(self): | 
					
						
							|  |  |  |         return [ | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "name": upscale_mode, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             for upscale_mode in [*(shared.latent_upscale_modes or {})] | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  |     def get_sd_models(self): | 
					
						
							| 
									
										
										
										
											2023-08-17 20:48:17 -05:00
										 |  |  |         import modules.sd_models as sd_models | 
					
						
							|  |  |  |         return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()] | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-29 22:25:43 +01:00
										 |  |  |     def get_sd_vaes(self): | 
					
						
							| 
									
										
										
										
											2023-08-17 20:48:17 -05:00
										 |  |  |         import modules.sd_vae as sd_vae | 
					
						
							|  |  |  |         return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()] | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_hypernetworks(self): | 
					
						
							|  |  |  |         return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_face_restorers(self): | 
					
						
							|  |  |  |         return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_realesrgan_models(self): | 
					
						
							|  |  |  |         return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)] | 
					
						
							| 
									
										
										
										
											2022-11-05 01:43:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-14 21:01:32 -05:00
										 |  |  |     def get_prompt_styles(self): | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  |         styleList = [] | 
					
						
							|  |  |  |         for k in shared.prompt_styles.styles: | 
					
						
							| 
									
										
										
										
											2022-11-05 01:43:02 +08:00
										 |  |  |             style = shared.prompt_styles.styles[k] | 
					
						
							| 
									
										
										
										
											2022-11-22 14:02:59 +00:00
										 |  |  |             styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]}) | 
					
						
							| 
									
										
										
										
											2022-11-03 00:51:22 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return styleList | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-02 10:17:33 +11:00
										 |  |  |     def get_embeddings(self): | 
					
						
							|  |  |  |         db = sd_hijack.model_hijack.embedding_db | 
					
						
							| 
									
										
										
										
											2023-01-02 12:21:22 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  |         def convert_embedding(embedding): | 
					
						
							|  |  |  |             return { | 
					
						
							|  |  |  |                 "step": embedding.step, | 
					
						
							|  |  |  |                 "sd_checkpoint": embedding.sd_checkpoint, | 
					
						
							|  |  |  |                 "sd_checkpoint_name": embedding.sd_checkpoint_name, | 
					
						
							|  |  |  |                 "shape": embedding.shape, | 
					
						
							|  |  |  |                 "vectors": embedding.vectors, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def convert_embeddings(embeddings): | 
					
						
							|  |  |  |             return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-02 10:17:33 +11:00
										 |  |  |         return { | 
					
						
							| 
									
										
										
										
											2023-01-02 12:21:22 +11:00
										 |  |  |             "loaded": convert_embeddings(db.word_embeddings), | 
					
						
							|  |  |  |             "skipped": convert_embeddings(db.skipped_embeddings), | 
					
						
							| 
									
										
										
										
											2023-01-02 10:17:33 +11:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-11 19:16:44 +00:00
										 |  |  |     def refresh_checkpoints(self): | 
					
						
							| 
									
										
										
										
											2023-07-10 23:10:14 +09:00
										 |  |  |         with self.queue_lock: | 
					
						
							|  |  |  |             shared.refresh_checkpoints() | 
					
						
							| 
									
										
										
										
											2022-10-30 18:08:40 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-24 19:45:08 +08:00
										 |  |  |     def refresh_vae(self): | 
					
						
							|  |  |  |         with self.queue_lock: | 
					
						
							|  |  |  |             shared_items.refresh_vae_list() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  |     def create_embedding(self, args: dict): | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2023-06-30 13:11:31 +03:00
										 |  |  |             shared.state.begin(job="create_embedding") | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  |             filename = create_embedding(**args) # create empty embedding | 
					
						
							|  |  |  |             sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             return models.CreateResponse(info=f"create embedding filename: {filename}") | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  |         except AssertionError as e: | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             return models.TrainResponse(info=f"create embedding error: {e}") | 
					
						
							| 
									
										
										
										
											2023-06-30 13:11:49 +03:00
										 |  |  |         finally: | 
					
						
							|  |  |  |             shared.state.end() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def create_hypernetwork(self, args: dict): | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2023-06-30 13:11:31 +03:00
										 |  |  |             shared.state.begin(job="create_hypernetwork") | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  |             filename = create_hypernetwork(**args) # create empty embedding | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             return models.CreateResponse(info=f"create hypernetwork filename: {filename}") | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  |         except AssertionError as e: | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             return models.TrainResponse(info=f"create hypernetwork error: {e}") | 
					
						
							| 
									
										
										
										
											2023-06-30 13:11:49 +03:00
										 |  |  |         finally: | 
					
						
							|  |  |  |             shared.state.end() | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def train_embedding(self, args: dict): | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2023-06-30 13:11:31 +03:00
										 |  |  |             shared.state.begin(job="train_embedding") | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  |             apply_optimizations = shared.opts.training_xattention_optimizations | 
					
						
							|  |  |  |             error = None | 
					
						
							|  |  |  |             filename = '' | 
					
						
							|  |  |  |             if not apply_optimizations: | 
					
						
							|  |  |  |                 sd_hijack.undo_optimizations() | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 embedding, filename = train_embedding(**args) # can take a long time to complete | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 error = e | 
					
						
							|  |  |  |             finally: | 
					
						
							|  |  |  |                 if not apply_optimizations: | 
					
						
							|  |  |  |                     sd_hijack.apply_optimizations() | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") | 
					
						
							| 
									
										
										
										
											2023-06-30 13:11:49 +03:00
										 |  |  |         except Exception as msg: | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             return models.TrainResponse(info=f"train embedding error: {msg}") | 
					
						
							| 
									
										
										
										
											2023-06-30 13:11:49 +03:00
										 |  |  |         finally: | 
					
						
							|  |  |  |             shared.state.end() | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def train_hypernetwork(self, args: dict): | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2023-06-30 13:11:31 +03:00
										 |  |  |             shared.state.begin(job="train_hypernetwork") | 
					
						
							| 
									
										
										
										
											2023-01-21 08:36:07 +03:00
										 |  |  |             shared.loaded_hypernetworks = [] | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  |             apply_optimizations = shared.opts.training_xattention_optimizations | 
					
						
							|  |  |  |             error = None | 
					
						
							|  |  |  |             filename = '' | 
					
						
							|  |  |  |             if not apply_optimizations: | 
					
						
							|  |  |  |                 sd_hijack.undo_optimizations() | 
					
						
							|  |  |  |             try: | 
					
						
							| 
									
										
										
										
											2023-02-10 17:58:35 +09:00
										 |  |  |                 hypernetwork, filename = train_hypernetwork(**args) | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 error = e | 
					
						
							|  |  |  |             finally: | 
					
						
							|  |  |  |                 shared.sd_model.cond_stage_model.to(devices.device) | 
					
						
							|  |  |  |                 shared.sd_model.first_stage_model.to(devices.device) | 
					
						
							|  |  |  |                 if not apply_optimizations: | 
					
						
							|  |  |  |                     sd_hijack.apply_optimizations() | 
					
						
							|  |  |  |                 shared.state.end() | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") | 
					
						
							| 
									
										
										
										
											2023-06-30 13:11:49 +03:00
										 |  |  |         except Exception as exc: | 
					
						
							|  |  |  |             return models.TrainResponse(info=f"train embedding error: {exc}") | 
					
						
							|  |  |  |         finally: | 
					
						
							| 
									
										
										
										
											2022-12-24 18:02:22 -05:00
										 |  |  |             shared.state.end() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-07 07:51:35 -05:00
										 |  |  |     def get_memory(self): | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             import os | 
					
						
							|  |  |  |             import psutil | 
					
						
							| 
									
										
										
										
											2023-01-07 07:51:35 -05:00
										 |  |  |             process = psutil.Process(os.getpid()) | 
					
						
							| 
									
										
										
										
											2023-01-09 16:54:12 -05:00
										 |  |  |             res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values | 
					
						
							|  |  |  |             ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe | 
					
						
							|  |  |  |             ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total } | 
					
						
							| 
									
										
										
										
											2023-01-07 07:51:35 -05:00
										 |  |  |         except Exception as err: | 
					
						
							|  |  |  |             ram = { 'error': f'{err}' } | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             import torch | 
					
						
							|  |  |  |             if torch.cuda.is_available(): | 
					
						
							|  |  |  |                 s = torch.cuda.mem_get_info() | 
					
						
							| 
									
										
										
										
											2023-01-09 16:54:12 -05:00
										 |  |  |                 system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] } | 
					
						
							| 
									
										
										
										
											2023-01-07 07:51:35 -05:00
										 |  |  |                 s = dict(torch.cuda.memory_stats(shared.device)) | 
					
						
							| 
									
										
										
										
											2023-01-09 16:54:12 -05:00
										 |  |  |                 allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] } | 
					
						
							|  |  |  |                 reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] } | 
					
						
							|  |  |  |                 active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] } | 
					
						
							|  |  |  |                 inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] } | 
					
						
							| 
									
										
										
										
											2023-01-07 07:51:35 -05:00
										 |  |  |                 warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] } | 
					
						
							|  |  |  |                 cuda = { | 
					
						
							|  |  |  |                     'system': system, | 
					
						
							|  |  |  |                     'active': active, | 
					
						
							|  |  |  |                     'allocated': allocated, | 
					
						
							|  |  |  |                     'reserved': reserved, | 
					
						
							|  |  |  |                     'inactive': inactive, | 
					
						
							|  |  |  |                     'events': warnings, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |                 cuda = {'error': 'unavailable'} | 
					
						
							| 
									
										
										
										
											2023-01-07 07:51:35 -05:00
										 |  |  |         except Exception as err: | 
					
						
							| 
									
										
										
										
											2023-05-10 08:25:25 +03:00
										 |  |  |             cuda = {'error': f'{err}'} | 
					
						
							|  |  |  |         return models.MemoryResponse(ram=ram, cuda=cuda) | 
					
						
							| 
									
										
										
										
											2023-08-25 22:23:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 22:15:35 +08:00
										 |  |  |     def get_extensions_list(self): | 
					
						
							|  |  |  |         from modules import extensions | 
					
						
							|  |  |  |         extensions.list_extensions() | 
					
						
							|  |  |  |         ext_list = [] | 
					
						
							|  |  |  |         for ext in extensions.extensions: | 
					
						
							|  |  |  |             ext: extensions.Extension | 
					
						
							|  |  |  |             ext.read_info_from_repo() | 
					
						
							|  |  |  |             if ext.remote is not None: | 
					
						
							|  |  |  |                 ext_list.append({ | 
					
						
							|  |  |  |                     "name": ext.name, | 
					
						
							|  |  |  |                     "remote": ext.remote, | 
					
						
							|  |  |  |                     "branch": ext.branch, | 
					
						
							|  |  |  |                     "commit_hash":ext.commit_hash, | 
					
						
							|  |  |  |                     "commit_date":ext.commit_date, | 
					
						
							|  |  |  |                     "version":ext.version, | 
					
						
							|  |  |  |                     "enabled":ext.enabled | 
					
						
							|  |  |  |                 }) | 
					
						
							|  |  |  |         return ext_list | 
					
						
							| 
									
										
										
										
											2023-01-07 07:51:35 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-25 15:19:10 +03:00
										 |  |  |     def launch(self, server_name, port, root_path): | 
					
						
							| 
									
										
										
										
											2022-10-18 06:51:53 +00:00
										 |  |  |         self.app.include_router(self.router) | 
					
						
							| 
									
										
										
										
											2023-07-25 15:19:10 +03:00
										 |  |  |         uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) | 
					
						
							| 
									
										
										
										
											2023-06-10 23:36:34 +09:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-14 18:51:47 +09:00
										 |  |  |     def kill_webui(self): | 
					
						
							| 
									
										
										
										
											2023-06-10 23:36:34 +09:00
										 |  |  |         restart.stop_program() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def restart_webui(self): | 
					
						
							|  |  |  |         if restart.is_restartable(): | 
					
						
							|  |  |  |             restart.restart_program() | 
					
						
							| 
									
										
										
										
											2023-06-14 19:52:12 +09:00
										 |  |  |         return Response(status_code=501) | 
					
						
							| 
									
										
										
										
											2023-06-12 18:15:27 +09:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-14 19:53:08 +09:00
										 |  |  |     def stop_webui(request): | 
					
						
							| 
									
										
										
										
											2023-06-12 18:15:27 +09:00
										 |  |  |         shared.state.server_command = "stop" | 
					
						
							|  |  |  |         return Response("Stopping.") | 
					
						
							| 
									
										
										
										
											2023-07-13 15:21:39 +03:00
										 |  |  | 
 |