jerryzh168 commited on
Commit
e225287
·
verified ·
1 Parent(s): c10a6f4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -99,7 +99,7 @@ from transformers import (
99
  from torchao.quantization.quant_api import (
100
  IntxWeightOnlyConfig,
101
  Int8DynamicActivationIntxWeightConfig,
102
- AOPerModuleConfig,
103
  quantize_,
104
  )
105
  from torchao.quantization.granularity import PerGroup, PerAxis
@@ -121,7 +121,7 @@ linear_config = Int8DynamicActivationIntxWeightConfig(
121
  weight_granularity=PerGroup(32),
122
  weight_scale_dtype=torch.bfloat16,
123
  )
124
- quant_config = AOPerModuleConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
125
  quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[])
126
 
127
  # either use `untied_model_id` or `untied_model_local_path`
@@ -130,7 +130,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
130
 
131
  # Push to hub
132
  MODEL_NAME = model_id.split("/")[-1]
133
- save_to = f"{USER_ID}/{MODEL_NAME}-untied-8da4w"
134
  quantized_model.push_to_hub(save_to, safe_serialization=False)
135
  tokenizer.push_to_hub(save_to)
136
 
 
99
  from torchao.quantization.quant_api import (
100
  IntxWeightOnlyConfig,
101
  Int8DynamicActivationIntxWeightConfig,
102
+ ModuleFqnToConfig,
103
  quantize_,
104
  )
105
  from torchao.quantization.granularity import PerGroup, PerAxis
 
121
  weight_granularity=PerGroup(32),
122
  weight_scale_dtype=torch.bfloat16,
123
  )
124
+ quant_config = ModuleFqnToConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
125
  quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[])
126
 
127
  # either use `untied_model_id` or `untied_model_local_path`
 
130
 
131
  # Push to hub
132
  MODEL_NAME = model_id.split("/")[-1]
133
+ save_to = f"{USER_ID}/{MODEL_NAME}-8da4w"
134
  quantized_model.push_to_hub(save_to, safe_serialization=False)
135
  tokenizer.push_to_hub(save_to)
136