mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-11-04 12:03:36 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			678 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			678 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from __future__ import annotations
 | 
						|
import math
 | 
						|
import psutil
 | 
						|
import platform
 | 
						|
 | 
						|
import torch
 | 
						|
from torch import einsum
 | 
						|
 | 
						|
from ldm.util import default
 | 
						|
from einops import rearrange
 | 
						|
 | 
						|
from modules import shared, errors, devices, sub_quadratic_attention
 | 
						|
from modules.hypernetworks import hypernetwork
 | 
						|
 | 
						|
import ldm.modules.attention
 | 
						|
import ldm.modules.diffusionmodules.model
 | 
						|
 | 
						|
import sgm.modules.attention
 | 
						|
import sgm.modules.diffusionmodules.model
 | 
						|
 | 
						|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
 | 
						|
sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
 | 
						|
 | 
						|
 | 
						|
class SdOptimization:
 | 
						|
    name: str = None
 | 
						|
    label: str | None = None
 | 
						|
    cmd_opt: str | None = None
 | 
						|
    priority: int = 0
 | 
						|
 | 
						|
    def title(self):
 | 
						|
        if self.label is None:
 | 
						|
            return self.name
 | 
						|
 | 
						|
        return f"{self.name} - {self.label}"
 | 
						|
 | 
						|
    def is_available(self):
 | 
						|
        return True
 | 
						|
 | 
						|
    def apply(self):
 | 
						|
        pass
 | 
						|
 | 
						|
    def undo(self):
 | 
						|
        ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
 | 
						|
        ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
 | 
						|
 | 
						|
        sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
 | 
						|
        sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
 | 
						|
 | 
						|
 | 
						|
class SdOptimizationXformers(SdOptimization):
 | 
						|
    name = "xformers"
 | 
						|
    cmd_opt = "xformers"
 | 
						|
    priority = 100
 | 
						|
 | 
						|
    def is_available(self):
 | 
						|
        return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
 | 
						|
 | 
						|
    def apply(self):
 | 
						|
        ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
 | 
						|
        ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
 | 
						|
        sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
 | 
						|
        sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
 | 
						|
 | 
						|
 | 
						|
class SdOptimizationSdpNoMem(SdOptimization):
 | 
						|
    name = "sdp-no-mem"
 | 
						|
    label = "scaled dot product without memory efficient attention"
 | 
						|
    cmd_opt = "opt_sdp_no_mem_attention"
 | 
						|
    priority = 80
 | 
						|
 | 
						|
    def is_available(self):
 | 
						|
        return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
 | 
						|
 | 
						|
    def apply(self):
 | 
						|
        ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
 | 
						|
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
 | 
						|
        sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
 | 
						|
        sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
 | 
						|
 | 
						|
 | 
						|
class SdOptimizationSdp(SdOptimizationSdpNoMem):
 | 
						|
    name = "sdp"
 | 
						|
    label = "scaled dot product"
 | 
						|
    cmd_opt = "opt_sdp_attention"
 | 
						|
    priority = 70
 | 
						|
 | 
						|
    def apply(self):
 | 
						|
        ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
 | 
						|
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
 | 
						|
        sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
 | 
						|
        sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
 | 
						|
 | 
						|
 | 
						|
class SdOptimizationSubQuad(SdOptimization):
 | 
						|
    name = "sub-quadratic"
 | 
						|
    cmd_opt = "opt_sub_quad_attention"
 | 
						|
 | 
						|
    @property
 | 
						|
    def priority(self):
 | 
						|
        return 1000 if shared.device.type == 'mps' else 10
 | 
						|
 | 
						|
    def apply(self):
 | 
						|
        ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
 | 
						|
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
 | 
						|
        sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
 | 
						|
        sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
 | 
						|
 | 
						|
 | 
						|
class SdOptimizationV1(SdOptimization):
 | 
						|
    name = "V1"
 | 
						|
    label = "original v1"
 | 
						|
    cmd_opt = "opt_split_attention_v1"
 | 
						|
    priority = 10
 | 
						|
 | 
						|
    def apply(self):
 | 
						|
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
 | 
						|
        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
 | 
						|
 | 
						|
 | 
						|
class SdOptimizationInvokeAI(SdOptimization):
 | 
						|
    name = "InvokeAI"
 | 
						|
    cmd_opt = "opt_split_attention_invokeai"
 | 
						|
 | 
						|
    @property
 | 
						|
    def priority(self):
 | 
						|
        return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
 | 
						|
 | 
						|
    def apply(self):
 | 
						|
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
 | 
						|
        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
 | 
						|
 | 
						|
 | 
						|
class SdOptimizationDoggettx(SdOptimization):
 | 
						|
    name = "Doggettx"
 | 
						|
    cmd_opt = "opt_split_attention"
 | 
						|
    priority = 90
 | 
						|
 | 
						|
    def apply(self):
 | 
						|
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
 | 
						|
        ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
 | 
						|
        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
 | 
						|
        sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
 | 
						|
 | 
						|
 | 
						|
def list_optimizers(res):
 | 
						|
    res.extend([
 | 
						|
        SdOptimizationXformers(),
 | 
						|
        SdOptimizationSdpNoMem(),
 | 
						|
        SdOptimizationSdp(),
 | 
						|
        SdOptimizationSubQuad(),
 | 
						|
        SdOptimizationV1(),
 | 
						|
        SdOptimizationInvokeAI(),
 | 
						|
        SdOptimizationDoggettx(),
 | 
						|
    ])
 | 
						|
 | 
						|
 | 
						|
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
 | 
						|
    try:
 | 
						|
        import xformers.ops
 | 
						|
        shared.xformers_available = True
 | 
						|
    except Exception:
 | 
						|
        errors.report("Cannot import xformers", exc_info=True)
 | 
						|
 | 
						|
 | 
						|
def get_available_vram():
 | 
						|
    if shared.device.type == 'cuda':
 | 
						|
        stats = torch.cuda.memory_stats(shared.device)
 | 
						|
        mem_active = stats['active_bytes.all.current']
 | 
						|
        mem_reserved = stats['reserved_bytes.all.current']
 | 
						|
        mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
 | 
						|
        mem_free_torch = mem_reserved - mem_active
 | 
						|
        mem_free_total = mem_free_cuda + mem_free_torch
 | 
						|
        return mem_free_total
 | 
						|
    else:
 | 
						|
        return psutil.virtual_memory().available
 | 
						|
 | 
						|
 | 
						|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
 | 
						|
def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
 | 
						|
    h = self.heads
 | 
						|
 | 
						|
    q_in = self.to_q(x)
 | 
						|
    context = default(context, x)
 | 
						|
 | 
						|
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
 | 
						|
    k_in = self.to_k(context_k)
 | 
						|
    v_in = self.to_v(context_v)
 | 
						|
    del context, context_k, context_v, x
 | 
						|
 | 
						|
    q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
 | 
						|
    del q_in, k_in, v_in
 | 
						|
 | 
						|
    dtype = q.dtype
 | 
						|
    if shared.opts.upcast_attn:
 | 
						|
        q, k, v = q.float(), k.float(), v.float()
 | 
						|
 | 
						|
    with devices.without_autocast(disable=not shared.opts.upcast_attn):
 | 
						|
        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
 | 
						|
        for i in range(0, q.shape[0], 2):
 | 
						|
            end = i + 2
 | 
						|
            s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
 | 
						|
            s1 *= self.scale
 | 
						|
 | 
						|
            s2 = s1.softmax(dim=-1)
 | 
						|
            del s1
 | 
						|
 | 
						|
            r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
 | 
						|
            del s2
 | 
						|
        del q, k, v
 | 
						|
 | 
						|
    r1 = r1.to(dtype)
 | 
						|
 | 
						|
    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
 | 
						|
    del r1
 | 
						|
 | 
						|
    return self.to_out(r2)
 | 
						|
 | 
						|
 | 
						|
# taken from https://github.com/Doggettx/stable-diffusion and modified
 | 
						|
def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
 | 
						|
    h = self.heads
 | 
						|
 | 
						|
    q_in = self.to_q(x)
 | 
						|
    context = default(context, x)
 | 
						|
 | 
						|
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
 | 
						|
    k_in = self.to_k(context_k)
 | 
						|
    v_in = self.to_v(context_v)
 | 
						|
 | 
						|
    dtype = q_in.dtype
 | 
						|
    if shared.opts.upcast_attn:
 | 
						|
        q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
 | 
						|
 | 
						|
    with devices.without_autocast(disable=not shared.opts.upcast_attn):
 | 
						|
        k_in = k_in * self.scale
 | 
						|
 | 
						|
        del context, x
 | 
						|
 | 
						|
        q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
 | 
						|
        del q_in, k_in, v_in
 | 
						|
 | 
						|
        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
 | 
						|
 | 
						|
        mem_free_total = get_available_vram()
 | 
						|
 | 
						|
        gb = 1024 ** 3
 | 
						|
        tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
 | 
						|
        modifier = 3 if q.element_size() == 2 else 2.5
 | 
						|
        mem_required = tensor_size * modifier
 | 
						|
        steps = 1
 | 
						|
 | 
						|
        if mem_required > mem_free_total:
 | 
						|
            steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
 | 
						|
            # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
 | 
						|
            #       f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
 | 
						|
 | 
						|
        if steps > 64:
 | 
						|
            max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
 | 
						|
            raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
 | 
						|
                               f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
 | 
						|
 | 
						|
        slice_size = q.shape[1] // steps
 | 
						|
        for i in range(0, q.shape[1], slice_size):
 | 
						|
            end = min(i + slice_size, q.shape[1])
 | 
						|
            s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
 | 
						|
 | 
						|
            s2 = s1.softmax(dim=-1, dtype=q.dtype)
 | 
						|
            del s1
 | 
						|
 | 
						|
            r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
 | 
						|
            del s2
 | 
						|
 | 
						|
        del q, k, v
 | 
						|
 | 
						|
    r1 = r1.to(dtype)
 | 
						|
 | 
						|
    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
 | 
						|
    del r1
 | 
						|
 | 
						|
    return self.to_out(r2)
 | 
						|
 | 
						|
 | 
						|
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
 | 
						|
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
 | 
						|
 | 
						|
 | 
						|
def einsum_op_compvis(q, k, v):
 | 
						|
    s = einsum('b i d, b j d -> b i j', q, k)
 | 
						|
    s = s.softmax(dim=-1, dtype=s.dtype)
 | 
						|
    return einsum('b i j, b j d -> b i d', s, v)
 | 
						|
 | 
						|
 | 
						|
def einsum_op_slice_0(q, k, v, slice_size):
 | 
						|
    r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
 | 
						|
    for i in range(0, q.shape[0], slice_size):
 | 
						|
        end = i + slice_size
 | 
						|
        r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
 | 
						|
    return r
 | 
						|
 | 
						|
 | 
						|
def einsum_op_slice_1(q, k, v, slice_size):
 | 
						|
    r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
 | 
						|
    for i in range(0, q.shape[1], slice_size):
 | 
						|
        end = i + slice_size
 | 
						|
        r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
 | 
						|
    return r
 | 
						|
 | 
						|
 | 
						|
def einsum_op_mps_v1(q, k, v):
 | 
						|
    if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
 | 
						|
        return einsum_op_compvis(q, k, v)
 | 
						|
    else:
 | 
						|
        slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
 | 
						|
        if slice_size % 4096 == 0:
 | 
						|
            slice_size -= 1
 | 
						|
        return einsum_op_slice_1(q, k, v, slice_size)
 | 
						|
 | 
						|
 | 
						|
def einsum_op_mps_v2(q, k, v):
 | 
						|
    if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
 | 
						|
        return einsum_op_compvis(q, k, v)
 | 
						|
    else:
 | 
						|
        return einsum_op_slice_0(q, k, v, 1)
 | 
						|
 | 
						|
 | 
						|
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
 | 
						|
    size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
 | 
						|
    if size_mb <= max_tensor_mb:
 | 
						|
        return einsum_op_compvis(q, k, v)
 | 
						|
    div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
 | 
						|
    if div <= q.shape[0]:
 | 
						|
        return einsum_op_slice_0(q, k, v, q.shape[0] // div)
 | 
						|
    return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
 | 
						|
 | 
						|
 | 
						|
def einsum_op_cuda(q, k, v):
 | 
						|
    stats = torch.cuda.memory_stats(q.device)
 | 
						|
    mem_active = stats['active_bytes.all.current']
 | 
						|
    mem_reserved = stats['reserved_bytes.all.current']
 | 
						|
    mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
 | 
						|
    mem_free_torch = mem_reserved - mem_active
 | 
						|
    mem_free_total = mem_free_cuda + mem_free_torch
 | 
						|
    # Divide factor of safety as there's copying and fragmentation
 | 
						|
    return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
 | 
						|
 | 
						|
 | 
						|
def einsum_op(q, k, v):
 | 
						|
    if q.device.type == 'cuda':
 | 
						|
        return einsum_op_cuda(q, k, v)
 | 
						|
 | 
						|
    if q.device.type == 'mps':
 | 
						|
        if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
 | 
						|
            return einsum_op_mps_v1(q, k, v)
 | 
						|
        return einsum_op_mps_v2(q, k, v)
 | 
						|
 | 
						|
    # Smaller slices are faster due to L2/L3/SLC caches.
 | 
						|
    # Tested on i7 with 8MB L3 cache.
 | 
						|
    return einsum_op_tensor_mem(q, k, v, 32)
 | 
						|
 | 
						|
 | 
						|
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
 | 
						|
    h = self.heads
 | 
						|
 | 
						|
    q = self.to_q(x)
 | 
						|
    context = default(context, x)
 | 
						|
 | 
						|
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
 | 
						|
    k = self.to_k(context_k)
 | 
						|
    v = self.to_v(context_v)
 | 
						|
    del context, context_k, context_v, x
 | 
						|
 | 
						|
    dtype = q.dtype
 | 
						|
    if shared.opts.upcast_attn:
 | 
						|
        q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
 | 
						|
 | 
						|
    with devices.without_autocast(disable=not shared.opts.upcast_attn):
 | 
						|
        k = k * self.scale
 | 
						|
 | 
						|
        q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
 | 
						|
        r = einsum_op(q, k, v)
 | 
						|
    r = r.to(dtype)
 | 
						|
    return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
 | 
						|
 | 
						|
# -- End of code from https://github.com/invoke-ai/InvokeAI --
 | 
						|
 | 
						|
 | 
						|
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
 | 
						|
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
 | 
						|
def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
 | 
						|
    assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
 | 
						|
 | 
						|
    h = self.heads
 | 
						|
 | 
						|
    q = self.to_q(x)
 | 
						|
    context = default(context, x)
 | 
						|
 | 
						|
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
 | 
						|
    k = self.to_k(context_k)
 | 
						|
    v = self.to_v(context_v)
 | 
						|
    del context, context_k, context_v, x
 | 
						|
 | 
						|
    q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
 | 
						|
    k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
 | 
						|
    v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
 | 
						|
 | 
						|
    if q.device.type == 'mps':
 | 
						|
        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
 | 
						|
 | 
						|
    dtype = q.dtype
 | 
						|
    if shared.opts.upcast_attn:
 | 
						|
        q, k = q.float(), k.float()
 | 
						|
 | 
						|
    x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
 | 
						|
 | 
						|
    x = x.to(dtype)
 | 
						|
 | 
						|
    x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
 | 
						|
 | 
						|
    out_proj, dropout = self.to_out
 | 
						|
    x = out_proj(x)
 | 
						|
    x = dropout(x)
 | 
						|
 | 
						|
    return x
 | 
						|
 | 
						|
 | 
						|
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
 | 
						|
    bytes_per_token = torch.finfo(q.dtype).bits//8
 | 
						|
    batch_x_heads, q_tokens, _ = q.shape
 | 
						|
    _, k_tokens, _ = k.shape
 | 
						|
    qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
 | 
						|
 | 
						|
    if chunk_threshold is None:
 | 
						|
        if q.device.type == 'mps':
 | 
						|
            chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
 | 
						|
        else:
 | 
						|
            chunk_threshold_bytes = int(get_available_vram() * 0.7)
 | 
						|
    elif chunk_threshold == 0:
 | 
						|
        chunk_threshold_bytes = None
 | 
						|
    else:
 | 
						|
        chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
 | 
						|
 | 
						|
    if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
 | 
						|
        kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
 | 
						|
    elif kv_chunk_size_min == 0:
 | 
						|
        kv_chunk_size_min = None
 | 
						|
 | 
						|
    if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
 | 
						|
        # the big matmul fits into our memory limit; do everything in 1 chunk,
 | 
						|
        # i.e. send it down the unchunked fast-path
 | 
						|
        kv_chunk_size = k_tokens
 | 
						|
 | 
						|
    with devices.without_autocast(disable=q.dtype == v.dtype):
 | 
						|
        return sub_quadratic_attention.efficient_dot_product_attention(
 | 
						|
            q,
 | 
						|
            k,
 | 
						|
            v,
 | 
						|
            query_chunk_size=q_chunk_size,
 | 
						|
            kv_chunk_size=kv_chunk_size,
 | 
						|
            kv_chunk_size_min = kv_chunk_size_min,
 | 
						|
            use_checkpoint=use_checkpoint,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def get_xformers_flash_attention_op(q, k, v):
 | 
						|
    if not shared.cmd_opts.xformers_flash_attention:
 | 
						|
        return None
 | 
						|
 | 
						|
    try:
 | 
						|
        flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
 | 
						|
        fw, bw = flash_attention_op
 | 
						|
        if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
 | 
						|
            return flash_attention_op
 | 
						|
    except Exception as e:
 | 
						|
        errors.display_once(e, "enabling flash attention")
 | 
						|
 | 
						|
    return None
 | 
						|
 | 
						|
 | 
						|
def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
 | 
						|
    h = self.heads
 | 
						|
    q_in = self.to_q(x)
 | 
						|
    context = default(context, x)
 | 
						|
 | 
						|
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
 | 
						|
    k_in = self.to_k(context_k)
 | 
						|
    v_in = self.to_v(context_v)
 | 
						|
 | 
						|
    q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in))
 | 
						|
 | 
						|
    del q_in, k_in, v_in
 | 
						|
 | 
						|
    dtype = q.dtype
 | 
						|
    if shared.opts.upcast_attn:
 | 
						|
        q, k, v = q.float(), k.float(), v.float()
 | 
						|
 | 
						|
    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
 | 
						|
 | 
						|
    out = out.to(dtype)
 | 
						|
 | 
						|
    b, n, h, d = out.shape
 | 
						|
    out = out.reshape(b, n, h * d)
 | 
						|
    return self.to_out(out)
 | 
						|
 | 
						|
 | 
						|
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
 | 
						|
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
 | 
						|
def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
 | 
						|
    batch_size, sequence_length, inner_dim = x.shape
 | 
						|
 | 
						|
    if mask is not None:
 | 
						|
        mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
 | 
						|
        mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
 | 
						|
 | 
						|
    h = self.heads
 | 
						|
    q_in = self.to_q(x)
 | 
						|
    context = default(context, x)
 | 
						|
 | 
						|
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
 | 
						|
    k_in = self.to_k(context_k)
 | 
						|
    v_in = self.to_v(context_v)
 | 
						|
 | 
						|
    head_dim = inner_dim // h
 | 
						|
    q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
 | 
						|
    k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
 | 
						|
    v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
 | 
						|
 | 
						|
    del q_in, k_in, v_in
 | 
						|
 | 
						|
    dtype = q.dtype
 | 
						|
    if shared.opts.upcast_attn:
 | 
						|
        q, k, v = q.float(), k.float(), v.float()
 | 
						|
 | 
						|
    # the output of sdp = (batch, num_heads, seq_len, head_dim)
 | 
						|
    hidden_states = torch.nn.functional.scaled_dot_product_attention(
 | 
						|
        q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
 | 
						|
    )
 | 
						|
 | 
						|
    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
 | 
						|
    hidden_states = hidden_states.to(dtype)
 | 
						|
 | 
						|
    # linear proj
 | 
						|
    hidden_states = self.to_out[0](hidden_states)
 | 
						|
    # dropout
 | 
						|
    hidden_states = self.to_out[1](hidden_states)
 | 
						|
    return hidden_states
 | 
						|
 | 
						|
 | 
						|
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
 | 
						|
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
 | 
						|
        return scaled_dot_product_attention_forward(self, x, context, mask)
 | 
						|
 | 
						|
 | 
						|
def cross_attention_attnblock_forward(self, x):
 | 
						|
        h_ = x
 | 
						|
        h_ = self.norm(h_)
 | 
						|
        q1 = self.q(h_)
 | 
						|
        k1 = self.k(h_)
 | 
						|
        v = self.v(h_)
 | 
						|
 | 
						|
        # compute attention
 | 
						|
        b, c, h, w = q1.shape
 | 
						|
 | 
						|
        q2 = q1.reshape(b, c, h*w)
 | 
						|
        del q1
 | 
						|
 | 
						|
        q = q2.permute(0, 2, 1)   # b,hw,c
 | 
						|
        del q2
 | 
						|
 | 
						|
        k = k1.reshape(b, c, h*w) # b,c,hw
 | 
						|
        del k1
 | 
						|
 | 
						|
        h_ = torch.zeros_like(k, device=q.device)
 | 
						|
 | 
						|
        mem_free_total = get_available_vram()
 | 
						|
 | 
						|
        tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
 | 
						|
        mem_required = tensor_size * 2.5
 | 
						|
        steps = 1
 | 
						|
 | 
						|
        if mem_required > mem_free_total:
 | 
						|
            steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
 | 
						|
 | 
						|
        slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
 | 
						|
        for i in range(0, q.shape[1], slice_size):
 | 
						|
            end = i + slice_size
 | 
						|
 | 
						|
            w1 = torch.bmm(q[:, i:end], k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
 | 
						|
            w2 = w1 * (int(c)**(-0.5))
 | 
						|
            del w1
 | 
						|
            w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
 | 
						|
            del w2
 | 
						|
 | 
						|
            # attend to values
 | 
						|
            v1 = v.reshape(b, c, h*w)
 | 
						|
            w4 = w3.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)
 | 
						|
            del w3
 | 
						|
 | 
						|
            h_[:, :, i:end] = torch.bmm(v1, w4)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
 | 
						|
            del v1, w4
 | 
						|
 | 
						|
        h2 = h_.reshape(b, c, h, w)
 | 
						|
        del h_
 | 
						|
 | 
						|
        h3 = self.proj_out(h2)
 | 
						|
        del h2
 | 
						|
 | 
						|
        h3 += x
 | 
						|
 | 
						|
        return h3
 | 
						|
 | 
						|
 | 
						|
def xformers_attnblock_forward(self, x):
 | 
						|
    try:
 | 
						|
        h_ = x
 | 
						|
        h_ = self.norm(h_)
 | 
						|
        q = self.q(h_)
 | 
						|
        k = self.k(h_)
 | 
						|
        v = self.v(h_)
 | 
						|
        b, c, h, w = q.shape
 | 
						|
        q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
 | 
						|
        dtype = q.dtype
 | 
						|
        if shared.opts.upcast_attn:
 | 
						|
            q, k = q.float(), k.float()
 | 
						|
        q = q.contiguous()
 | 
						|
        k = k.contiguous()
 | 
						|
        v = v.contiguous()
 | 
						|
        out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
 | 
						|
        out = out.to(dtype)
 | 
						|
        out = rearrange(out, 'b (h w) c -> b c h w', h=h)
 | 
						|
        out = self.proj_out(out)
 | 
						|
        return x + out
 | 
						|
    except NotImplementedError:
 | 
						|
        return cross_attention_attnblock_forward(self, x)
 | 
						|
 | 
						|
 | 
						|
def sdp_attnblock_forward(self, x):
 | 
						|
    h_ = x
 | 
						|
    h_ = self.norm(h_)
 | 
						|
    q = self.q(h_)
 | 
						|
    k = self.k(h_)
 | 
						|
    v = self.v(h_)
 | 
						|
    b, c, h, w = q.shape
 | 
						|
    q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
 | 
						|
    dtype = q.dtype
 | 
						|
    if shared.opts.upcast_attn:
 | 
						|
        q, k, v = q.float(), k.float(), v.float()
 | 
						|
    q = q.contiguous()
 | 
						|
    k = k.contiguous()
 | 
						|
    v = v.contiguous()
 | 
						|
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
 | 
						|
    out = out.to(dtype)
 | 
						|
    out = rearrange(out, 'b (h w) c -> b c h w', h=h)
 | 
						|
    out = self.proj_out(out)
 | 
						|
    return x + out
 | 
						|
 | 
						|
 | 
						|
def sdp_no_mem_attnblock_forward(self, x):
 | 
						|
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
 | 
						|
        return sdp_attnblock_forward(self, x)
 | 
						|
 | 
						|
 | 
						|
def sub_quad_attnblock_forward(self, x):
 | 
						|
    h_ = x
 | 
						|
    h_ = self.norm(h_)
 | 
						|
    q = self.q(h_)
 | 
						|
    k = self.k(h_)
 | 
						|
    v = self.v(h_)
 | 
						|
    b, c, h, w = q.shape
 | 
						|
    q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
 | 
						|
    q = q.contiguous()
 | 
						|
    k = k.contiguous()
 | 
						|
    v = v.contiguous()
 | 
						|
    out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
 | 
						|
    out = rearrange(out, 'b (h w) c -> b c h w', h=h)
 | 
						|
    out = self.proj_out(out)
 | 
						|
    return x + out
 |