Martico2432 commited on
Commit
b3b6708
·
verified ·
1 Parent(s): 1d644dd

Upload 3 files

Browse files
Files changed (3) hide show
  1. architecture.py +129 -0
  2. config.json +7 -0
  3. generate.py +138 -0
architecture.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import xformers.ops as xops
5
+
6
+ class SmallGPT(nn.Module):
7
+ def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=6, max_length=128, pad_idx=0):
8
+ super().__init__()
9
+ self.d_model = d_model
10
+ self.max_length = max_length
11
+
12
+ # Embeddings
13
+ self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
14
+ self.position_embedding = nn.Embedding(max_length, d_model)
15
+
16
+ # Transformer blocks
17
+ self.blocks = nn.ModuleList([
18
+ TransformerBlock(d_model, n_heads) for _ in range(n_layers)
19
+ ])
20
+
21
+ # Output
22
+ self.ln_f = nn.LayerNorm(d_model)
23
+ self.head = nn.Linear(d_model, vocab_size, bias=False)
24
+
25
+ # Init weights
26
+ self.apply(self._init_weights)
27
+
28
+ def _init_weights(self, module):
29
+ if isinstance(module, nn.Linear):
30
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.03)
31
+ if module.bias is not None:
32
+ torch.nn.init.zeros_(module.bias)
33
+ elif isinstance(module, nn.Embedding):
34
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.03)
35
+
36
+ def forward(self, x):
37
+ batch_size, seq_len = x.size()
38
+
39
+ # position indices
40
+ pos = torch.arange(0, seq_len, dtype=torch.long, device=x.device)
41
+ pos = pos.unsqueeze(0).expand(batch_size, seq_len)
42
+
43
+ # Embeddings
44
+ tok_emb = self.token_embedding(x)
45
+ pos_emb = self.position_embedding(pos)
46
+ x = tok_emb + pos_emb
47
+
48
+ # Transformer blocks
49
+ for block in self.blocks:
50
+ x = block(x)
51
+
52
+ # Final layer norm and projection
53
+ x = self.ln_f(x)
54
+ logits = self.head(x)
55
+
56
+ return logits
57
+
58
+
59
+ class TransformerBlock(nn.Module):
60
+ def __init__(self, d_model, n_heads):
61
+ super().__init__()
62
+ self.ln1 = nn.LayerNorm(d_model)
63
+ self.attn = CausalSelfAttention(d_model, n_heads)
64
+ self.ln2 = nn.LayerNorm(d_model)
65
+ self.mlp = MLP(d_model)
66
+
67
+ def forward(self, x):
68
+ x = x + self.attn(self.ln1(x))
69
+ x = x + self.mlp(self.ln2(x))
70
+ return x
71
+
72
+
73
+ class CausalSelfAttention(nn.Module):
74
+ def __init__(self, d_model, n_heads):
75
+ super().__init__()
76
+ assert d_model % n_heads == 0
77
+ self.n_heads = n_heads
78
+ self.d_model = d_model
79
+ self.head_dim = d_model // n_heads
80
+
81
+ self.qkv = nn.Linear(d_model, 3 * d_model)
82
+ self.proj = nn.Linear(d_model, d_model)
83
+
84
+ def forward(self, x):
85
+ batch, seq_len, d_model = x.size()
86
+ qkv = self.qkv(x) # [B, S, 3*D]
87
+ q, k, v = qkv.chunk(3, dim=-1)
88
+
89
+ # reshape into heads
90
+ q = q.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2) # [B, H, S, Hd]
91
+ k = k.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
92
+ v = v.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
93
+
94
+ # flatten for xformers: [B*H, S, Hd]
95
+ q = q.reshape(batch * self.n_heads, seq_len, self.head_dim)
96
+ k = k.reshape(batch * self.n_heads, seq_len, self.head_dim)
97
+ v = v.reshape(batch * self.n_heads, seq_len, self.head_dim)
98
+
99
+ # apply memory-efficient attention with causal mask
100
+ out = xops.memory_efficient_attention(q, k, v, attn_bias=xops.LowerTriangularMask())
101
+ # out: [B*H, S, Hd]
102
+
103
+ # reshape back
104
+ out = out.view(batch, self.n_heads, seq_len, self.head_dim)
105
+ out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
106
+
107
+ return self.proj(out)
108
+
109
+ class MLP(nn.Module):
110
+ def __init__(self, d_model):
111
+ super().__init__()
112
+ self.fc1 = nn.Linear(d_model, 4 * d_model)
113
+ self.fc2 = nn.Linear(4 * d_model, d_model)
114
+ self.silu = nn.SiLU()
115
+
116
+ def forward(self, x):
117
+ x = self.fc1(x)
118
+ x = self.silu(x)
119
+ x = self.fc2(x)
120
+ return x
121
+
122
+
123
+ DEFAULT_CONFIG = {
124
+ "vocab_size": 24_005,
125
+ "d_model": 256,
126
+ "n_heads": 8,
127
+ "n_layers": 6,
128
+ "max_length": 128,
129
+ }
config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 24005,
3
+ "d_model": 256,
4
+ "n_heads": 8,
5
+ "n_layers": 6,
6
+ "max_length": 128
7
+ }
generate.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import time
4
+ from architecture import SmallGPT
5
+ from tokenizers import Tokenizer
6
+
7
+ def load_tokenizer(path="smptokenizer/tokenizer.json"):
8
+ tokenizer = Tokenizer.from_file(path)
9
+ return tokenizer
10
+
11
+ def generate_text_streaming(model, tokenizer, start_text, device, max_length=64, temperature=1.0, max_new_tokens=20, repetition_penalty=1.2):
12
+ """
13
+ Generates text token by token, yielding each new token.
14
+ """
15
+ model.eval()
16
+
17
+ # Encode start text
18
+ input_ids = tokenizer.encode(start_text).ids
19
+ generated_ids = []
20
+
21
+ # Print the starting text, and wait for the model to continue
22
+ print("Generated Sentence:")
23
+ print(start_text, end="", flush=True)
24
+
25
+ current_ids = input_ids
26
+
27
+ with torch.no_grad():
28
+ for _ in range(max_new_tokens):
29
+ # Prepare input (truncate if too long)
30
+ current_input = current_ids[-max_length+1:] if len(current_ids) >= max_length else current_ids
31
+ input_tensor = torch.tensor([current_input], dtype=torch.long, device=device)
32
+
33
+ # Get output
34
+ logits = model(input_tensor)
35
+
36
+ # Get logits for last position
37
+ next_token_logits = logits[0, -1, :] / temperature
38
+
39
+ # Apply repetition penalty, if needed
40
+ if repetition_penalty > 1.0:
41
+ for token_id in set(current_ids):
42
+ next_token_logits[token_id] /= repetition_penalty
43
+
44
+ # Sample next token
45
+ probs = torch.softmax(next_token_logits, dim=-1)
46
+ next_token_id = torch.multinomial(probs, 1).item()
47
+
48
+ # Check for EOS
49
+ if next_token_id == tokenizer.token_to_id("<eos>"):
50
+ break
51
+
52
+ generated_ids.append(next_token_id)
53
+ current_ids.append(next_token_id)
54
+
55
+ # Decode and yield the new token
56
+ new_token = tokenizer.decode([next_token_id])
57
+ yield new_token
58
+
59
+ def main(seed):
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ print(f"Using device: {device}")
62
+
63
+ # Load tokenizer
64
+ tokenizer_path = "smptokenizer/tokenizer.json"
65
+ tokenizer = load_tokenizer(tokenizer_path)
66
+ vocab_size = tokenizer.get_vocab_size()
67
+ pad_id = tokenizer.token_to_id("<pad>") or 0
68
+
69
+ # Model parameters from training
70
+ d_model = 256
71
+ n_heads = 8
72
+ n_layers = 6
73
+ max_length = 172
74
+
75
+ # Instantiate the model
76
+ model = SmallGPT(
77
+ vocab_size=vocab_size,
78
+ d_model=d_model,
79
+ n_heads=n_heads,
80
+ n_layers=n_layers,
81
+ max_length=max_length,
82
+ pad_idx=pad_id,
83
+ ).to(device)
84
+
85
+ # Load the trained model weights
86
+ model_path = "models/pytorch_model.bin" # idk if safetensor works
87
+ try:
88
+ model.load_state_dict(torch.load(model_path, map_location=device))
89
+ model.eval()
90
+ print(f"Model loaded from {model_path}")
91
+ except FileNotFoundError:
92
+ print(f"Error: Model file not found at {model_path}")
93
+ print("Please ensure the model is trained and the path is correct.")
94
+ return
95
+
96
+ while True:
97
+ # Reset seed
98
+ torch.manual_seed(seed)
99
+ torch.cuda.manual_seed(seed)
100
+
101
+ start_text = input("Enter a starting word or phrase (or 'quit' to exit): ")
102
+ if start_text.lower() == 'quit':
103
+ break
104
+
105
+ if not start_text.strip():
106
+ print("Please enter some text. We are using a random character as a starting point.")
107
+ start_text = str(time.time())
108
+
109
+ print("Generating...")
110
+
111
+ token_count = 0
112
+ start_time = time.time()
113
+
114
+ for token in generate_text_streaming(
115
+ model=model,
116
+ tokenizer=tokenizer,
117
+ start_text=start_text,
118
+ device=device,
119
+ max_new_tokens=1000,
120
+ temperature=0.7,
121
+ max_length=max_length,
122
+ repetition_penalty=1.2
123
+ ):
124
+ print(token, end="", flush=True)
125
+ token_count += 1
126
+
127
+ end_time = time.time()
128
+ elapsed_time = end_time - start_time
129
+ tokens_per_sec = token_count / elapsed_time if elapsed_time > 0 else 0
130
+
131
+ print(f"\n\nPerformance: {tokens_per_sec:.2f} tokens/sec")
132
+ print("-" * 30)
133
+
134
+ if __name__ == "__main__":
135
+ seed = 42
136
+ torch.manual_seed(seed)
137
+ torch.cuda.manual_seed(seed)
138
+ main(seed)