snowclipsed commited on
Commit
8cb4b5e
·
1 Parent(s): e15c30f

remove gsize check, replace with quantized_linear directly

Browse files
Files changed (1) hide show
  1. moondream.py +6 -10
moondream.py CHANGED
@@ -77,36 +77,32 @@ class MoondreamModel(nn.Module):
77
  self.vision = build_vision_model(config.vision, dtype)
78
  self.text = build_text_model(config.text, dtype)
79
 
80
- # Region Model
81
- linear_cls = (
82
- QuantizedLinear if config.region.group_size is not None else nn.Linear
83
- )
84
  self.region = nn.ModuleDict(
85
  {
86
- "coord_encoder": linear_cls(
87
  config.region.coord_feat_dim, config.region.dim, dtype=dtype
88
  ),
89
  "coord_decoder": nn.ModuleDict(
90
  {
91
- "fc1": linear_cls(
92
  config.region.dim, config.region.inner_dim, dtype=dtype
93
  ),
94
- "fc2": linear_cls(
95
  config.region.inner_dim,
96
  config.region.coord_out_dim,
97
  dtype=dtype,
98
  ),
99
  }
100
  ),
101
- "size_encoder": linear_cls(
102
  config.region.size_feat_dim, config.region.dim, dtype=dtype
103
  ),
104
  "size_decoder": nn.ModuleDict(
105
  {
106
- "fc1": linear_cls(
107
  config.region.dim, config.region.inner_dim, dtype=dtype
108
  ),
109
- "fc2": linear_cls(
110
  config.region.inner_dim,
111
  config.region.size_out_dim,
112
  dtype=dtype,
 
77
  self.vision = build_vision_model(config.vision, dtype)
78
  self.text = build_text_model(config.text, dtype)
79
 
 
 
 
 
80
  self.region = nn.ModuleDict(
81
  {
82
+ "coord_encoder": QuantizedLinear(
83
  config.region.coord_feat_dim, config.region.dim, dtype=dtype
84
  ),
85
  "coord_decoder": nn.ModuleDict(
86
  {
87
+ "fc1": QuantizedLinear(
88
  config.region.dim, config.region.inner_dim, dtype=dtype
89
  ),
90
+ "fc2": QuantizedLinear(
91
  config.region.inner_dim,
92
  config.region.coord_out_dim,
93
  dtype=dtype,
94
  ),
95
  }
96
  ),
97
+ "size_encoder": QuantizedLinear(
98
  config.region.size_feat_dim, config.region.dim, dtype=dtype
99
  ),
100
  "size_decoder": nn.ModuleDict(
101
  {
102
+ "fc1": QuantizedLinear(
103
  config.region.dim, config.region.inner_dim, dtype=dtype
104
  ),
105
+ "fc2": QuantizedLinear(
106
  config.region.inner_dim,
107
  config.region.size_out_dim,
108
  dtype=dtype,