tclf90 commited on
Commit
0928ede
·
verified ·
1 Parent(s): 7da9560

Delete awq_marlin.py

Browse files
Files changed (1) hide show
  1. awq_marlin.py +0 -526
awq_marlin.py DELETED
@@ -1,526 +0,0 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
-
4
- from typing import Any, Callable, Optional
5
-
6
- import torch
7
- from torch.nn import Parameter
8
-
9
- import vllm.model_executor.layers.fused_moe # noqa
10
- from vllm import _custom_ops as ops
11
- from vllm.logger import init_logger
12
- from vllm.model_executor.layers.fused_moe.layer import (
13
- FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
14
- UnquantizedFusedMoEMethod)
15
- from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
16
- UnquantizedLinearMethod,
17
- set_weight_attrs)
18
- from vllm.model_executor.layers.quantization import QuantizationMethods
19
- from vllm.model_executor.layers.quantization.awq import (AWQConfig,
20
- is_layer_skipped_awq)
21
- from vllm.model_executor.layers.quantization.base_config import (
22
- QuantizationConfig, QuantizeMethodBase)
23
- from vllm.model_executor.layers.quantization.utils import replace_parameter
24
- from vllm.model_executor.layers.quantization.utils.marlin_utils import (
25
- apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
26
- check_marlin_supports_layer, check_moe_marlin_supports_layer,
27
- marlin_make_empty_g_idx, marlin_make_workspace_new,
28
- marlin_moe_permute_scales, marlin_permute_scales,
29
- moe_awq_to_marlin_zero_points, verify_marlin_supported,
30
- verify_marlin_supports_shape)
31
- from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
32
- from vllm.model_executor.parameter import (GroupQuantScaleParameter,
33
- PackedvLLMParameter)
34
- from vllm.platforms import current_platform
35
- from vllm.scalar_type import scalar_types
36
-
37
- logger = init_logger(__name__)
38
-
39
-
40
- class AWQMarlinConfig(QuantizationConfig):
41
- """Config class for AWQ Marlin"""
42
-
43
- # num_bits -> type
44
- TYPE_MAP = {
45
- 4: scalar_types.uint4,
46
- 8: scalar_types.uint8,
47
- }
48
-
49
- def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
50
- lm_head_quantized: bool,
51
- modules_to_not_convert: Optional[list[str]],
52
- full_config: dict[str, Any]) -> None:
53
- super().__init__()
54
- self.pack_factor = 32 // weight_bits # packed into int32
55
- self.group_size = group_size
56
- self.zero_point = zero_point
57
- self.lm_head_quantized = lm_head_quantized
58
- self.weight_bits = weight_bits
59
- self.modules_to_not_convert = modules_to_not_convert or []
60
- self.full_config = full_config
61
-
62
- if self.weight_bits not in self.TYPE_MAP:
63
- raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
64
- f"Supported num_bits = {self.TYPE_MAP.keys()}")
65
-
66
- self.quant_type = self.TYPE_MAP[self.weight_bits]
67
-
68
- verify_marlin_supported(self.quant_type,
69
- group_size=self.group_size,
70
- has_zp=self.zero_point)
71
-
72
- def __repr__(self) -> str:
73
- return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
74
- f"group_size={self.group_size}, "
75
- f"zero_point={self.zero_point}, "
76
- f"lm_head_quantized={self.lm_head_quantized}, "
77
- f"modules_to_not_convert={self.modules_to_not_convert})")
78
-
79
- @classmethod
80
- def get_name(cls) -> QuantizationMethods:
81
- return "awq_marlin"
82
-
83
- @classmethod
84
- def get_supported_act_dtypes(cls) -> list[torch.dtype]:
85
- return [torch.half, torch.bfloat16]
86
-
87
- @classmethod
88
- def get_min_capability(cls) -> int:
89
- return 80
90
-
91
- @classmethod
92
- def get_config_filenames(cls) -> list[str]:
93
- return ["quantize_config.json"]
94
-
95
- @classmethod
96
- def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":
97
- weight_bits = cls.get_from_keys(config, ["bits"])
98
- group_size = cls.get_from_keys(config, ["group_size"])
99
- zero_point = cls.get_from_keys(config, ["zero_point"])
100
- lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
101
- default=False)
102
- modules_to_not_convert = cls.get_from_keys_or(
103
- config, ["modules_to_not_convert"], None)
104
- return cls(weight_bits, group_size, zero_point, lm_head_quantized,
105
- modules_to_not_convert, config)
106
-
107
- @classmethod
108
- def override_quantization_method(
109
- cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
110
- can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
111
- is_valid_user_quant = (user_quant is None or user_quant == "marlin"
112
- or user_quant == "awq_marlin")
113
-
114
- if can_convert and is_valid_user_quant:
115
- msg = ("The model is convertible to {} during runtime."
116
- " Using {} kernel.".format(cls.get_name(), cls.get_name()))
117
- logger.info(msg)
118
- return cls.get_name()
119
-
120
- if can_convert and user_quant == "awq":
121
- logger.info("Detected that the model can run with awq_marlin"
122
- ", however you specified quantization=awq explicitly,"
123
- " so forcing awq. Use quantization=awq_marlin for"
124
- " faster inference")
125
- return None
126
-
127
- def get_quant_method(self, layer: torch.nn.Module,
128
- prefix: str) -> Optional["QuantizeMethodBase"]:
129
- if (isinstance(layer, LinearBase) or
130
- (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
131
- if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
132
- return UnquantizedLinearMethod()
133
- # Check if the layer is supported by AWQMarlin.
134
- if not check_marlin_supports_layer(layer, self.group_size):
135
- logger.warning_once(
136
- "Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
137
- prefix,
138
- )
139
- return AWQConfig.from_config(
140
- self.full_config).get_quant_method(layer, prefix)
141
- return AWQMarlinLinearMethod(self)
142
- elif isinstance(layer, FusedMoE):
143
- if is_layer_skipped_awq(prefix, getattr(self, "modules_to_not_convert", [])):
144
- return UnquantizedFusedMoEMethod(layer.moe_config)
145
- from vllm.model_executor.layers.quantization.moe_wna16 import (
146
- MoeWNA16Config)
147
- if not check_moe_marlin_supports_layer(layer, self.group_size):
148
- logger.warning_once(
149
- f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
150
- "Falling back to Moe WNA16 kernels.")
151
- return MoeWNA16Config.from_config(
152
- self.full_config).get_quant_method(layer, prefix)
153
- return AWQMoEMethod(self)
154
- return None
155
-
156
- @classmethod
157
- def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
158
- # Extract data from quant config.
159
- quant_method = quant_config.get("quant_method", "").lower()
160
- num_bits = quant_config.get("bits")
161
- group_size = quant_config.get("group_size")
162
- zero_point = quant_config.get("zero_point")
163
-
164
- if not current_platform.is_cuda():
165
- return False
166
-
167
- if quant_method != "awq":
168
- return False
169
-
170
- # If we cannot find the info needed in the config, cannot convert.
171
- if (num_bits is None or group_size is None or zero_point is None):
172
- return False
173
-
174
- if num_bits not in cls.TYPE_MAP:
175
- return False
176
-
177
- return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
178
- group_size=group_size,
179
- has_zp=zero_point)
180
-
181
-
182
- class AWQMarlinLinearMethod(LinearMethodBase):
183
- """Linear method for AWQ Marlin.
184
-
185
- Args:
186
- quant_config: The AWQ Marlin quantization config.
187
- """
188
-
189
- def __init__(self, quant_config: AWQMarlinConfig) -> None:
190
- self.quant_config = quant_config
191
-
192
- def create_weights(
193
- self,
194
- layer: torch.nn.Module,
195
- input_size_per_partition: int,
196
- output_partition_sizes: list[int],
197
- input_size: int,
198
- output_size: int,
199
- params_dtype: torch.dtype,
200
- **extra_weight_attrs,
201
- ) -> None:
202
- del output_size
203
- output_size_per_partition = sum(output_partition_sizes)
204
- weight_loader = extra_weight_attrs.get("weight_loader")
205
-
206
- # Normalize group_size
207
- if self.quant_config.group_size != -1:
208
- group_size = self.quant_config.group_size
209
- else:
210
- group_size = input_size
211
-
212
- verify_marlin_supports_shape(
213
- output_size_per_partition=output_size_per_partition,
214
- input_size_per_partition=input_size_per_partition,
215
- input_size=input_size,
216
- group_size=group_size)
217
-
218
- qweight = PackedvLLMParameter(
219
- data=torch.empty(
220
- input_size_per_partition,
221
- output_size_per_partition // self.quant_config.pack_factor,
222
- dtype=torch.int32,
223
- ),
224
- input_dim=0,
225
- output_dim=1,
226
- packed_dim=1,
227
- packed_factor=self.quant_config.pack_factor,
228
- weight_loader=weight_loader)
229
-
230
- num_groups = input_size_per_partition // group_size
231
-
232
- qzeros = PackedvLLMParameter(
233
- data=torch.empty(
234
- num_groups,
235
- output_size_per_partition // self.quant_config.pack_factor,
236
- dtype=torch.int32,
237
- ),
238
- input_dim=0,
239
- output_dim=1,
240
- packed_dim=1,
241
- packed_factor=self.quant_config.pack_factor,
242
- weight_loader=weight_loader)
243
-
244
- scales = GroupQuantScaleParameter(data=torch.empty(
245
- num_groups,
246
- output_size_per_partition,
247
- dtype=params_dtype,
248
- ),
249
- input_dim=0,
250
- output_dim=1,
251
- weight_loader=weight_loader)
252
-
253
- layer.register_parameter("qweight", qweight)
254
- layer.register_parameter("qzeros", qzeros)
255
- layer.register_parameter("scales", scales)
256
-
257
- layer.input_size_per_partition = input_size_per_partition
258
- layer.output_size_per_partition = output_size_per_partition
259
- layer.num_groups = num_groups
260
-
261
- # TODO: Update this docs
262
- # Checkpoints are serialized in AutoAWQ format, which is different from the
263
- # marlin format. This function is called after the weights are loaded.
264
- # Here, we handle the repacking
265
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
266
- device = layer.qweight.device
267
- layer.qweight = torch.nn.Parameter(layer.qweight.data,
268
- requires_grad=False)
269
- layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
270
- requires_grad=False)
271
- layer.scales = torch.nn.Parameter(layer.scales.data,
272
- requires_grad=False)
273
-
274
- # Allocate marlin workspace
275
- layer.workspace = marlin_make_workspace_new(device)
276
-
277
- # Repack weights from AWQ format to marlin format.
278
- marlin_qweight = ops.awq_marlin_repack(
279
- layer.qweight,
280
- size_k=layer.input_size_per_partition,
281
- size_n=layer.output_size_per_partition,
282
- num_bits=self.quant_config.quant_type.size_bits)
283
- replace_parameter(layer, "qweight", marlin_qweight)
284
-
285
- # Permute scales from AWQ format to marlin format.
286
- marlin_scales = marlin_permute_scales(
287
- layer.scales,
288
- size_k=layer.input_size_per_partition,
289
- size_n=layer.output_size_per_partition,
290
- group_size=self.quant_config.group_size)
291
- replace_parameter(layer, "scales", marlin_scales)
292
-
293
- # Permute zero-points from AWQ format to marlin format.
294
- marlin_zp = awq_to_marlin_zero_points(
295
- layer.qzeros,
296
- size_k=layer.num_groups,
297
- size_n=layer.output_size_per_partition,
298
- num_bits=self.quant_config.quant_type.size_bits)
299
- replace_parameter(layer, "qzeros", marlin_zp)
300
-
301
- # Not-used
302
- layer.g_idx = marlin_make_empty_g_idx(device)
303
- layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
304
-
305
- def apply(
306
- self,
307
- layer: torch.nn.Module,
308
- x: torch.Tensor,
309
- bias: Optional[torch.Tensor] = None,
310
- ) -> torch.Tensor:
311
- return apply_awq_marlin_linear(
312
- input=x,
313
- weight=layer.qweight,
314
- weight_scale=layer.scales,
315
- weight_zp=layer.qzeros,
316
- g_idx=layer.g_idx,
317
- g_idx_sort_indices=layer.g_idx_sort_indices,
318
- workspace=layer.workspace,
319
- quant_type=self.quant_config.quant_type,
320
- output_size_per_partition=layer.output_size_per_partition,
321
- input_size_per_partition=layer.input_size_per_partition,
322
- bias=bias)
323
-
324
-
325
- class AWQMoEMethod(FusedMoEMethodBase):
326
-
327
- def __init__(self, quant_config: AWQMarlinConfig):
328
- self.quant_config = quant_config
329
- if self.quant_config.weight_bits != 4:
330
- raise ValueError("AWQMoEMethod only supports 4bit now.")
331
- self.quant_type = scalar_types.uint4
332
-
333
- def create_weights(self, layer: torch.nn.Module, num_experts: int,
334
- hidden_size: int, intermediate_size_per_partition: int,
335
- params_dtype: torch.dtype, **extra_weight_attrs):
336
- extra_weight_attrs.update({
337
- "is_transposed":
338
- True,
339
- "quant_method":
340
- FusedMoeWeightScaleSupported.GROUP.value,
341
- })
342
-
343
- w13_qweight = Parameter(
344
- torch.empty(num_experts,
345
- hidden_size,
346
- 2 * intermediate_size_per_partition //
347
- self.quant_config.pack_factor,
348
- dtype=torch.int32),
349
- requires_grad=False)
350
- layer.register_parameter("w13_qweight", w13_qweight)
351
- set_weight_attrs(w13_qweight, extra_weight_attrs)
352
-
353
- w2_qweight = Parameter(torch.empty(num_experts,
354
- intermediate_size_per_partition,
355
- hidden_size //
356
- self.quant_config.pack_factor,
357
- dtype=torch.int32),
358
- requires_grad=False)
359
- layer.register_parameter("w2_qweight", w2_qweight)
360
- set_weight_attrs(w2_qweight, extra_weight_attrs)
361
-
362
- num_groups_w13 = hidden_size // self.quant_config.group_size
363
- num_groups_w2 = (intermediate_size_per_partition //
364
- self.quant_config.group_size)
365
-
366
- # WEIGHT_SCALES
367
- # Allocate 2 scales for w1 and w3 respectively.
368
- w13_scales = Parameter(torch.empty(num_experts,
369
- num_groups_w13,
370
- intermediate_size_per_partition * 2,
371
- dtype=params_dtype),
372
- requires_grad=False)
373
- layer.register_parameter("w13_scales", w13_scales)
374
- set_weight_attrs(w13_scales, extra_weight_attrs)
375
-
376
- w2_scales = Parameter(torch.empty(num_experts,
377
- num_groups_w2,
378
- hidden_size,
379
- dtype=params_dtype),
380
- requires_grad=False)
381
- layer.register_parameter("w2_scales", w2_scales)
382
- set_weight_attrs(w2_scales, extra_weight_attrs)
383
-
384
- # WEIGHT_ZERO_POINT
385
- # Allocate 2 zero points for w1 and w3 respectively.
386
- w13_qzeros = Parameter(
387
- torch.empty(num_experts,
388
- num_groups_w13,
389
- 2 * intermediate_size_per_partition //
390
- self.quant_config.pack_factor,
391
- dtype=torch.int32),
392
- requires_grad=False)
393
- layer.register_parameter("w13_qzeros", w13_qzeros)
394
- set_weight_attrs(w13_qzeros, extra_weight_attrs)
395
-
396
- w2_qzeros = Parameter(torch.empty(num_experts,
397
- num_groups_w2,
398
- hidden_size //
399
- self.quant_config.pack_factor,
400
- dtype=torch.int32),
401
- requires_grad=False)
402
- layer.register_parameter("w2_qzeros", w2_qzeros)
403
- set_weight_attrs(w2_qzeros, extra_weight_attrs)
404
-
405
- device = layer.w13_qweight.device
406
- layer.workspace = marlin_make_workspace_new(device, 4)
407
-
408
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
409
- num_experts = layer.w13_qweight.shape[0]
410
- device = layer.w13_qweight.device
411
-
412
- layer.w13_g_idx_sort_indices = torch.nn.Parameter(
413
- torch.empty((num_experts, 0), dtype=torch.int32, device=device),
414
- requires_grad=False,
415
- )
416
- layer.w2_g_idx_sort_indices = torch.nn.Parameter(
417
- torch.empty((num_experts, 0), dtype=torch.int32, device=device),
418
- requires_grad=False,
419
- )
420
-
421
- marlin_w13_qweight = ops.awq_marlin_moe_repack(
422
- layer.w13_qweight,
423
- layer.w13_g_idx_sort_indices,
424
- size_k=layer.w13_qweight.shape[1],
425
- size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
426
- num_bits=self.quant_config.weight_bits,
427
- )
428
- replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
429
-
430
- marlin_w2_qweight = ops.awq_marlin_moe_repack(
431
- layer.w2_qweight,
432
- layer.w2_g_idx_sort_indices,
433
- size_k=layer.w2_qweight.shape[1],
434
- size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
435
- num_bits=self.quant_config.weight_bits,
436
- )
437
- replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
438
-
439
- # Why does this take the intermediate size for size_k?
440
- marlin_w13_scales = marlin_moe_permute_scales(
441
- s=layer.w13_scales,
442
- size_k=layer.intermediate_size_per_partition,
443
- size_n=layer.w13_scales.shape[2],
444
- group_size=self.quant_config.group_size,
445
- )
446
-
447
- replace_parameter(layer, "w13_scales", marlin_w13_scales)
448
-
449
- marlin_w2_scales = marlin_moe_permute_scales(
450
- s=layer.w2_scales,
451
- size_k=layer.intermediate_size_per_partition,
452
- size_n=layer.w2_scales.shape[2],
453
- group_size=self.quant_config.group_size,
454
- )
455
- replace_parameter(layer, "w2_scales", marlin_w2_scales)
456
-
457
- marlin_w13_zp = moe_awq_to_marlin_zero_points(
458
- layer.w13_qzeros,
459
- size_k=layer.w13_qzeros.shape[1],
460
- size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
461
- num_bits=self.quant_config.weight_bits)
462
- replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
463
-
464
- marlin_w2_zp = moe_awq_to_marlin_zero_points(
465
- layer.w2_qzeros,
466
- size_k=layer.w2_qzeros.shape[1],
467
- size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
468
- num_bits=self.quant_config.weight_bits)
469
- replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
470
-
471
- def apply(
472
- self,
473
- layer: torch.nn.Module,
474
- x: torch.Tensor,
475
- router_logits: torch.Tensor,
476
- top_k: int,
477
- renormalize: bool,
478
- use_grouped_topk: bool = False,
479
- topk_group: Optional[int] = None,
480
- num_expert_group: Optional[int] = None,
481
- global_num_experts: int = -1,
482
- expert_map: Optional[torch.Tensor] = None,
483
- custom_routing_function: Optional[Callable] = None,
484
- scoring_func: str = "softmax",
485
- e_score_correction_bias: Optional[torch.Tensor] = None,
486
- apply_router_weight_on_input: bool = False,
487
- activation: str = "silu",
488
- enable_eplb: bool = False,
489
- expert_load_view: Optional[torch.Tensor] = None,
490
- logical_to_physical_map: Optional[torch.Tensor] = None,
491
- logical_replica_count: Optional[torch.Tensor] = None,
492
- ) -> torch.Tensor:
493
- if enable_eplb:
494
- raise NotImplementedError(
495
- "EPLB not supported for `AWQMoEMethod` yet.")
496
-
497
- assert activation == "silu", "Only SiLU activation is supported."
498
-
499
- topk_weights, topk_ids = FusedMoE.select_experts(
500
- hidden_states=x,
501
- router_logits=router_logits,
502
- use_grouped_topk=use_grouped_topk,
503
- top_k=top_k,
504
- renormalize=renormalize,
505
- topk_group=topk_group,
506
- num_expert_group=num_expert_group,
507
- custom_routing_function=custom_routing_function,
508
- scoring_func=scoring_func,
509
- e_score_correction_bias=e_score_correction_bias)
510
-
511
- return torch.ops.vllm.fused_marlin_moe(
512
- x,
513
- layer.w13_qweight,
514
- layer.w2_qweight,
515
- layer.w13_scales,
516
- layer.w2_scales,
517
- router_logits,
518
- topk_weights,
519
- topk_ids,
520
- quant_type_id=self.quant_type.id,
521
- apply_router_weight_on_input=apply_router_weight_on_input,
522
- global_num_experts=global_num_experts,
523
- expert_map=expert_map,
524
- w1_zeros=layer.w13_qzeros,
525
- w2_zeros=layer.w2_qzeros,
526
- workspace=layer.workspace)