snowclipsed commited on
Commit
4f9fd1b
·
1 Parent(s): 1db79dd

fix group size input in build text model

Browse files
Files changed (1) hide show
  1. text.py +4 -4
text.py CHANGED
@@ -164,19 +164,19 @@ def build_text_model(config: TextConfig, dtype: torch.dtype = torch.float16) ->
164
  "ln": nn.LayerNorm(config.dim, dtype=dtype),
165
  "attn": nn.ModuleDict(
166
  {
167
- "qkv": QuantizedLinear(config.dim, qkv_dim, group_size=config.text.group_size, dtype=dtype),
168
  "proj": QuantizedLinear(
169
- config.dim, config.dim, group_size=config.text.group_size, dtype=dtype
170
  ),
171
  }
172
  ),
173
  "mlp": nn.ModuleDict(
174
  {
175
  "fc1": QuantizedLinear(
176
- config.dim, config.ff_dim, group_size=config.text.group_size, dtype=dtype
177
  ),
178
  "fc2": QuantizedLinear(
179
- config.ff_dim, config.dim, group_size=config.text.group_size, dtype=dtype
180
  ),
181
  }
182
  ),
 
164
  "ln": nn.LayerNorm(config.dim, dtype=dtype),
165
  "attn": nn.ModuleDict(
166
  {
167
+ "qkv": QuantizedLinear(config.dim, qkv_dim, group_size=config.group_size, dtype=dtype),
168
  "proj": QuantizedLinear(
169
+ config.dim, config.dim, group_size=config.group_size, dtype=dtype
170
  ),
171
  }
172
  ),
173
  "mlp": nn.ModuleDict(
174
  {
175
  "fc1": QuantizedLinear(
176
+ config.dim, config.ff_dim, group_size=config.group_size, dtype=dtype
177
  ),
178
  "fc2": QuantizedLinear(
179
+ config.ff_dim, config.dim, group_size=config.group_size, dtype=dtype
180
  ),
181
  }
182
  ),