mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-11-04 03:55:05 +00:00 
			
		
		
		
	Merge pull request #10274 from akx/torch-cpu-for-tests
Use CPU Torch in CI, etc.
This commit is contained in:
		
						commit
						fb366891ab
					
				
							
								
								
									
										6
									
								
								.github/workflows/run_tests.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/run_tests.yaml
									
									
									
									
										vendored
									
									
								
							@ -17,8 +17,14 @@ jobs:
 | 
				
			|||||||
          cache: pip
 | 
					          cache: pip
 | 
				
			||||||
          cache-dependency-path: |
 | 
					          cache-dependency-path: |
 | 
				
			||||||
            **/requirements*txt
 | 
					            **/requirements*txt
 | 
				
			||||||
 | 
					            launch.py
 | 
				
			||||||
      - name: Run tests
 | 
					      - name: Run tests
 | 
				
			||||||
        run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
 | 
					        run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
 | 
				
			||||||
 | 
					        env:
 | 
				
			||||||
 | 
					          PIP_DISABLE_PIP_VERSION_CHECK: "1"
 | 
				
			||||||
 | 
					          PIP_PROGRESS_BAR: "off"
 | 
				
			||||||
 | 
					          TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
 | 
				
			||||||
 | 
					          WEBUI_LAUNCH_LIVE_OUTPUT: "1"
 | 
				
			||||||
      - name: Upload main app stdout-stderr
 | 
					      - name: Upload main app stdout-stderr
 | 
				
			||||||
        uses: actions/upload-artifact@v3
 | 
					        uses: actions/upload-artifact@v3
 | 
				
			||||||
        if: always()
 | 
					        if: always()
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										54
									
								
								launch.py
									
									
									
									
									
								
							
							
						
						
									
										54
									
								
								launch.py
									
									
									
									
									
								
							@ -22,6 +22,9 @@ stored_commit_hash = None
 | 
				
			|||||||
stored_git_tag = None
 | 
					stored_git_tag = None
 | 
				
			||||||
dir_repos = "repositories"
 | 
					dir_repos = "repositories"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Whether to default to printing command output
 | 
				
			||||||
 | 
					default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
 | 
					if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
 | 
				
			||||||
    os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
 | 
					    os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -85,32 +88,36 @@ def git_tag():
 | 
				
			|||||||
    return stored_git_tag
 | 
					    return stored_git_tag
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
 | 
					def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
 | 
				
			||||||
    if desc is not None:
 | 
					    if desc is not None:
 | 
				
			||||||
        print(desc)
 | 
					        print(desc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if live:
 | 
					    run_kwargs = {
 | 
				
			||||||
        result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
 | 
					        "args": command,
 | 
				
			||||||
        if result.returncode != 0:
 | 
					        "shell": True,
 | 
				
			||||||
            raise RuntimeError(f"""{errdesc or 'Error running command'}.
 | 
					        "env": os.environ if custom_env is None else custom_env,
 | 
				
			||||||
Command: {command}
 | 
					        "encoding": 'utf8',
 | 
				
			||||||
Error code: {result.returncode}""")
 | 
					        "errors": 'ignore',
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return ""
 | 
					    if not live:
 | 
				
			||||||
 | 
					        run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
 | 
					    result = subprocess.run(**run_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if result.returncode != 0:
 | 
					    if result.returncode != 0:
 | 
				
			||||||
 | 
					        error_bits = [
 | 
				
			||||||
 | 
					            f"{errdesc or 'Error running command'}.",
 | 
				
			||||||
 | 
					            f"Command: {command}",
 | 
				
			||||||
 | 
					            f"Error code: {result.returncode}",
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        if result.stdout:
 | 
				
			||||||
 | 
					            error_bits.append(f"stdout: {result.stdout}")
 | 
				
			||||||
 | 
					        if result.stderr:
 | 
				
			||||||
 | 
					            error_bits.append(f"stderr: {result.stderr}")
 | 
				
			||||||
 | 
					        raise RuntimeError("\n".join(error_bits))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        message = f"""{errdesc or 'Error running command'}.
 | 
					    return (result.stdout or "")
 | 
				
			||||||
Command: {command}
 | 
					 | 
				
			||||||
Error code: {result.returncode}
 | 
					 | 
				
			||||||
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
 | 
					 | 
				
			||||||
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
        raise RuntimeError(message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return result.stdout.decode(encoding="utf8", errors="ignore")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def check_run(command):
 | 
					def check_run(command):
 | 
				
			||||||
@ -135,7 +142,7 @@ def run_python(code, desc=None, errdesc=None):
 | 
				
			|||||||
    return run(f'"{python}" -c "{code}"', desc, errdesc)
 | 
					    return run(f'"{python}" -c "{code}"', desc, errdesc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_pip(command, desc=None, live=False):
 | 
					def run_pip(command, desc=None, live=default_command_live):
 | 
				
			||||||
    if args.skip_install:
 | 
					    if args.skip_install:
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -237,13 +244,14 @@ def run_extensions_installers(settings_file):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def prepare_environment():
 | 
					def prepare_environment():
 | 
				
			||||||
    torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118")
 | 
					    torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
 | 
				
			||||||
 | 
					    torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
 | 
				
			||||||
    requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
 | 
					    requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
 | 
					    xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
 | 
				
			||||||
    gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
 | 
					    gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
 | 
				
			||||||
    clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
 | 
					    clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
 | 
				
			||||||
    openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
 | 
					    openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
 | 
					    stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
 | 
				
			||||||
    taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
 | 
					    taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user