File size: 4,421 Bytes
b3b6708 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import sys
import time
from architecture import SmallGPT
from tokenizers import Tokenizer
def load_tokenizer(path="smptokenizer/tokenizer.json"):
tokenizer = Tokenizer.from_file(path)
return tokenizer
def generate_text_streaming(model, tokenizer, start_text, device, max_length=64, temperature=1.0, max_new_tokens=20, repetition_penalty=1.2):
"""
Generates text token by token, yielding each new token.
"""
model.eval()
# Encode start text
input_ids = tokenizer.encode(start_text).ids
generated_ids = []
# Print the starting text, and wait for the model to continue
print("Generated Sentence:")
print(start_text, end="", flush=True)
current_ids = input_ids
with torch.no_grad():
for _ in range(max_new_tokens):
# Prepare input (truncate if too long)
current_input = current_ids[-max_length+1:] if len(current_ids) >= max_length else current_ids
input_tensor = torch.tensor([current_input], dtype=torch.long, device=device)
# Get output
logits = model(input_tensor)
# Get logits for last position
next_token_logits = logits[0, -1, :] / temperature
# Apply repetition penalty, if needed
if repetition_penalty > 1.0:
for token_id in set(current_ids):
next_token_logits[token_id] /= repetition_penalty
# Sample next token
probs = torch.softmax(next_token_logits, dim=-1)
next_token_id = torch.multinomial(probs, 1).item()
# Check for EOS
if next_token_id == tokenizer.token_to_id("<eos>"):
break
generated_ids.append(next_token_id)
current_ids.append(next_token_id)
# Decode and yield the new token
new_token = tokenizer.decode([next_token_id])
yield new_token
def main(seed):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load tokenizer
tokenizer_path = "smptokenizer/tokenizer.json"
tokenizer = load_tokenizer(tokenizer_path)
vocab_size = tokenizer.get_vocab_size()
pad_id = tokenizer.token_to_id("<pad>") or 0
# Model parameters from training
d_model = 256
n_heads = 8
n_layers = 6
max_length = 172
# Instantiate the model
model = SmallGPT(
vocab_size=vocab_size,
d_model=d_model,
n_heads=n_heads,
n_layers=n_layers,
max_length=max_length,
pad_idx=pad_id,
).to(device)
# Load the trained model weights
model_path = "models/pytorch_model.bin" # idk if safetensor works
try:
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"Model loaded from {model_path}")
except FileNotFoundError:
print(f"Error: Model file not found at {model_path}")
print("Please ensure the model is trained and the path is correct.")
return
while True:
# Reset seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
start_text = input("Enter a starting word or phrase (or 'quit' to exit): ")
if start_text.lower() == 'quit':
break
if not start_text.strip():
print("Please enter some text. We are using a random character as a starting point.")
start_text = str(time.time())
print("Generating...")
token_count = 0
start_time = time.time()
for token in generate_text_streaming(
model=model,
tokenizer=tokenizer,
start_text=start_text,
device=device,
max_new_tokens=1000,
temperature=0.7,
max_length=max_length,
repetition_penalty=1.2
):
print(token, end="", flush=True)
token_count += 1
end_time = time.time()
elapsed_time = end_time - start_time
tokens_per_sec = token_count / elapsed_time if elapsed_time > 0 else 0
print(f"\n\nPerformance: {tokens_per_sec:.2f} tokens/sec")
print("-" * 30)
if __name__ == "__main__":
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
main(seed) |