snowclipsed
commited on
Commit
·
595a8a6
1
Parent(s):
a89c592
remove dtype as input in weights
Browse files- layers.py +2 -2
- text.py +5 -6
- weights.py +2 -2
layers.py
CHANGED
@@ -36,7 +36,7 @@ class QuantizedLinear(nn.Module):
|
|
36 |
self,
|
37 |
in_features: int,
|
38 |
out_features: int,
|
39 |
-
dtype: torch.dtype,
|
40 |
):
|
41 |
# TODO: Take group_size as an input instead of hardcoding it here.
|
42 |
super().__init__()
|
@@ -46,7 +46,7 @@ class QuantizedLinear(nn.Module):
|
|
46 |
{
|
47 |
"packed": nn.Parameter(
|
48 |
torch.empty(
|
49 |
-
out_features * in_features // (128 * 2), 128, dtype=
|
50 |
),
|
51 |
requires_grad=False,
|
52 |
),
|
|
|
36 |
self,
|
37 |
in_features: int,
|
38 |
out_features: int,
|
39 |
+
dtype: torch.dtype = torch.uint8,
|
40 |
):
|
41 |
# TODO: Take group_size as an input instead of hardcoding it here.
|
42 |
super().__init__()
|
|
|
46 |
{
|
47 |
"packed": nn.Parameter(
|
48 |
torch.empty(
|
49 |
+
out_features * in_features // (128 * 2), 128, dtype=dtype
|
50 |
),
|
51 |
requires_grad=False,
|
52 |
),
|
text.py
CHANGED
@@ -152,9 +152,8 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
|
152 |
return logits
|
153 |
|
154 |
|
155 |
-
def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
156 |
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
|
157 |
-
linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
|
158 |
|
159 |
text = nn.ModuleDict(
|
160 |
{
|
@@ -165,18 +164,18 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
165 |
"ln": nn.LayerNorm(config.dim, dtype=dtype),
|
166 |
"attn": nn.ModuleDict(
|
167 |
{
|
168 |
-
"qkv":
|
169 |
-
"proj":
|
170 |
config.dim, config.dim, dtype=dtype
|
171 |
),
|
172 |
}
|
173 |
),
|
174 |
"mlp": nn.ModuleDict(
|
175 |
{
|
176 |
-
"fc1":
|
177 |
config.dim, config.ff_dim, dtype=dtype
|
178 |
),
|
179 |
-
"fc2":
|
180 |
config.ff_dim, config.dim, dtype=dtype
|
181 |
),
|
182 |
}
|
|
|
152 |
return logits
|
153 |
|
154 |
|
155 |
+
def build_text_model(config: TextConfig, dtype: torch.dtype = torch.float16) -> nn.Module:
|
156 |
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
|
|
|
157 |
|
158 |
text = nn.ModuleDict(
|
159 |
{
|
|
|
164 |
"ln": nn.LayerNorm(config.dim, dtype=dtype),
|
165 |
"attn": nn.ModuleDict(
|
166 |
{
|
167 |
+
"qkv": QuantizedLinear(config.dim, qkv_dim, dtype=dtype),
|
168 |
+
"proj": QuantizedLinear(
|
169 |
config.dim, config.dim, dtype=dtype
|
170 |
),
|
171 |
}
|
172 |
),
|
173 |
"mlp": nn.ModuleDict(
|
174 |
{
|
175 |
+
"fc1": QuantizedLinear(
|
176 |
config.dim, config.ff_dim, dtype=dtype
|
177 |
),
|
178 |
+
"fc2": QuantizedLinear(
|
179 |
config.ff_dim, config.dim, dtype=dtype
|
180 |
),
|
181 |
}
|
weights.py
CHANGED
@@ -175,7 +175,7 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
|
175 |
linear_dtype = torch.int8 if is_quantized else torch.float16
|
176 |
|
177 |
model.text = build_text_model(
|
178 |
-
TextConfig
|
179 |
)
|
180 |
if model.setup_caches_flag:
|
181 |
model._setup_caches()
|
@@ -212,7 +212,7 @@ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
|
|
212 |
|
213 |
linear_dtype = torch.int8 if is_quantized else torch.float16
|
214 |
model.text = build_text_model(
|
215 |
-
TextConfig
|
216 |
)
|
217 |
if model.setup_caches_flag:
|
218 |
model._setup_caches()
|
|
|
175 |
linear_dtype = torch.int8 if is_quantized else torch.float16
|
176 |
|
177 |
model.text = build_text_model(
|
178 |
+
TextConfig
|
179 |
)
|
180 |
if model.setup_caches_flag:
|
181 |
model._setup_caches()
|
|
|
212 |
|
213 |
linear_dtype = torch.int8 if is_quantized else torch.float16
|
214 |
model.text = build_text_model(
|
215 |
+
TextConfig
|
216 |
)
|
217 |
if model.setup_caches_flag:
|
218 |
model._setup_caches()
|