|
import torch |
|
import sys |
|
import time |
|
from architecture import SmallGPT |
|
from tokenizers import Tokenizer |
|
|
|
def load_tokenizer(path="smptokenizer/tokenizer.json"): |
|
tokenizer = Tokenizer.from_file(path) |
|
return tokenizer |
|
|
|
def generate_text_streaming(model, tokenizer, start_text, device, max_length=64, temperature=1.0, max_new_tokens=20, repetition_penalty=1.2): |
|
""" |
|
Generates text token by token, yielding each new token. |
|
""" |
|
model.eval() |
|
|
|
|
|
input_ids = tokenizer.encode(start_text).ids |
|
generated_ids = [] |
|
|
|
|
|
print("Generated Sentence:") |
|
print(start_text, end="", flush=True) |
|
|
|
current_ids = input_ids |
|
|
|
with torch.no_grad(): |
|
for _ in range(max_new_tokens): |
|
|
|
current_input = current_ids[-max_length+1:] if len(current_ids) >= max_length else current_ids |
|
input_tensor = torch.tensor([current_input], dtype=torch.long, device=device) |
|
|
|
|
|
logits = model(input_tensor) |
|
|
|
|
|
next_token_logits = logits[0, -1, :] / temperature |
|
|
|
|
|
if repetition_penalty > 1.0: |
|
for token_id in set(current_ids): |
|
next_token_logits[token_id] /= repetition_penalty |
|
|
|
|
|
probs = torch.softmax(next_token_logits, dim=-1) |
|
next_token_id = torch.multinomial(probs, 1).item() |
|
|
|
|
|
if next_token_id == tokenizer.token_to_id("<eos>"): |
|
break |
|
|
|
generated_ids.append(next_token_id) |
|
current_ids.append(next_token_id) |
|
|
|
|
|
new_token = tokenizer.decode([next_token_id]) |
|
yield new_token |
|
|
|
def main(seed): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
tokenizer_path = "smptokenizer/tokenizer.json" |
|
tokenizer = load_tokenizer(tokenizer_path) |
|
vocab_size = tokenizer.get_vocab_size() |
|
pad_id = tokenizer.token_to_id("<pad>") or 0 |
|
|
|
|
|
d_model = 256 |
|
n_heads = 8 |
|
n_layers = 6 |
|
max_length = 172 |
|
|
|
|
|
model = SmallGPT( |
|
vocab_size=vocab_size, |
|
d_model=d_model, |
|
n_heads=n_heads, |
|
n_layers=n_layers, |
|
max_length=max_length, |
|
pad_idx=pad_id, |
|
).to(device) |
|
|
|
|
|
model_path = "models/pytorch_model.bin" |
|
try: |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.eval() |
|
print(f"Model loaded from {model_path}") |
|
except FileNotFoundError: |
|
print(f"Error: Model file not found at {model_path}") |
|
print("Please ensure the model is trained and the path is correct.") |
|
return |
|
|
|
while True: |
|
|
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
start_text = input("Enter a starting word or phrase (or 'quit' to exit): ") |
|
if start_text.lower() == 'quit': |
|
break |
|
|
|
if not start_text.strip(): |
|
print("Please enter some text. We are using a random character as a starting point.") |
|
start_text = str(time.time()) |
|
|
|
print("Generating...") |
|
|
|
token_count = 0 |
|
start_time = time.time() |
|
|
|
for token in generate_text_streaming( |
|
model=model, |
|
tokenizer=tokenizer, |
|
start_text=start_text, |
|
device=device, |
|
max_new_tokens=1000, |
|
temperature=0.7, |
|
max_length=max_length, |
|
repetition_penalty=1.2 |
|
): |
|
print(token, end="", flush=True) |
|
token_count += 1 |
|
|
|
end_time = time.time() |
|
elapsed_time = end_time - start_time |
|
tokens_per_sec = token_count / elapsed_time if elapsed_time > 0 else 0 |
|
|
|
print(f"\n\nPerformance: {tokens_per_sec:.2f} tokens/sec") |
|
print("-" * 30) |
|
|
|
if __name__ == "__main__": |
|
seed = 42 |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
main(seed) |