cleanup-inference-code
#2
by
snowclipsed
- opened
- layers.py +7 -6
- moondream.py +12 -14
- text.py +8 -9
- weights.py +19 -77
layers.py
CHANGED
@@ -36,26 +36,27 @@ class QuantizedLinear(nn.Module):
|
|
36 |
self,
|
37 |
in_features: int,
|
38 |
out_features: int,
|
39 |
-
|
|
|
40 |
):
|
41 |
-
# TODO: Take group_size as an input instead of hardcoding it here.
|
42 |
super().__init__()
|
43 |
self.in_features = in_features
|
44 |
self.out_features = out_features
|
|
|
45 |
self.weight = nn.ParameterDict(
|
46 |
{
|
47 |
"packed": nn.Parameter(
|
48 |
torch.empty(
|
49 |
-
out_features * in_features // (
|
50 |
),
|
51 |
requires_grad=False,
|
52 |
),
|
53 |
"scale": nn.Parameter(
|
54 |
-
torch.empty(out_features * in_features //
|
55 |
requires_grad=False,
|
56 |
),
|
57 |
"zero_point": nn.Parameter(
|
58 |
-
torch.empty(out_features * in_features //
|
59 |
requires_grad=False,
|
60 |
),
|
61 |
}
|
@@ -86,7 +87,7 @@ class QuantizedLinear(nn.Module):
|
|
86 |
)
|
87 |
|
88 |
del self.weight, self.bias
|
89 |
-
quantize_(self, int4_weight_only(group_size=
|
90 |
self.unpacked = True
|
91 |
torch.cuda.empty_cache()
|
92 |
|
|
|
36 |
self,
|
37 |
in_features: int,
|
38 |
out_features: int,
|
39 |
+
group_size: int = 128,
|
40 |
+
dtype: torch.dtype = torch.uint8,
|
41 |
):
|
|
|
42 |
super().__init__()
|
43 |
self.in_features = in_features
|
44 |
self.out_features = out_features
|
45 |
+
self.group_size = group_size
|
46 |
self.weight = nn.ParameterDict(
|
47 |
{
|
48 |
"packed": nn.Parameter(
|
49 |
torch.empty(
|
50 |
+
out_features * in_features // (group_size * 2), group_size, dtype=dtype
|
51 |
),
|
52 |
requires_grad=False,
|
53 |
),
|
54 |
"scale": nn.Parameter(
|
55 |
+
torch.empty(out_features * in_features // group_size, 1),
|
56 |
requires_grad=False,
|
57 |
),
|
58 |
"zero_point": nn.Parameter(
|
59 |
+
torch.empty(out_features * in_features // group_size, 1),
|
60 |
requires_grad=False,
|
61 |
),
|
62 |
}
|
|
|
87 |
)
|
88 |
|
89 |
del self.weight, self.bias
|
90 |
+
quantize_(self, int4_weight_only(group_size=self.group_size))
|
91 |
self.unpacked = True
|
92 |
torch.cuda.empty_cache()
|
93 |
|
moondream.py
CHANGED
@@ -77,38 +77,36 @@ class MoondreamModel(nn.Module):
|
|
77 |
self.vision = build_vision_model(config.vision, dtype)
|
78 |
self.text = build_text_model(config.text, dtype)
|
79 |
|
80 |
-
# Region Model
|
81 |
-
linear_cls = (
|
82 |
-
QuantizedLinear if config.region.group_size is not None else nn.Linear
|
83 |
-
)
|
84 |
self.region = nn.ModuleDict(
|
85 |
{
|
86 |
-
"coord_encoder":
|
87 |
-
config.region.coord_feat_dim, config.region.dim, dtype=dtype
|
88 |
),
|
89 |
"coord_decoder": nn.ModuleDict(
|
90 |
{
|
91 |
-
"fc1":
|
92 |
-
config.region.dim, config.region.inner_dim, dtype=dtype
|
93 |
),
|
94 |
-
"fc2":
|
95 |
config.region.inner_dim,
|
96 |
config.region.coord_out_dim,
|
|
|
97 |
dtype=dtype,
|
98 |
),
|
99 |
}
|
100 |
),
|
101 |
-
"size_encoder":
|
102 |
-
config.region.size_feat_dim, config.region.dim, dtype=dtype
|
103 |
),
|
104 |
"size_decoder": nn.ModuleDict(
|
105 |
{
|
106 |
-
"fc1":
|
107 |
-
config.region.dim, config.region.inner_dim, dtype=dtype
|
108 |
),
|
109 |
-
"fc2":
|
110 |
config.region.inner_dim,
|
111 |
config.region.size_out_dim,
|
|
|
112 |
dtype=dtype,
|
113 |
),
|
114 |
}
|
|
|
77 |
self.vision = build_vision_model(config.vision, dtype)
|
78 |
self.text = build_text_model(config.text, dtype)
|
79 |
|
|
|
|
|
|
|
|
|
80 |
self.region = nn.ModuleDict(
|
81 |
{
|
82 |
+
"coord_encoder": QuantizedLinear(
|
83 |
+
config.region.coord_feat_dim, config.region.dim, group_size=config.text.group_size, dtype=dtype
|
84 |
),
|
85 |
"coord_decoder": nn.ModuleDict(
|
86 |
{
|
87 |
+
"fc1": QuantizedLinear(
|
88 |
+
config.region.dim, config.region.inner_dim, group_size=config.text.group_size, dtype=dtype
|
89 |
),
|
90 |
+
"fc2": QuantizedLinear(
|
91 |
config.region.inner_dim,
|
92 |
config.region.coord_out_dim,
|
93 |
+
group_size=config.text.group_size,
|
94 |
dtype=dtype,
|
95 |
),
|
96 |
}
|
97 |
),
|
98 |
+
"size_encoder": QuantizedLinear(
|
99 |
+
config.region.size_feat_dim, config.region.dim, group_size=config.text.group_size, dtype=dtype
|
100 |
),
|
101 |
"size_decoder": nn.ModuleDict(
|
102 |
{
|
103 |
+
"fc1": QuantizedLinear(
|
104 |
+
config.region.dim, config.region.inner_dim, group_size=config.text.group_size, dtype=dtype
|
105 |
),
|
106 |
+
"fc2": QuantizedLinear(
|
107 |
config.region.inner_dim,
|
108 |
config.region.size_out_dim,
|
109 |
+
group_size=config.text.group_size,
|
110 |
dtype=dtype,
|
111 |
),
|
112 |
}
|
text.py
CHANGED
@@ -152,9 +152,8 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
|
152 |
return logits
|
153 |
|
154 |
|
155 |
-
def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
156 |
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
|
157 |
-
linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
|
158 |
|
159 |
text = nn.ModuleDict(
|
160 |
{
|
@@ -165,19 +164,19 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
165 |
"ln": nn.LayerNorm(config.dim, dtype=dtype),
|
166 |
"attn": nn.ModuleDict(
|
167 |
{
|
168 |
-
"qkv":
|
169 |
-
"proj":
|
170 |
-
config.dim, config.dim, dtype=dtype
|
171 |
),
|
172 |
}
|
173 |
),
|
174 |
"mlp": nn.ModuleDict(
|
175 |
{
|
176 |
-
"fc1":
|
177 |
-
config.dim, config.ff_dim, dtype=dtype
|
178 |
),
|
179 |
-
"fc2":
|
180 |
-
config.ff_dim, config.dim, dtype=dtype
|
181 |
),
|
182 |
}
|
183 |
),
|
|
|
152 |
return logits
|
153 |
|
154 |
|
155 |
+
def build_text_model(config: TextConfig, dtype: torch.dtype = torch.float16) -> nn.Module:
|
156 |
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
|
|
|
157 |
|
158 |
text = nn.ModuleDict(
|
159 |
{
|
|
|
164 |
"ln": nn.LayerNorm(config.dim, dtype=dtype),
|
165 |
"attn": nn.ModuleDict(
|
166 |
{
|
167 |
+
"qkv": QuantizedLinear(config.dim, qkv_dim, group_size=config.group_size, dtype=dtype),
|
168 |
+
"proj": QuantizedLinear(
|
169 |
+
config.dim, config.dim, group_size=config.group_size, dtype=dtype
|
170 |
),
|
171 |
}
|
172 |
),
|
173 |
"mlp": nn.ModuleDict(
|
174 |
{
|
175 |
+
"fc1": QuantizedLinear(
|
176 |
+
config.dim, config.ff_dim, group_size=config.group_size, dtype=dtype
|
177 |
),
|
178 |
+
"fc2": QuantizedLinear(
|
179 |
+
config.ff_dim, config.dim, group_size=config.group_size, dtype=dtype
|
180 |
),
|
181 |
}
|
182 |
),
|
weights.py
CHANGED
@@ -6,9 +6,6 @@ import re
|
|
6 |
from contextlib import contextmanager
|
7 |
from typing import Callable, List
|
8 |
|
9 |
-
from .text import build_text_model
|
10 |
-
from .config import TextConfig
|
11 |
-
|
12 |
|
13 |
# Our custom linear has an module named linear, so we add linear to the name
|
14 |
def add_linear_to_key(k: str) -> str:
|
@@ -46,7 +43,6 @@ def safetensors_open(safetensors_file: str):
|
|
46 |
def _load_weights(
|
47 |
get_tensor: Callable[[str], torch.Tensor],
|
48 |
model: nn.Module,
|
49 |
-
is_quantized: bool = False,
|
50 |
) -> None:
|
51 |
"""Internal function to load weights using a tensor getter function."""
|
52 |
model = model.to(dtype=torch.float16)
|
@@ -111,42 +107,23 @@ def _load_weights(
|
|
111 |
}
|
112 |
)
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
{
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
)
|
132 |
-
else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight)
|
133 |
-
for i in range(len(model.text["blocks"])):
|
134 |
-
prefix = f"text_model.transformer.h.{i}"
|
135 |
-
blk = model.text["blocks"][i]
|
136 |
-
weight_map.update(
|
137 |
-
{
|
138 |
-
f"{prefix}.ln.qweight": blk["ln"].weight,
|
139 |
-
f"{prefix}.ln.bias": blk["ln"].bias,
|
140 |
-
f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight,
|
141 |
-
f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
|
142 |
-
f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight,
|
143 |
-
f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
|
144 |
-
f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight,
|
145 |
-
f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
|
146 |
-
f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight,
|
147 |
-
f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
|
148 |
-
}
|
149 |
-
)
|
150 |
|
151 |
for key, tensor in weight_map.items():
|
152 |
tensor.data.copy_(get_tensor(key))
|
@@ -162,24 +139,6 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
|
162 |
with safetensors_open(weights_file) as get_tensor:
|
163 |
all_keys = get_tensor.keys()
|
164 |
|
165 |
-
is_quantized = any(
|
166 |
-
".qweight" in key or "_quantized" in key or "quant." in key
|
167 |
-
for key in all_keys
|
168 |
-
)
|
169 |
-
|
170 |
-
if "text_model.transformer.h.0.ln.weight" in all_keys:
|
171 |
-
layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype
|
172 |
-
else:
|
173 |
-
layernorm_dtype = torch.float16
|
174 |
-
|
175 |
-
linear_dtype = torch.int8 if is_quantized else torch.float16
|
176 |
-
|
177 |
-
model.text = build_text_model(
|
178 |
-
TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype
|
179 |
-
)
|
180 |
-
if model.setup_caches_flag:
|
181 |
-
model._setup_caches()
|
182 |
-
|
183 |
if (
|
184 |
"vision.blocks.0.attn.proj.bias" in all_keys
|
185 |
or "model.vision.blocks.0.attn.proj.bias" in all_keys
|
@@ -193,7 +152,6 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
|
193 |
_load_weights(
|
194 |
lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
|
195 |
model,
|
196 |
-
is_quantized,
|
197 |
)
|
198 |
|
199 |
|
@@ -201,22 +159,6 @@ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
|
|
201 |
"""Load weights from a PyTorch file into a MoondreamModel instance."""
|
202 |
tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
|
203 |
all_keys = tensors.keys()
|
204 |
-
is_quantized = any(
|
205 |
-
".qweight" in key or "_quantized" in key or "quant." in key for key in all_keys
|
206 |
-
)
|
207 |
-
|
208 |
-
if "text.blocks.0.ln.weight" in all_keys:
|
209 |
-
layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype
|
210 |
-
else:
|
211 |
-
layernorm_dtype = torch.float16
|
212 |
-
|
213 |
-
linear_dtype = torch.int8 if is_quantized else torch.float16
|
214 |
-
model.text = build_text_model(
|
215 |
-
TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype
|
216 |
-
)
|
217 |
-
if model.setup_caches_flag:
|
218 |
-
model._setup_caches()
|
219 |
-
|
220 |
if (
|
221 |
"vision.blocks.0.attn.proj.bias" in all_keys
|
222 |
or "model.vision.blocks.0.attn.proj.bias" in all_keys
|
@@ -228,7 +170,7 @@ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
|
|
228 |
k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
|
229 |
for k, v in tensors.items()
|
230 |
}
|
231 |
-
_load_weights(lambda x: tensors[x], model
|
232 |
|
233 |
|
234 |
def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
|
@@ -246,4 +188,4 @@ def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
|
|
246 |
|
247 |
# Make all parameters contiguous
|
248 |
for param in model.parameters():
|
249 |
-
param.data = param.data.contiguous()
|
|
|
6 |
from contextlib import contextmanager
|
7 |
from typing import Callable, List
|
8 |
|
|
|
|
|
|
|
9 |
|
10 |
# Our custom linear has an module named linear, so we add linear to the name
|
11 |
def add_linear_to_key(k: str) -> str:
|
|
|
43 |
def _load_weights(
|
44 |
get_tensor: Callable[[str], torch.Tensor],
|
45 |
model: nn.Module,
|
|
|
46 |
) -> None:
|
47 |
"""Internal function to load weights using a tensor getter function."""
|
48 |
model = model.to(dtype=torch.float16)
|
|
|
107 |
}
|
108 |
)
|
109 |
|
110 |
+
for i in range(len(model.text["blocks"])):
|
111 |
+
prefix = f"text_model.transformer.h.{i}"
|
112 |
+
blk = model.text["blocks"][i]
|
113 |
+
weight_map.update(
|
114 |
+
{
|
115 |
+
f"{prefix}.ln.weight": blk["ln"].weight,
|
116 |
+
f"{prefix}.ln.bias": blk["ln"].bias,
|
117 |
+
f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
|
118 |
+
f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
|
119 |
+
f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
|
120 |
+
f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
|
121 |
+
f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
|
122 |
+
f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
|
123 |
+
f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
|
124 |
+
f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
|
125 |
+
}
|
126 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
for key, tensor in weight_map.items():
|
129 |
tensor.data.copy_(get_tensor(key))
|
|
|
139 |
with safetensors_open(weights_file) as get_tensor:
|
140 |
all_keys = get_tensor.keys()
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
if (
|
143 |
"vision.blocks.0.attn.proj.bias" in all_keys
|
144 |
or "model.vision.blocks.0.attn.proj.bias" in all_keys
|
|
|
152 |
_load_weights(
|
153 |
lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
|
154 |
model,
|
|
|
155 |
)
|
156 |
|
157 |
|
|
|
159 |
"""Load weights from a PyTorch file into a MoondreamModel instance."""
|
160 |
tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
|
161 |
all_keys = tensors.keys()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
if (
|
163 |
"vision.blocks.0.attn.proj.bias" in all_keys
|
164 |
or "model.vision.blocks.0.attn.proj.bias" in all_keys
|
|
|
170 |
k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
|
171 |
for k, v in tensors.items()
|
172 |
}
|
173 |
+
_load_weights(lambda x: tensors[x], model)
|
174 |
|
175 |
|
176 |
def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
|
|
|
188 |
|
189 |
# Make all parameters contiguous
|
190 |
for param in model.parameters():
|
191 |
+
param.data = param.data.contiguous()
|