Vishwas1 commited on
Commit
0a660ad
·
verified ·
1 Parent(s): 00d1598

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +1 -0
  2. best_model.pt +3 -0
  3. model_cfg.json +11 -0
  4. ssw_model.py +193 -0
  5. tokenizer.json +0 -0
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Structured State Weaving (SSW) – toy LM
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e26d92b032202260a30bb88b5c9adc9e6c6ae4999d340332c6845ae830737905
3
+ size 174862196
model_cfg.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "d_model": 384,
4
+ "n_layers": 6,
5
+ "ffn_mult": 4,
6
+ "dropout": 0.1,
7
+ "ltc_kernel": 7,
8
+ "gsp_state": 128,
9
+ "cbs_topk": 4,
10
+ "max_seq_len": 512
11
+ }
ssw_model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ # ---------------------
8
+ # Utility Layers
9
+ # ---------------------
10
+ class RMSNorm(nn.Module):
11
+ def __init__(self, d: int, eps: float = 1e-6):
12
+ super().__init__()
13
+ self.eps = eps
14
+ self.weight = nn.Parameter(torch.ones(d))
15
+ def forward(self, x):
16
+ norm = x.pow(2).mean(-1, keepdim=True)
17
+ x = x * torch.rsqrt(norm + self.eps)
18
+ return self.weight * x
19
+
20
+ class FeedForward(nn.Module):
21
+ def __init__(self, d_model: int, mult: int = 4, dropout: float = 0.0):
22
+ super().__init__()
23
+ inner = d_model * mult
24
+ self.net = nn.Sequential(
25
+ nn.Linear(d_model, inner * 2), # GEGLU
26
+ nn.GLU(dim=-1),
27
+ nn.Linear(inner, d_model),
28
+ nn.Dropout(dropout),
29
+ )
30
+ def forward(self, x):
31
+ return self.net(x)
32
+
33
+ # ---------------------
34
+ # SSW Components
35
+ # ---------------------
36
+ class LocalTextureConv(nn.Module):
37
+ """Depthwise 1D conv + GLU gate. Causal padding. O(n * d * k) with small k."""
38
+ def __init__(self, d_model: int, kernel_size: int = 7):
39
+ super().__init__()
40
+ assert kernel_size % 2 == 1, "kernel_size should be odd for simple causal pad"
41
+ self.dw = nn.Conv1d(d_model, d_model, kernel_size, groups=d_model, padding=kernel_size-1)
42
+ self.pw = nn.Conv1d(d_model, 2 * d_model, 1)
43
+ def forward(self, x):
44
+ # x: (B, T, C)
45
+ x_c = x.transpose(1, 2) # (B, C, T)
46
+ y = self.dw(x_c)
47
+ T = x.size(1)
48
+ y = y[..., :T] # causal crop
49
+ y = self.pw(y).transpose(1, 2) # (B, T, 2C)
50
+ y = F.glu(y, dim=-1) # (B, T, C)
51
+ return y
52
+
53
+ class GlobalStatePropagation(nn.Module):
54
+ """Simplified selective SSM-like recurrence (toy, readable)."""
55
+ def __init__(self, d_model: int, state_size: int = 128):
56
+ super().__init__()
57
+ self.state_size = state_size
58
+ self.inp = nn.Linear(d_model, state_size * 3)
59
+ self.out = nn.Linear(state_size, d_model)
60
+ def forward(self, x):
61
+ B, T, _ = x.size()
62
+ u, f, r = self.inp(x).chunk(3, dim=-1)
63
+ f = torch.sigmoid(f)
64
+ r = torch.sigmoid(r)
65
+ u = torch.tanh(u)
66
+ h = torch.zeros(B, self.state_size, device=x.device, dtype=x.dtype)
67
+ outs = []
68
+ for t in range(T):
69
+ h = f[:, t] * h + (1 - f[:, t]) * u[:, t]
70
+ outs.append(r[:, t] * h)
71
+ y = torch.stack(outs, dim=1) # (B, T, S)
72
+ return self.out(y) # (B, T, C)
73
+
74
+ class ContentBasedSummarizer(nn.Module):
75
+ """Top-k sparse attention over history (causal)."""
76
+ def __init__(self, d_model: int, top_k: int = 8):
77
+ super().__init__()
78
+ self.k = top_k
79
+ self.q = nn.Linear(d_model, d_model, bias=False)
80
+ self.kv = nn.Linear(d_model, 2 * d_model, bias=False)
81
+ self.scale = 1.0 / math.sqrt(d_model)
82
+ self.scorer = nn.Linear(d_model, 1, bias=False)
83
+ def forward(self, x):
84
+ B, T, C = x.size()
85
+ q = self.q(x)
86
+ k, v = self.kv(x).chunk(2, dim=-1)
87
+ imp = self.scorer(x).squeeze(-1) # (B, T)
88
+ out = torch.zeros_like(x)
89
+ for t in range(T):
90
+ topk = min(self.k, t + 1)
91
+ vals, idx = torch.topk(imp[:, :t+1], k=topk, dim=-1)
92
+ k_sel = torch.gather(k[:, :t+1, :], 1, idx.unsqueeze(-1).expand(-1, -1, C))
93
+ v_sel = torch.gather(v[:, :t+1, :], 1, idx.unsqueeze(-1).expand(-1, -1, C))
94
+ q_t = q[:, t:t+1, :]
95
+ att = torch.matmul(q_t, k_sel.transpose(1, 2)) * self.scale
96
+ att = F.softmax(att, dim=-1)
97
+ out[:, t:t+1, :] = torch.matmul(att, v_sel)
98
+ return out
99
+
100
+ class WeaverBlock(nn.Module):
101
+ def __init__(self, d_model: int, ltc_kernel: int, gsp_state: int, cbs_topk: int, dropout: float):
102
+ super().__init__()
103
+ self.norm1 = RMSNorm(d_model)
104
+ self.ltc = LocalTextureConv(d_model, kernel_size=ltc_kernel)
105
+ self.gsp = GlobalStatePropagation(d_model, state_size=gsp_state)
106
+ self.cbs = ContentBasedSummarizer(d_model, top_k=cbs_topk)
107
+ self.mix = nn.Linear(d_model * 3, d_model)
108
+ self.dropout = nn.Dropout(dropout)
109
+ self.norm2 = RMSNorm(d_model)
110
+ self.ff = FeedForward(d_model, mult=4, dropout=dropout)
111
+ def forward(self, x):
112
+ h = self.norm1(x)
113
+ a = self.ltc(h)
114
+ b = self.gsp(h)
115
+ c = self.cbs(h)
116
+ h = self.mix(torch.cat([a, b, c], dim=-1))
117
+ x = x + self.dropout(h)
118
+ x = x + self.ff(self.norm2(x))
119
+ return x
120
+
121
+ class SSWLM(nn.Module):
122
+ def __init__(self, vocab_size: int, d_model: int = 512, n_layers: int = 8,
123
+ ltc_kernel: int = 7, gsp_state: int = 128, cbs_topk: int = 8,
124
+ dropout: float = 0.1, max_seq_len: int = 1024):
125
+ super().__init__()
126
+ self.tok_emb = nn.Embedding(vocab_size, d_model)
127
+ self.pos_emb = nn.Embedding(max_seq_len, d_model)
128
+ self.layers = nn.ModuleList([
129
+ WeaverBlock(d_model, ltc_kernel, gsp_state, cbs_topk, dropout)
130
+ for _ in range(n_layers)
131
+ ])
132
+ self.norm = RMSNorm(d_model)
133
+ self.head = nn.Linear(d_model, vocab_size, bias=False)
134
+ self.max_seq_len = max_seq_len
135
+
136
+ def forward(self, input_ids: torch.Tensor):
137
+ B, T = input_ids.size()
138
+ assert T <= self.max_seq_len, "sequence too long"
139
+ pos = torch.arange(T, device=input_ids.device)
140
+ x = self.tok_emb(input_ids) + self.pos_emb(pos)[None, :, :]
141
+ for blk in self.layers:
142
+ x = blk(x)
143
+ x = self.norm(x)
144
+ return self.head(x)
145
+
146
+ @torch.no_grad()
147
+ def generate(
148
+ self,
149
+ input_ids: torch.Tensor,
150
+ max_new_tokens: int = 100,
151
+ temperature: float = 1.0,
152
+ top_p: float = 0.9,
153
+ top_k: int = 50,
154
+ repetition_penalty: float = 1.1,
155
+ eos_token_id: Optional[int] = None,
156
+ ):
157
+ self.eval()
158
+ for _ in range(max_new_tokens):
159
+ inp = input_ids[:, -self.max_seq_len:]
160
+ logits = self.forward(inp)[:, -1, :] / max(1e-6, temperature)
161
+
162
+ # repetition penalty (simple): downweight already seen token logits
163
+ if repetition_penalty and repetition_penalty > 1.0:
164
+ for b in range(input_ids.size(0)):
165
+ seen = torch.bincount(input_ids[b], minlength=logits.size(-1)).bool()
166
+ logits[b, seen] /= repetition_penalty
167
+
168
+ # top-k filter
169
+ if top_k and top_k > 0:
170
+ k = min(top_k, logits.size(-1))
171
+ topk_vals, topk_idx = torch.topk(logits, k=k, dim=-1)
172
+ mask = torch.full_like(logits, float("-inf"))
173
+ logits = mask.scatter(1, topk_idx, topk_vals)
174
+
175
+ # nucleus (top-p) filter
176
+ if top_p < 1.0:
177
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
178
+ probs = torch.softmax(sorted_logits, dim=-1)
179
+ cumsum = torch.cumsum(probs, dim=-1)
180
+ cutoff = cumsum > top_p
181
+ cutoff[..., 0] = False # keep at least one
182
+ sorted_logits[cutoff] = float("-inf")
183
+ # unsort back
184
+ inv_idx = torch.argsort(sorted_idx, dim=-1)
185
+ logits = torch.gather(sorted_logits, 1, inv_idx)
186
+
187
+ probs = torch.softmax(logits, dim=-1)
188
+ next_token = torch.multinomial(probs, num_samples=1)
189
+ input_ids = torch.cat([input_ids, next_token], dim=1)
190
+
191
+ if eos_token_id is not None and (next_token == eos_token_id).all():
192
+ break
193
+ return input_ids
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff