Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,981 Bytes
f3ff4f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from typing import *
import torch
import torch.nn as nn
from ...modules.utils import convert_module_to_f16, convert_module_to_f32, convert_module_to_bf16
from ...modules import sparse as sp
from ...modules.transformer import AbsolutePositionEmbedder
from ...modules.sparse.transformer import SparseTransformerBlock
def block_attn_config(self):
"""
Return the attention configuration of the model.
"""
for i in range(self.num_blocks):
if self.attn_mode == "shift_window":
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
elif self.attn_mode == "shift_sequence":
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
elif self.attn_mode == "shift_order":
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
elif self.attn_mode == "full":
yield "full", None, None, None, None
elif self.attn_mode == "swin":
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
class SparseTransformerBase(nn.Module):
"""
Sparse Transformer without output layers.
Serve as the base class for encoder and decoder.
"""
def __init__(
self,
in_channels: int,
model_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
window_size: Optional[int] = None,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_bf16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.num_blocks = num_blocks
self.window_size = window_size
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.attn_mode = attn_mode
self.pe_mode = pe_mode
self.use_fp16 = use_fp16
self.use_bf16 = use_bf16
self.use_checkpoint = use_checkpoint
self.qk_rms_norm = qk_rms_norm
if use_fp16:
self.dtype = torch.float16
elif use_bf16:
self.dtype = torch.bfloat16
else:
self.dtype = torch.float32
if pe_mode == "ape":
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
self.input_layer = sp.SparseLinear(in_channels, model_channels)
self.blocks = nn.ModuleList([
SparseTransformerBlock(
model_channels,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
shift_sequence=shift_sequence,
shift_window=shift_window,
serialize_mode=serialize_mode,
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
qk_rms_norm=self.qk_rms_norm,
)
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
])
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
self.use_fp16 = True
self.use_bf16 = False
self.dtype = torch.float16
self.blocks.apply(convert_module_to_f16)
def convert_to_bf16(self) -> None:
"""
Convert the torso of the model to bfloat16.
"""
self.use_fp16 = False
self.use_bf16 = True
self.dtype = torch.bfloat16
self.blocks.apply(convert_module_to_bf16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
self.use_fp16 = False
self.use_bf16 = False
self.dtype = torch.float32
self.blocks.apply(convert_module_to_f32)
def initialize_weights(self) -> None:
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = self.input_layer(x)
if self.pe_mode == "ape":
h = h + self.pos_embedder(x.coords[:, 1:])
h = h.type(self.dtype)
for block in self.blocks:
h = block(h)
return h
|