pkumc HandH1998 commited on
Commit
d476539
·
verified ·
1 Parent(s): e48a514

Update inference/bf16_cast_block_int8.py (#7)

Browse files

- Update inference/bf16_cast_block_int8.py (ca68670207d276a38c4d91eb8054a6d1653c2c23)


Co-authored-by: HandH1998 <[email protected]>

Files changed (1) hide show
  1. inference/bf16_cast_block_int8.py +30 -5
inference/bf16_cast_block_int8.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  from argparse import ArgumentParser
4
  from glob import glob
5
  from tqdm import tqdm
 
6
 
7
  import torch
8
  from safetensors.torch import load_file, save_file
@@ -14,15 +15,40 @@ def main(bf16_path, int8_path, model_name="deepseek-ai/DeepSeek-R1"):
14
  torch.set_default_dtype(torch.bfloat16)
15
  os.makedirs(int8_path, exist_ok=True)
16
  model_index_file = os.path.join(int8_path, "model.safetensors.index.json")
17
-
18
- if not os.path.exists(model_index_file):
 
19
  snapshot_download(
20
  repo_id=model_name,
21
- allow_patterns=["model.safetensors.index.json"],
22
  local_dir=int8_path,
23
  local_dir_use_symlinks=False
24
  )
25
- print(f"model index file downloaded to {model_index_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  with open(model_index_file, "r") as f:
28
  model_index = json.load(f)
@@ -60,4 +86,3 @@ if __name__ == "__main__":
60
  args = parser.parse_args()
61
  main(args.input_bf16_hf_path, args.output_int8_hf_path, args.model_name)
62
  print("done")
63
-
 
3
  from argparse import ArgumentParser
4
  from glob import glob
5
  from tqdm import tqdm
6
+ import json
7
 
8
  import torch
9
  from safetensors.torch import load_file, save_file
 
15
  torch.set_default_dtype(torch.bfloat16)
16
  os.makedirs(int8_path, exist_ok=True)
17
  model_index_file = os.path.join(int8_path, "model.safetensors.index.json")
18
+ config_file = os.path.join(int8_path, "config.json")
19
+
20
+ if not os.path.exists(model_index_file) or not os.path.exists(config_file):
21
  snapshot_download(
22
  repo_id=model_name,
23
+ ignore_patterns=["*.safetensors"],
24
  local_dir=int8_path,
25
  local_dir_use_symlinks=False
26
  )
27
+ print(f"model index file and config file downloaded to {int8_path}")
28
+
29
+ # modify config.json and save it
30
+ config = json.load(open(config_file))
31
+ if "quantization_config" in config:
32
+ quant_config = config["quantization_config"]
33
+ quant_config.pop("fmt", None)
34
+ quant_config["quant_method"] = "blockwise_int8"
35
+ quant_config["weight_block_size"] = [
36
+ 128,
37
+ 128
38
+ ]
39
+ quant_config["activation_scheme"] = "dynamic"
40
+ else:
41
+ config["quantization_config"] = {
42
+ "activation_scheme": "dynamic",
43
+ "quant_method": "blockwise_int8",
44
+ "weight_block_size": [
45
+ 128,
46
+ 128
47
+ ]
48
+ }
49
+ with open(config_file, "w", encoding="utf-8") as f:
50
+ json.dump(config, f, indent=2, ensure_ascii=False, sort_keys=True)
51
+ print(f"config.json modified and saved to {config_file}")
52
 
53
  with open(model_index_file, "r") as f:
54
  model_index = json.load(f)
 
86
  args = parser.parse_args()
87
  main(args.input_bf16_hf_path, args.output_int8_hf_path, args.model_name)
88
  print("done")