mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-10-24 22:41:19 +00:00 
			
		
		
		
	 e3b53fd295
			
		
	
	
		e3b53fd295
		
	
	
	
	
		
			
			Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers. In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also.
		
			
				
	
	
		
			215 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			215 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # original source:
 | |
| #   https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
 | |
| # license:
 | |
| #   MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
 | |
| # credit:
 | |
| #   Amin Rezaei (original author)
 | |
| #   Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
 | |
| #   brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
 | |
| # implementation of:
 | |
| #   Self-attention Does Not Need O(n2) Memory":
 | |
| #   https://arxiv.org/abs/2112.05682v2
 | |
| 
 | |
| from functools import partial
 | |
| import torch
 | |
| from torch import Tensor
 | |
| from torch.utils.checkpoint import checkpoint
 | |
| import math
 | |
| from typing import Optional, NamedTuple, List
 | |
| 
 | |
| 
 | |
| def narrow_trunc(
 | |
|     input: Tensor,
 | |
|     dim: int,
 | |
|     start: int,
 | |
|     length: int
 | |
| ) -> Tensor:
 | |
|     return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
 | |
| 
 | |
| 
 | |
| class AttnChunk(NamedTuple):
 | |
|     exp_values: Tensor
 | |
|     exp_weights_sum: Tensor
 | |
|     max_score: Tensor
 | |
| 
 | |
| 
 | |
| class SummarizeChunk:
 | |
|     @staticmethod
 | |
|     def __call__(
 | |
|         query: Tensor,
 | |
|         key: Tensor,
 | |
|         value: Tensor,
 | |
|     ) -> AttnChunk: ...
 | |
| 
 | |
| 
 | |
| class ComputeQueryChunkAttn:
 | |
|     @staticmethod
 | |
|     def __call__(
 | |
|         query: Tensor,
 | |
|         key: Tensor,
 | |
|         value: Tensor,
 | |
|     ) -> Tensor: ...
 | |
| 
 | |
| 
 | |
| def _summarize_chunk(
 | |
|     query: Tensor,
 | |
|     key: Tensor,
 | |
|     value: Tensor,
 | |
|     scale: float,
 | |
| ) -> AttnChunk:
 | |
|     attn_weights = torch.baddbmm(
 | |
|         torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
 | |
|         query,
 | |
|         key.transpose(1,2),
 | |
|         alpha=scale,
 | |
|         beta=0,
 | |
|     )
 | |
|     max_score, _ = torch.max(attn_weights, -1, keepdim=True)
 | |
|     max_score = max_score.detach()
 | |
|     exp_weights = torch.exp(attn_weights - max_score)
 | |
|     exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
 | |
|     max_score = max_score.squeeze(-1)
 | |
|     return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
 | |
| 
 | |
| 
 | |
| def _query_chunk_attention(
 | |
|     query: Tensor,
 | |
|     key: Tensor,
 | |
|     value: Tensor,
 | |
|     summarize_chunk: SummarizeChunk,
 | |
|     kv_chunk_size: int,
 | |
| ) -> Tensor:
 | |
|     batch_x_heads, k_tokens, k_channels_per_head = key.shape
 | |
|     _, _, v_channels_per_head = value.shape
 | |
| 
 | |
|     def chunk_scanner(chunk_idx: int) -> AttnChunk:
 | |
|         key_chunk = narrow_trunc(
 | |
|             key,
 | |
|             1,
 | |
|             chunk_idx,
 | |
|             kv_chunk_size
 | |
|         )
 | |
|         value_chunk = narrow_trunc(
 | |
|             value,
 | |
|             1,
 | |
|             chunk_idx,
 | |
|             kv_chunk_size
 | |
|         )
 | |
|         return summarize_chunk(query, key_chunk, value_chunk)
 | |
| 
 | |
|     chunks: List[AttnChunk] = [
 | |
|         chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
 | |
|     ]
 | |
|     acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
 | |
|     chunk_values, chunk_weights, chunk_max = acc_chunk
 | |
| 
 | |
|     global_max, _ = torch.max(chunk_max, 0, keepdim=True)
 | |
|     max_diffs = torch.exp(chunk_max - global_max)
 | |
|     chunk_values *= torch.unsqueeze(max_diffs, -1)
 | |
|     chunk_weights *= max_diffs
 | |
| 
 | |
|     all_values = chunk_values.sum(dim=0)
 | |
|     all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
 | |
|     return all_values / all_weights
 | |
| 
 | |
| 
 | |
| # TODO: refactor CrossAttention#get_attention_scores to share code with this
 | |
| def _get_attention_scores_no_kv_chunking(
 | |
|     query: Tensor,
 | |
|     key: Tensor,
 | |
|     value: Tensor,
 | |
|     scale: float,
 | |
| ) -> Tensor:
 | |
|     attn_scores = torch.baddbmm(
 | |
|         torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
 | |
|         query,
 | |
|         key.transpose(1,2),
 | |
|         alpha=scale,
 | |
|         beta=0,
 | |
|     )
 | |
|     attn_probs = attn_scores.softmax(dim=-1)
 | |
|     del attn_scores
 | |
|     hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
 | |
|     return hidden_states_slice
 | |
| 
 | |
| 
 | |
| class ScannedChunk(NamedTuple):
 | |
|     chunk_idx: int
 | |
|     attn_chunk: AttnChunk
 | |
| 
 | |
| 
 | |
| def efficient_dot_product_attention(
 | |
|     query: Tensor,
 | |
|     key: Tensor,
 | |
|     value: Tensor,
 | |
|     query_chunk_size=1024,
 | |
|     kv_chunk_size: Optional[int] = None,
 | |
|     kv_chunk_size_min: Optional[int] = None,
 | |
|     use_checkpoint=True,
 | |
| ):
 | |
|     """Computes efficient dot-product attention given query, key, and value.
 | |
|       This is efficient version of attention presented in
 | |
|       https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
 | |
|       Args:
 | |
|         query: queries for calculating attention with shape of
 | |
|           `[batch * num_heads, tokens, channels_per_head]`.
 | |
|         key: keys for calculating attention with shape of
 | |
|           `[batch * num_heads, tokens, channels_per_head]`.
 | |
|         value: values to be used in attention with shape of
 | |
|           `[batch * num_heads, tokens, channels_per_head]`.
 | |
|         query_chunk_size: int: query chunks size
 | |
|         kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
 | |
|         kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
 | |
|         use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
 | |
|       Returns:
 | |
|         Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
 | |
|       """
 | |
|     batch_x_heads, q_tokens, q_channels_per_head = query.shape
 | |
|     _, k_tokens, _ = key.shape
 | |
|     scale = q_channels_per_head ** -0.5
 | |
| 
 | |
|     kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
 | |
|     if kv_chunk_size_min is not None:
 | |
|         kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
 | |
| 
 | |
|     def get_query_chunk(chunk_idx: int) -> Tensor:
 | |
|         return narrow_trunc(
 | |
|             query,
 | |
|             1,
 | |
|             chunk_idx,
 | |
|             min(query_chunk_size, q_tokens)
 | |
|         )
 | |
|     
 | |
|     summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
 | |
|     summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
 | |
|     compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
 | |
|         _get_attention_scores_no_kv_chunking,
 | |
|         scale=scale
 | |
|     ) if k_tokens <= kv_chunk_size else (
 | |
|         # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
 | |
|         partial(
 | |
|             _query_chunk_attention,
 | |
|             kv_chunk_size=kv_chunk_size,
 | |
|             summarize_chunk=summarize_chunk,
 | |
|         )
 | |
|     )
 | |
| 
 | |
|     if q_tokens <= query_chunk_size:
 | |
|         # fast-path for when there's just 1 query chunk
 | |
|         return compute_query_chunk_attn(
 | |
|             query=query,
 | |
|             key=key,
 | |
|             value=value,
 | |
|         )
 | |
|     
 | |
|     # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
 | |
|     # and pass slices to be mutated, instead of torch.cat()ing the returned slices
 | |
|     res = torch.cat([
 | |
|         compute_query_chunk_attn(
 | |
|             query=get_query_chunk(i * query_chunk_size),
 | |
|             key=key,
 | |
|             value=value,
 | |
|         ) for i in range(math.ceil(q_tokens / query_chunk_size))
 | |
|     ], dim=1)
 | |
|     return res
 |