ssllm_hf / generate.py
sausheong's picture
cleaned up
869a97c
from ssllm_hf import SSLLMForCausalLM, SSLLMConfig
import tiktoken
import torch
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
# Initialize model with config
config = SSLLMConfig.from_pretrained('sausheong/ssllm_hf')
model = SSLLMForCausalLM(config)
# Download and load model weights
model_path = hf_hub_download(repo_id='sausheong/ssllm_hf', filename='model.safetensors')
state_dict = load_file(model_path)
model.load_state_dict(state_dict, strict=False)
# Setup device and eval mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device).eval()
# Initialize tokenizer
tokenizer = tiktoken.get_encoding('cl100k_base')
def generate_text(prompt, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=40):
# Encode the prompt
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
attention_mask = torch.ones_like(input_ids)
# Generate with the model
with torch.no_grad():
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
pad_token_id=100257,
eos_token_id=100257,
)
# Decode only the new tokens
new_tokens = outputs[0][input_ids.shape[1]:].tolist()
generated = tokenizer.decode(new_tokens)
print(f"{prompt}{generated}")
print(f"\nTokens generated: {len(new_tokens)}")
if __name__ == "__main__":
prompt = "In a small village nestled between mountains,"
print(f"PROMPT: {prompt}\n--")
generate_text(prompt)