cleanup-inference-code

#2
Files changed (4) hide show
  1. layers.py +7 -6
  2. moondream.py +12 -14
  3. text.py +8 -9
  4. weights.py +19 -77
layers.py CHANGED
@@ -36,26 +36,27 @@ class QuantizedLinear(nn.Module):
36
  self,
37
  in_features: int,
38
  out_features: int,
39
- dtype: torch.dtype,
 
40
  ):
41
- # TODO: Take group_size as an input instead of hardcoding it here.
42
  super().__init__()
43
  self.in_features = in_features
44
  self.out_features = out_features
 
45
  self.weight = nn.ParameterDict(
46
  {
47
  "packed": nn.Parameter(
48
  torch.empty(
49
- out_features * in_features // (128 * 2), 128, dtype=torch.uint8
50
  ),
51
  requires_grad=False,
52
  ),
53
  "scale": nn.Parameter(
54
- torch.empty(out_features * in_features // 128, 1),
55
  requires_grad=False,
56
  ),
57
  "zero_point": nn.Parameter(
58
- torch.empty(out_features * in_features // 128, 1),
59
  requires_grad=False,
60
  ),
61
  }
@@ -86,7 +87,7 @@ class QuantizedLinear(nn.Module):
86
  )
87
 
88
  del self.weight, self.bias
89
- quantize_(self, int4_weight_only(group_size=128))
90
  self.unpacked = True
91
  torch.cuda.empty_cache()
92
 
 
36
  self,
37
  in_features: int,
38
  out_features: int,
39
+ group_size: int = 128,
40
+ dtype: torch.dtype = torch.uint8,
41
  ):
 
42
  super().__init__()
43
  self.in_features = in_features
44
  self.out_features = out_features
45
+ self.group_size = group_size
46
  self.weight = nn.ParameterDict(
47
  {
48
  "packed": nn.Parameter(
49
  torch.empty(
50
+ out_features * in_features // (group_size * 2), group_size, dtype=dtype
51
  ),
52
  requires_grad=False,
53
  ),
54
  "scale": nn.Parameter(
55
+ torch.empty(out_features * in_features // group_size, 1),
56
  requires_grad=False,
57
  ),
58
  "zero_point": nn.Parameter(
59
+ torch.empty(out_features * in_features // group_size, 1),
60
  requires_grad=False,
61
  ),
62
  }
 
87
  )
88
 
89
  del self.weight, self.bias
90
+ quantize_(self, int4_weight_only(group_size=self.group_size))
91
  self.unpacked = True
92
  torch.cuda.empty_cache()
93
 
moondream.py CHANGED
@@ -77,38 +77,36 @@ class MoondreamModel(nn.Module):
77
  self.vision = build_vision_model(config.vision, dtype)
78
  self.text = build_text_model(config.text, dtype)
79
 
80
- # Region Model
81
- linear_cls = (
82
- QuantizedLinear if config.region.group_size is not None else nn.Linear
83
- )
84
  self.region = nn.ModuleDict(
85
  {
86
- "coord_encoder": linear_cls(
87
- config.region.coord_feat_dim, config.region.dim, dtype=dtype
88
  ),
89
  "coord_decoder": nn.ModuleDict(
90
  {
91
- "fc1": linear_cls(
92
- config.region.dim, config.region.inner_dim, dtype=dtype
93
  ),
94
- "fc2": linear_cls(
95
  config.region.inner_dim,
96
  config.region.coord_out_dim,
 
97
  dtype=dtype,
98
  ),
99
  }
100
  ),
101
- "size_encoder": linear_cls(
102
- config.region.size_feat_dim, config.region.dim, dtype=dtype
103
  ),
104
  "size_decoder": nn.ModuleDict(
105
  {
106
- "fc1": linear_cls(
107
- config.region.dim, config.region.inner_dim, dtype=dtype
108
  ),
109
- "fc2": linear_cls(
110
  config.region.inner_dim,
111
  config.region.size_out_dim,
 
112
  dtype=dtype,
113
  ),
114
  }
 
77
  self.vision = build_vision_model(config.vision, dtype)
78
  self.text = build_text_model(config.text, dtype)
79
 
 
 
 
 
80
  self.region = nn.ModuleDict(
81
  {
82
+ "coord_encoder": QuantizedLinear(
83
+ config.region.coord_feat_dim, config.region.dim, group_size=config.text.group_size, dtype=dtype
84
  ),
85
  "coord_decoder": nn.ModuleDict(
86
  {
87
+ "fc1": QuantizedLinear(
88
+ config.region.dim, config.region.inner_dim, group_size=config.text.group_size, dtype=dtype
89
  ),
90
+ "fc2": QuantizedLinear(
91
  config.region.inner_dim,
92
  config.region.coord_out_dim,
93
+ group_size=config.text.group_size,
94
  dtype=dtype,
95
  ),
96
  }
97
  ),
98
+ "size_encoder": QuantizedLinear(
99
+ config.region.size_feat_dim, config.region.dim, group_size=config.text.group_size, dtype=dtype
100
  ),
101
  "size_decoder": nn.ModuleDict(
102
  {
103
+ "fc1": QuantizedLinear(
104
+ config.region.dim, config.region.inner_dim, group_size=config.text.group_size, dtype=dtype
105
  ),
106
+ "fc2": QuantizedLinear(
107
  config.region.inner_dim,
108
  config.region.size_out_dim,
109
+ group_size=config.text.group_size,
110
  dtype=dtype,
111
  ),
112
  }
text.py CHANGED
@@ -152,9 +152,8 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
152
  return logits
153
 
154
 
155
- def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
156
  qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
157
- linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
158
 
159
  text = nn.ModuleDict(
160
  {
@@ -165,19 +164,19 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
165
  "ln": nn.LayerNorm(config.dim, dtype=dtype),
166
  "attn": nn.ModuleDict(
167
  {
168
- "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
169
- "proj": linear_cls(
170
- config.dim, config.dim, dtype=dtype
171
  ),
172
  }
173
  ),
174
  "mlp": nn.ModuleDict(
175
  {
176
- "fc1": linear_cls(
177
- config.dim, config.ff_dim, dtype=dtype
178
  ),
179
- "fc2": linear_cls(
180
- config.ff_dim, config.dim, dtype=dtype
181
  ),
182
  }
183
  ),
 
152
  return logits
153
 
154
 
155
+ def build_text_model(config: TextConfig, dtype: torch.dtype = torch.float16) -> nn.Module:
156
  qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
 
157
 
158
  text = nn.ModuleDict(
159
  {
 
164
  "ln": nn.LayerNorm(config.dim, dtype=dtype),
165
  "attn": nn.ModuleDict(
166
  {
167
+ "qkv": QuantizedLinear(config.dim, qkv_dim, group_size=config.group_size, dtype=dtype),
168
+ "proj": QuantizedLinear(
169
+ config.dim, config.dim, group_size=config.group_size, dtype=dtype
170
  ),
171
  }
172
  ),
173
  "mlp": nn.ModuleDict(
174
  {
175
+ "fc1": QuantizedLinear(
176
+ config.dim, config.ff_dim, group_size=config.group_size, dtype=dtype
177
  ),
178
+ "fc2": QuantizedLinear(
179
+ config.ff_dim, config.dim, group_size=config.group_size, dtype=dtype
180
  ),
181
  }
182
  ),
weights.py CHANGED
@@ -6,9 +6,6 @@ import re
6
  from contextlib import contextmanager
7
  from typing import Callable, List
8
 
9
- from .text import build_text_model
10
- from .config import TextConfig
11
-
12
 
13
  # Our custom linear has an module named linear, so we add linear to the name
14
  def add_linear_to_key(k: str) -> str:
@@ -46,7 +43,6 @@ def safetensors_open(safetensors_file: str):
46
  def _load_weights(
47
  get_tensor: Callable[[str], torch.Tensor],
48
  model: nn.Module,
49
- is_quantized: bool = False,
50
  ) -> None:
51
  """Internal function to load weights using a tensor getter function."""
52
  model = model.to(dtype=torch.float16)
@@ -111,42 +107,23 @@ def _load_weights(
111
  }
112
  )
113
 
114
- if not is_quantized:
115
- for i in range(len(model.text["blocks"])):
116
- prefix = f"text_model.transformer.h.{i}"
117
- blk = model.text["blocks"][i]
118
- weight_map.update(
119
- {
120
- f"{prefix}.ln.weight": blk["ln"].weight,
121
- f"{prefix}.ln.bias": blk["ln"].bias,
122
- f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
123
- f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
124
- f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
125
- f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
126
- f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
127
- f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
128
- f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
129
- f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
130
- }
131
- )
132
- else: # add special quantized path. this is specific to how bitblas expects weights to be loaded (.qweight)
133
- for i in range(len(model.text["blocks"])):
134
- prefix = f"text_model.transformer.h.{i}"
135
- blk = model.text["blocks"][i]
136
- weight_map.update(
137
- {
138
- f"{prefix}.ln.qweight": blk["ln"].weight,
139
- f"{prefix}.ln.bias": blk["ln"].bias,
140
- f"{prefix}.mixer.Wqkv.qweight": blk["attn"]["qkv"].weight,
141
- f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
142
- f"{prefix}.mixer.out_proj.qweight": blk["attn"]["proj"].weight,
143
- f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
144
- f"{prefix}.mlp.fc1.qweight": blk["mlp"]["fc1"].weight,
145
- f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
146
- f"{prefix}.mlp.fc2.qweight": blk["mlp"]["fc2"].weight,
147
- f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
148
- }
149
- )
150
 
151
  for key, tensor in weight_map.items():
152
  tensor.data.copy_(get_tensor(key))
@@ -162,24 +139,6 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
162
  with safetensors_open(weights_file) as get_tensor:
163
  all_keys = get_tensor.keys()
164
 
165
- is_quantized = any(
166
- ".qweight" in key or "_quantized" in key or "quant." in key
167
- for key in all_keys
168
- )
169
-
170
- if "text_model.transformer.h.0.ln.weight" in all_keys:
171
- layernorm_dtype = get_tensor("text_model.transformer.h.0.ln.weight").dtype
172
- else:
173
- layernorm_dtype = torch.float16
174
-
175
- linear_dtype = torch.int8 if is_quantized else torch.float16
176
-
177
- model.text = build_text_model(
178
- TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype
179
- )
180
- if model.setup_caches_flag:
181
- model._setup_caches()
182
-
183
  if (
184
  "vision.blocks.0.attn.proj.bias" in all_keys
185
  or "model.vision.blocks.0.attn.proj.bias" in all_keys
@@ -193,7 +152,6 @@ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
193
  _load_weights(
194
  lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
195
  model,
196
- is_quantized,
197
  )
198
 
199
 
@@ -201,22 +159,6 @@ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
201
  """Load weights from a PyTorch file into a MoondreamModel instance."""
202
  tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
203
  all_keys = tensors.keys()
204
- is_quantized = any(
205
- ".qweight" in key or "_quantized" in key or "quant." in key for key in all_keys
206
- )
207
-
208
- if "text.blocks.0.ln.weight" in all_keys:
209
- layernorm_dtype = tensors["text.blocks.0.ln.weight"].dtype
210
- else:
211
- layernorm_dtype = torch.float16
212
-
213
- linear_dtype = torch.int8 if is_quantized else torch.float16
214
- model.text = build_text_model(
215
- TextConfig, linear_dtype=linear_dtype, layernorm_dtype=layernorm_dtype
216
- )
217
- if model.setup_caches_flag:
218
- model._setup_caches()
219
-
220
  if (
221
  "vision.blocks.0.attn.proj.bias" in all_keys
222
  or "model.vision.blocks.0.attn.proj.bias" in all_keys
@@ -228,7 +170,7 @@ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
228
  k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
229
  for k, v in tensors.items()
230
  }
231
- _load_weights(lambda x: tensors[x], model, is_quantized)
232
 
233
 
234
  def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
@@ -246,4 +188,4 @@ def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
246
 
247
  # Make all parameters contiguous
248
  for param in model.parameters():
249
- param.data = param.data.contiguous()
 
6
  from contextlib import contextmanager
7
  from typing import Callable, List
8
 
 
 
 
9
 
10
  # Our custom linear has an module named linear, so we add linear to the name
11
  def add_linear_to_key(k: str) -> str:
 
43
  def _load_weights(
44
  get_tensor: Callable[[str], torch.Tensor],
45
  model: nn.Module,
 
46
  ) -> None:
47
  """Internal function to load weights using a tensor getter function."""
48
  model = model.to(dtype=torch.float16)
 
107
  }
108
  )
109
 
110
+ for i in range(len(model.text["blocks"])):
111
+ prefix = f"text_model.transformer.h.{i}"
112
+ blk = model.text["blocks"][i]
113
+ weight_map.update(
114
+ {
115
+ f"{prefix}.ln.weight": blk["ln"].weight,
116
+ f"{prefix}.ln.bias": blk["ln"].bias,
117
+ f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
118
+ f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
119
+ f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
120
+ f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
121
+ f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
122
+ f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
123
+ f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
124
+ f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
125
+ }
126
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  for key, tensor in weight_map.items():
129
  tensor.data.copy_(get_tensor(key))
 
139
  with safetensors_open(weights_file) as get_tensor:
140
  all_keys = get_tensor.keys()
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  if (
143
  "vision.blocks.0.attn.proj.bias" in all_keys
144
  or "model.vision.blocks.0.attn.proj.bias" in all_keys
 
152
  _load_weights(
153
  lambda x: get_tensor(name_map[x]).to(dtype=torch.float16),
154
  model,
 
155
  )
156
 
157
 
 
159
  """Load weights from a PyTorch file into a MoondreamModel instance."""
160
  tensors = torch.load(weights_file, map_location="cpu", weights_only=True)
161
  all_keys = tensors.keys()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  if (
163
  "vision.blocks.0.attn.proj.bias" in all_keys
164
  or "model.vision.blocks.0.attn.proj.bias" in all_keys
 
170
  k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
171
  for k, v in tensors.items()
172
  }
173
+ _load_weights(lambda x: tensors[x], model)
174
 
175
 
176
  def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
 
188
 
189
  # Make all parameters contiguous
190
  for param in model.parameters():
191
+ param.data = param.data.contiguous()