snowclipsed commited on
Commit
595a8a6
·
1 Parent(s): a89c592

remove dtype as input in weights

Browse files
Files changed (3) hide show
  1. layers.py +2 -2
  2. text.py +5 -6
  3. weights.py +2 -2
layers.py CHANGED
@@ -36,7 +36,7 @@ class QuantizedLinear(nn.Module):
36
  self,
37
  in_features: int,
38
  out_features: int,
39
- dtype: torch.dtype,
40
  ):
41
  # TODO: Take group_size as an input instead of hardcoding it here.
42
  super().__init__()
@@ -46,7 +46,7 @@ class QuantizedLinear(nn.Module):
46
  {
47
  "packed": nn.Parameter(
48
  torch.empty(
49
- out_features * in_features // (128 * 2), 128, dtype=torch.uint8
50
  ),
51
  requires_grad=False,
52
  ),
 
36
  self,
37
  in_features: int,
38
  out_features: int,
39
+ dtype: torch.dtype = torch.uint8,
40
  ):
41
  # TODO: Take group_size as an input instead of hardcoding it here.
42
  super().__init__()
 
46
  {
47
  "packed": nn.Parameter(
48
  torch.empty(
49
+ out_features * in_features // (128 * 2), 128, dtype=dtype
50
  ),
51
  requires_grad=False,
52
  ),
text.py CHANGED
@@ -152,9 +152,8 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
152
  return logits
153
 
154
 
155
- def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
156
  qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
157
- linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
158
 
159
  text = nn.ModuleDict(
160
  {
@@ -165,18 +164,18 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
165
  "ln": nn.LayerNorm(config.dim, dtype=dtype),
166
  "attn": nn.ModuleDict(
167
  {
168
- "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
169
- "proj": linear_cls(
170
  config.dim, config.dim, dtype=dtype
171
  ),
172
  }
173
  ),
174
  "mlp": nn.ModuleDict(
175
  {
176
- "fc1": linear_cls(
177
  config.dim, config.ff_dim, dtype=dtype
178
  ),
179
- "fc2": linear_cls(
180
  config.ff_dim, config.dim, dtype=dtype
181
  ),
182
  }
 
152
  return logits
153
 
154
 
155
+ def build_text_model(config: TextConfig, dtype: torch.dtype = torch.float16) -> nn.Module:
156
  qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
 
157
 
158
  text = nn.ModuleDict(
159
  {
 
164
  "ln": nn.LayerNorm(config.dim, dtype=dtype),
165
  "attn": nn.ModuleDict(
166
  {
167
+ "qkv": QuantizedLinear(config.dim, qkv_dim, dtype=dtype),
168
+ "proj": QuantizedLinear(
169
  config.dim, config.dim, dtype=dtype
170
  ),
171
  }
172
  ),
173
  "mlp": nn.ModuleDict(
174
  {
175
+ "fc1": QuantizedLinear(
176
  config.dim, config.ff_dim, dtype=dtype
177
  ),
178
+ "fc2": QuantizedLinear(
179
  config.ff_dim, config.dim, dtype=dtype
180
  ),
181
  }
weights.py CHANGED
@@ -175,7 +175,7 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
175
  linear_dtype = torch.int8 if is_quantized else torch.float16
176
 
177
  model.text = build_text_model(
178
- TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype
179
  )
180
  if model.setup_caches_flag:
181
  model._setup_caches()
@@ -212,7 +212,7 @@ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
212
 
213
  linear_dtype = torch.int8 if is_quantized else torch.float16
214
  model.text = build_text_model(
215
- TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype
216
  )
217
  if model.setup_caches_flag:
218
  model._setup_caches()
 
175
  linear_dtype = torch.int8 if is_quantized else torch.float16
176
 
177
  model.text = build_text_model(
178
+ TextConfig
179
  )
180
  if model.setup_caches_flag:
181
  model._setup_caches()
 
212
 
213
  linear_dtype = torch.int8 if is_quantized else torch.float16
214
  model.text = build_text_model(
215
+ TextConfig
216
  )
217
  if model.setup_caches_flag:
218
  model._setup_caches()