File size: 4,421 Bytes
b3b6708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import sys
import time
from architecture import SmallGPT
from tokenizers import Tokenizer

def load_tokenizer(path="smptokenizer/tokenizer.json"):
    tokenizer = Tokenizer.from_file(path)
    return tokenizer

def generate_text_streaming(model, tokenizer, start_text, device, max_length=64, temperature=1.0, max_new_tokens=20, repetition_penalty=1.2):
    """
    Generates text token by token, yielding each new token.
    """
    model.eval()
    
    # Encode start text
    input_ids = tokenizer.encode(start_text).ids
    generated_ids = []

    # Print the starting text, and wait for the model to continue
    print("Generated Sentence:")
    print(start_text, end="", flush=True)

    current_ids = input_ids
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Prepare input (truncate if too long)
            current_input = current_ids[-max_length+1:] if len(current_ids) >= max_length else current_ids
            input_tensor = torch.tensor([current_input], dtype=torch.long, device=device)
            
            # Get output
            logits = model(input_tensor)
            
            # Get logits for last position
            next_token_logits = logits[0, -1, :] / temperature

            # Apply repetition penalty, if needed
            if repetition_penalty > 1.0:
                for token_id in set(current_ids):
                    next_token_logits[token_id] /= repetition_penalty
            
            # Sample next token
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token_id = torch.multinomial(probs, 1).item()
            
            # Check for EOS
            if next_token_id == tokenizer.token_to_id("<eos>"):
                break
                
            generated_ids.append(next_token_id)
            current_ids.append(next_token_id)
            
            # Decode and yield the new token
            new_token = tokenizer.decode([next_token_id])
            yield new_token

def main(seed):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load tokenizer
    tokenizer_path = "smptokenizer/tokenizer.json"
    tokenizer = load_tokenizer(tokenizer_path)
    vocab_size = tokenizer.get_vocab_size()
    pad_id = tokenizer.token_to_id("<pad>") or 0

    # Model parameters from training
    d_model = 256
    n_heads = 8
    n_layers = 6
    max_length = 172

    # Instantiate the model
    model = SmallGPT(
        vocab_size=vocab_size,
        d_model=d_model,
        n_heads=n_heads,
        n_layers=n_layers,
        max_length=max_length,
        pad_idx=pad_id,
    ).to(device)

    # Load the trained model weights
    model_path = "models/pytorch_model.bin" # idk if safetensor works
    try:
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        print(f"Model loaded from {model_path}")
    except FileNotFoundError:
        print(f"Error: Model file not found at {model_path}")
        print("Please ensure the model is trained and the path is correct.")
        return

    while True:
        # Reset seed
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        start_text = input("Enter a starting word or phrase (or 'quit' to exit): ")
        if start_text.lower() == 'quit':
            break

        if not start_text.strip():
            print("Please enter some text. We are using a random character as a starting point.")
            start_text = str(time.time())

        print("Generating...")
        
        token_count = 0
        start_time = time.time()

        for token in generate_text_streaming(
            model=model,
            tokenizer=tokenizer,
            start_text=start_text,
            device=device,
            max_new_tokens=1000,
            temperature=0.7,
            max_length=max_length,
            repetition_penalty=1.2
        ):
            print(token, end="", flush=True)
            token_count += 1
        
        end_time = time.time()
        elapsed_time = end_time - start_time
        tokens_per_sec = token_count / elapsed_time if elapsed_time > 0 else 0

        print(f"\n\nPerformance: {tokens_per_sec:.2f} tokens/sec")
        print("-" * 30)

if __name__ == "__main__":
    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    main(seed)