|
from ssllm_hf import SSLLMForCausalLM, SSLLMConfig |
|
import tiktoken |
|
import torch |
|
from safetensors.torch import load_file |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
config = SSLLMConfig.from_pretrained('sausheong/ssllm_hf') |
|
model = SSLLMForCausalLM(config) |
|
|
|
|
|
model_path = hf_hub_download(repo_id='sausheong/ssllm_hf', filename='model.safetensors') |
|
state_dict = load_file(model_path) |
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.to(device).eval() |
|
|
|
|
|
tokenizer = tiktoken.get_encoding('cl100k_base') |
|
|
|
def generate_text(prompt, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=40): |
|
|
|
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device) |
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
pad_token_id=100257, |
|
eos_token_id=100257, |
|
) |
|
|
|
|
|
new_tokens = outputs[0][input_ids.shape[1]:].tolist() |
|
generated = tokenizer.decode(new_tokens) |
|
|
|
print(f"{prompt}{generated}") |
|
print(f"\nTokens generated: {len(new_tokens)}") |
|
|
|
if __name__ == "__main__": |
|
prompt = "In a small village nestled between mountains," |
|
print(f"PROMPT: {prompt}\n--") |
|
generate_text(prompt) |
|
|