jerryzh168's picture
Update README.md
6a5dd37 verified
metadata
library_name: transformers
tags: []

Model Card for Model ID

"""
OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=768, out_features=50272, bias=False)
)
"""


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

SAVE = True

model_id = "facebook/opt-125m"

from torchao.quantization import ModuleFqnToConfig
# from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
# fp8_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
from torchao.quantization import Int4WeightOnlyConfig
int4_config = Int4WeightOnlyConfig(group_size=128, use_hqq=True)

qconfig_dict = {}
# 0...12
for idx in range(12):
    qconfig_dict[f"model.decoder.layers.{idx}.fc1"] = int4_config
    qconfig_dict[f"model.decoder.layers.{idx}.fc2"] = int4_config


quant_config = ModuleFqnToConfig(qconfig_dict)

quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
print(quantized_model)
# processor = AutoProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Push to hub
USER_ID = "YOUR_USER_ID"
save_to = f"jerryzh168/opt-125m-int4wo-per-module"
if SAVE:
    quantized_model.push_to_hub(save_to, safe_serialization=False)
    tokenizer.push_to_hub(save_to)

    # quantized_model.save_pretrained(save_to, safe_serialization=False)
    # tokenizer.save_pretrained(save_to)

# Manual Testing
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)