Lizard-17m / generate.py
Martico2432's picture
Upload 3 files
b3b6708 verified
raw
history blame
4.42 kB
import torch
import sys
import time
from architecture import SmallGPT
from tokenizers import Tokenizer
def load_tokenizer(path="smptokenizer/tokenizer.json"):
tokenizer = Tokenizer.from_file(path)
return tokenizer
def generate_text_streaming(model, tokenizer, start_text, device, max_length=64, temperature=1.0, max_new_tokens=20, repetition_penalty=1.2):
"""
Generates text token by token, yielding each new token.
"""
model.eval()
# Encode start text
input_ids = tokenizer.encode(start_text).ids
generated_ids = []
# Print the starting text, and wait for the model to continue
print("Generated Sentence:")
print(start_text, end="", flush=True)
current_ids = input_ids
with torch.no_grad():
for _ in range(max_new_tokens):
# Prepare input (truncate if too long)
current_input = current_ids[-max_length+1:] if len(current_ids) >= max_length else current_ids
input_tensor = torch.tensor([current_input], dtype=torch.long, device=device)
# Get output
logits = model(input_tensor)
# Get logits for last position
next_token_logits = logits[0, -1, :] / temperature
# Apply repetition penalty, if needed
if repetition_penalty > 1.0:
for token_id in set(current_ids):
next_token_logits[token_id] /= repetition_penalty
# Sample next token
probs = torch.softmax(next_token_logits, dim=-1)
next_token_id = torch.multinomial(probs, 1).item()
# Check for EOS
if next_token_id == tokenizer.token_to_id("<eos>"):
break
generated_ids.append(next_token_id)
current_ids.append(next_token_id)
# Decode and yield the new token
new_token = tokenizer.decode([next_token_id])
yield new_token
def main(seed):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load tokenizer
tokenizer_path = "smptokenizer/tokenizer.json"
tokenizer = load_tokenizer(tokenizer_path)
vocab_size = tokenizer.get_vocab_size()
pad_id = tokenizer.token_to_id("<pad>") or 0
# Model parameters from training
d_model = 256
n_heads = 8
n_layers = 6
max_length = 172
# Instantiate the model
model = SmallGPT(
vocab_size=vocab_size,
d_model=d_model,
n_heads=n_heads,
n_layers=n_layers,
max_length=max_length,
pad_idx=pad_id,
).to(device)
# Load the trained model weights
model_path = "models/pytorch_model.bin" # idk if safetensor works
try:
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"Model loaded from {model_path}")
except FileNotFoundError:
print(f"Error: Model file not found at {model_path}")
print("Please ensure the model is trained and the path is correct.")
return
while True:
# Reset seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
start_text = input("Enter a starting word or phrase (or 'quit' to exit): ")
if start_text.lower() == 'quit':
break
if not start_text.strip():
print("Please enter some text. We are using a random character as a starting point.")
start_text = str(time.time())
print("Generating...")
token_count = 0
start_time = time.time()
for token in generate_text_streaming(
model=model,
tokenizer=tokenizer,
start_text=start_text,
device=device,
max_new_tokens=1000,
temperature=0.7,
max_length=max_length,
repetition_penalty=1.2
):
print(token, end="", flush=True)
token_count += 1
end_time = time.time()
elapsed_time = end_time - start_time
tokens_per_sec = token_count / elapsed_time if elapsed_time > 0 else 0
print(f"\n\nPerformance: {tokens_per_sec:.2f} tokens/sec")
print("-" * 30)
if __name__ == "__main__":
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
main(seed)