snowclipsed
commited on
Commit
·
e15c30f
1
Parent(s):
8f87bef
remove is_quantized completely
Browse files- weights.py +17 -41
weights.py
CHANGED
@@ -6,9 +6,6 @@ import re
|
|
6 |
from contextlib import contextmanager
|
7 |
from typing import Callable, List
|
8 |
|
9 |
-
from .text import build_text_model
|
10 |
-
from .config import TextConfig
|
11 |
-
|
12 |
|
13 |
# Our custom linear has an module named linear, so we add linear to the name
|
14 |
def add_linear_to_key(k: str) -> str:
|
@@ -46,7 +43,6 @@ def safetensors_open(safetensors_file: str):
|
|
46 |
def _load_weights(
|
47 |
get_tensor: Callable[[str], torch.Tensor],
|
48 |
model: nn.Module,
|
49 |
-
is_quantized: bool = False,
|
50 |
) -> None:
|
51 |
"""Internal function to load weights using a tensor getter function."""
|
52 |
model = model.to(dtype=torch.float16)
|
@@ -111,42 +107,23 @@ def _load_weights(
|
|
111 |
}
|
112 |
)
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
{
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
)
|
132 |
-
else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight)
|
133 |
-
for i in range(len(model.text["blocks"])):
|
134 |
-
prefix = f"text_model.transformer.h.{i}"
|
135 |
-
blk = model.text["blocks"][i]
|
136 |
-
weight_map.update(
|
137 |
-
{
|
138 |
-
f"{prefix}.ln.qweight": blk["ln"].weight,
|
139 |
-
f"{prefix}.ln.bias": blk["ln"].bias,
|
140 |
-
f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight,
|
141 |
-
f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
|
142 |
-
f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight,
|
143 |
-
f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
|
144 |
-
f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight,
|
145 |
-
f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
|
146 |
-
f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight,
|
147 |
-
f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
|
148 |
-
}
|
149 |
-
)
|
150 |
|
151 |
for key, tensor in weight_map.items():
|
152 |
tensor.data.copy_(get_tensor(key))
|
@@ -175,7 +152,6 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
|
175 |
_load_weights(
|
176 |
lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
|
177 |
model,
|
178 |
-
# is_quantized,
|
179 |
)
|
180 |
|
181 |
|
|
|
6 |
from contextlib import contextmanager
|
7 |
from typing import Callable, List
|
8 |
|
|
|
|
|
|
|
9 |
|
10 |
# Our custom linear has an module named linear, so we add linear to the name
|
11 |
def add_linear_to_key(k: str) -> str:
|
|
|
43 |
def _load_weights(
|
44 |
get_tensor: Callable[[str], torch.Tensor],
|
45 |
model: nn.Module,
|
|
|
46 |
) -> None:
|
47 |
"""Internal function to load weights using a tensor getter function."""
|
48 |
model = model.to(dtype=torch.float16)
|
|
|
107 |
}
|
108 |
)
|
109 |
|
110 |
+
for i in range(len(model.text["blocks"])):
|
111 |
+
prefix = f"text_model.transformer.h.{i}"
|
112 |
+
blk = model.text["blocks"][i]
|
113 |
+
weight_map.update(
|
114 |
+
{
|
115 |
+
f"{prefix}.ln.weight": blk["ln"].weight,
|
116 |
+
f"{prefix}.ln.bias": blk["ln"].bias,
|
117 |
+
f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
|
118 |
+
f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
|
119 |
+
f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
|
120 |
+
f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
|
121 |
+
f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
|
122 |
+
f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
|
123 |
+
f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
|
124 |
+
f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
|
125 |
+
}
|
126 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
for key, tensor in weight_map.items():
|
129 |
tensor.data.copy_(get_tensor(key))
|
|
|
152 |
_load_weights(
|
153 |
lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
|
154 |
model,
|
|
|
155 |
)
|
156 |
|
157 |
|