wuzhiying commited on
Commit
4d1208f
1 Parent(s): ddad89a

sync base to chat

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +14 -14
modeling_baichuan.py CHANGED
@@ -528,7 +528,6 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
528
  self.model = BaichuanModel(config)
529
 
530
  self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
531
- #if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']:
532
  if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False):
533
  try:
534
  from .quantizer import quantize_offline, init_model_weight_int4
@@ -609,22 +608,23 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
609
  model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin')
610
  state_dict = torch.load(model_file, map_location="cpu")
611
  model.is_quantized = True
612
-
613
  device_map = kwargs.pop("device_map", None)
614
  torch_dtype = kwargs.pop("torch_dtype", None)
615
 
616
- kwargs = {"no_split_module_classes": model._no_split_modules}
617
- target_dtype = CustomDtype.INT4
618
- max_memory = get_balanced_memory(
619
- model,
620
- dtype=target_dtype,
621
- low_zero=(device_map == "balanced_low_0"),
622
- max_memory=None,
623
- **kwargs,
624
- )
625
- kwargs["max_memory"] = max_memory
626
-
627
- device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
 
628
  model = init_model_weight_int4(config, model, state_dict)
629
 
630
  # Set model in evaluation mode to deactivate DropOut modules by default
 
528
  self.model = BaichuanModel(config)
529
 
530
  self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
 
531
  if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False):
532
  try:
533
  from .quantizer import quantize_offline, init_model_weight_int4
 
608
  model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin')
609
  state_dict = torch.load(model_file, map_location="cpu")
610
  model.is_quantized = True
611
+
612
  device_map = kwargs.pop("device_map", None)
613
  torch_dtype = kwargs.pop("torch_dtype", None)
614
 
615
+ if device_map is not None:
616
+ kwargs = {"no_split_module_classes": model._no_split_modules}
617
+ target_dtype = CustomDtype.INT4
618
+ max_memory = get_balanced_memory(
619
+ model,
620
+ dtype=target_dtype,
621
+ low_zero=(device_map == "balanced_low_0"),
622
+ max_memory=None,
623
+ **kwargs,
624
+ )
625
+ kwargs["max_memory"] = max_memory
626
+ device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
627
+
628
  model = init_model_weight_int4(config, model, state_dict)
629
 
630
  # Set model in evaluation mode to deactivate DropOut modules by default