Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from diffusers.models.attention_processor import ( | |
| Attention, | |
| AttnProcessor2_0, | |
| SlicedAttnProcessor, | |
| XFormersAttnProcessor | |
| ) | |
| try: | |
| import xformers.ops | |
| except: | |
| xformers = None | |
| loaded_networks = [] | |
| def apply_single_hypernetwork( | |
| hypernetwork, hidden_states, encoder_hidden_states | |
| ): | |
| context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) | |
| return context_k, context_v | |
| def apply_hypernetworks(context_k, context_v, layer=None): | |
| if len(loaded_networks) == 0: | |
| return context_v, context_v | |
| for hypernetwork in loaded_networks: | |
| context_k, context_v = hypernetwork.forward(context_k, context_v) | |
| context_k = context_k.to(dtype=context_k.dtype) | |
| context_v = context_v.to(dtype=context_k.dtype) | |
| return context_k, context_v | |
| def xformers_forward( | |
| self: XFormersAttnProcessor, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor = None, | |
| attention_mask: torch.Tensor = None, | |
| ): | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) | |
| key = attn.to_k(context_k) | |
| value = attn.to_v(context_v) | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| hidden_states = xformers.ops.memory_efficient_attention( | |
| query, | |
| key, | |
| value, | |
| attn_bias=attention_mask, | |
| op=self.attention_op, | |
| scale=attn.scale, | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states | |
| def sliced_attn_forward( | |
| self: SlicedAttnProcessor, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor = None, | |
| attention_mask: torch.Tensor = None, | |
| ): | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| query = attn.to_q(hidden_states) | |
| dim = query.shape[-1] | |
| query = attn.head_to_batch_dim(query) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) | |
| key = attn.to_k(context_k) | |
| value = attn.to_v(context_v) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| batch_size_attention, query_tokens, _ = query.shape | |
| hidden_states = torch.zeros( | |
| (batch_size_attention, query_tokens, dim // attn.heads), | |
| device=query.device, | |
| dtype=query.dtype, | |
| ) | |
| for i in range(batch_size_attention // self.slice_size): | |
| start_idx = i * self.slice_size | |
| end_idx = (i + 1) * self.slice_size | |
| query_slice = query[start_idx:end_idx] | |
| key_slice = key[start_idx:end_idx] | |
| attn_mask_slice = ( | |
| attention_mask[start_idx:end_idx] if attention_mask is not None else None | |
| ) | |
| attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) | |
| attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) | |
| hidden_states[start_idx:end_idx] = attn_slice | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states | |
| def v2_0_forward( | |
| self: AttnProcessor2_0, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| ): | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| inner_dim = hidden_states.shape[-1] | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) | |
| key = attn.to_k(context_k) | |
| value = attn.to_v(context_v) | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states | |
| def replace_attentions_for_hypernetwork(): | |
| import diffusers.models.attention_processor | |
| diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( | |
| xformers_forward | |
| ) | |
| diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( | |
| sliced_attn_forward | |
| ) | |
| diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward | |