eaglelandsonce commited on
Commit
f9a9b57
·
verified ·
1 Parent(s): e62de68

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +377 -0
app.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Character-level text generator with PyTorch (RNN/LSTM/GRU) + Gradio UI
3
+ import io
4
+ import math
5
+ import time
6
+ import random
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import matplotlib.pyplot as plt
11
+ import gradio as gr
12
+
13
+ # -----------------------
14
+ # Utilities
15
+ # -----------------------
16
+ def set_seed(seed: int = 42):
17
+ random.seed(seed)
18
+ np.random.seed(seed)
19
+ torch.manual_seed(seed)
20
+ torch.cuda.manual_seed_all(seed)
21
+
22
+ set_seed(42)
23
+
24
+ def pick_device(force_cpu: bool = False):
25
+ if not force_cpu and torch.cuda.is_available():
26
+ return torch.device("cuda")
27
+ return torch.device("cpu")
28
+
29
+ def text_from_file(f):
30
+ if f is None:
31
+ return ""
32
+ name = getattr(f, "name", None)
33
+ try:
34
+ # Try to read as bytes then decode safely
35
+ raw = f.read()
36
+ if isinstance(raw, bytes):
37
+ try:
38
+ return raw.decode("utf-8")
39
+ except UnicodeDecodeError:
40
+ return raw.decode("latin-1", errors="ignore")
41
+ return str(raw)
42
+ except Exception:
43
+ # Fallback: if it's a path-like
44
+ if name:
45
+ try:
46
+ with open(name, "rb") as fh:
47
+ raw = fh.read()
48
+ try:
49
+ return raw.decode("utf-8")
50
+ except UnicodeDecodeError:
51
+ return raw.decode("latin-1", errors="ignore")
52
+ except Exception:
53
+ return ""
54
+ return ""
55
+
56
+ def build_vocab(text: str):
57
+ chars = sorted(list(set(text)))
58
+ stoi = {ch: i for i, ch in enumerate(chars)}
59
+ itos = {i: ch for ch, i in stoi.items()}
60
+ return stoi, itos, len(chars)
61
+
62
+ def encode(text: str, stoi: dict):
63
+ return torch.tensor([stoi[c] for c in text], dtype=torch.long)
64
+
65
+ def decode(indices, itos: dict):
66
+ return "".join([itos[int(i)] for i in indices])
67
+
68
+ def make_batches(data_ids: torch.Tensor, seq_len: int, batch_size: int):
69
+ # data_ids: shape [T]
70
+ total_len = data_ids.size(0)
71
+ num_sequences = (total_len - 1) // seq_len
72
+ if num_sequences <= 0:
73
+ return []
74
+
75
+ # Trim so we have full sequences
76
+ trimmed = num_sequences * seq_len
77
+ x = data_ids[:trimmed]
78
+ y = data_ids[1:trimmed + 1]
79
+
80
+ x = x.view(num_sequences, seq_len)
81
+ y = y.view(num_sequences, seq_len)
82
+
83
+ # Shuffle sequence order
84
+ perm = torch.randperm(num_sequences)
85
+ x = x[perm]
86
+ y = y[perm]
87
+
88
+ # Batch into mini-batches
89
+ batches = []
90
+ for i in range(0, num_sequences, batch_size):
91
+ xb = x[i:i + batch_size]
92
+ yb = y[i:i + batch_size]
93
+ if xb.size(0) == batch_size: # keep full batches only for simplicity
94
+ batches.append((xb, yb))
95
+ return batches
96
+
97
+ # -----------------------
98
+ # Model
99
+ # -----------------------
100
+ class CharRNN(nn.Module):
101
+ def __init__(self, vocab_size: int, embed_dim: int, hidden_size: int, num_layers: int, cell_type: str = "LSTM", dropout: float = 0.0):
102
+ super().__init__()
103
+ assert cell_type in {"RNN", "LSTM", "GRU"}
104
+ self.vocab_size = vocab_size
105
+ self.embed = nn.Embedding(vocab_size, embed_dim)
106
+ self.cell_type = cell_type
107
+ self.hidden_size = hidden_size
108
+ self.num_layers = num_layers
109
+
110
+ if cell_type == "RNN":
111
+ self.rnn = nn.RNN(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
112
+ elif cell_type == "GRU":
113
+ self.rnn = nn.GRU(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
114
+ else: # LSTM
115
+ self.rnn = nn.LSTM(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
116
+
117
+ self.fc = nn.Linear(hidden_size, vocab_size)
118
+
119
+ def forward(self, x, hidden=None):
120
+ # x: [B, T]
121
+ x = self.embed(x) # [B, T, E]
122
+ out, hidden = self.rnn(x, hidden) # out: [B, T, H]
123
+ logits = self.fc(out) # [B, T, V]
124
+ return logits, hidden
125
+
126
+ def init_hidden(self, batch_size: int, device):
127
+ if self.cell_type == "LSTM":
128
+ h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
129
+ c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
130
+ return (h0, c0)
131
+ else:
132
+ h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
133
+ return h0
134
+
135
+ # -----------------------
136
+ # Training & Generation
137
+ # -----------------------
138
+ def temperature_sample(logits: torch.Tensor, temperature: float):
139
+ # logits: [V]
140
+ if temperature <= 0:
141
+ # argmax
142
+ return int(torch.argmax(logits).item())
143
+ probs = torch.softmax(logits / temperature, dim=-1)
144
+ return int(torch.multinomial(probs, num_samples=1).item())
145
+
146
+ def generate_text(model: CharRNN, stoi, itos, prime: str, length: int, temperature: float, device):
147
+ if len(prime) == 0:
148
+ # start from random char if prime is empty
149
+ prime = random.choice(list(stoi.keys()))
150
+ model.eval()
151
+ # Feed prime to warm up hidden state
152
+ input_ids = torch.tensor([[stoi.get(ch, 0) for ch in prime]], dtype=torch.long, device=device)
153
+ hidden = model.init_hidden(batch_size=1, device=device)
154
+ with torch.no_grad():
155
+ logits, hidden = model(input_ids, hidden)
156
+ last_char_id = input_ids[0, -1]
157
+
158
+ generated = [ch for ch in prime]
159
+ for _ in range(length):
160
+ with torch.no_grad():
161
+ # take last timestep logits
162
+ last_logits = logits[0, -1, :] # [V]
163
+ next_id = temperature_sample(last_logits, temperature)
164
+ generated.append(itos[next_id])
165
+
166
+ # next step
167
+ inp = torch.tensor([[next_id]], dtype=torch.long, device=device)
168
+ logits, hidden = model(inp, hidden)
169
+
170
+ return "".join(generated)
171
+
172
+ def train_one_run(
173
+ raw_text: str,
174
+ model_type: str = "LSTM",
175
+ embed_dim: int = 64,
176
+ hidden_size: int = 128,
177
+ num_layers: int = 2,
178
+ seq_len: int = 64,
179
+ batch_size: int = 16,
180
+ lr: float = 0.003,
181
+ epochs: int = 3,
182
+ dropout: float = 0.0,
183
+ temperature: float = 0.8,
184
+ generate_n_chars: int = 400,
185
+ prime_text: str = "The ",
186
+ force_cpu: bool = False,
187
+ ):
188
+ t0 = time.time()
189
+
190
+ # Sanity text
191
+ text = raw_text.strip()
192
+ if len(text) < max(100, seq_len + 1):
193
+ # Provide a small default if not enough text
194
+ default = (
195
+ "To be, or not to be, that is the question:\n"
196
+ "Whether 'tis nobler in the mind to suffer\n"
197
+ "The slings and arrows of outrageous fortune,\n"
198
+ "Or to take arms against a sea of troubles\n"
199
+ "And by opposing end them."
200
+ )
201
+ text = (text + "\n" + default).strip()
202
+
203
+ stoi, itos, vocab_size = build_vocab(text)
204
+ data_ids = encode(text, stoi)
205
+
206
+ device = pick_device(force_cpu=force_cpu)
207
+
208
+ model = CharRNN(
209
+ vocab_size=vocab_size,
210
+ embed_dim=embed_dim,
211
+ hidden_size=hidden_size,
212
+ num_layers=num_layers,
213
+ cell_type=model_type,
214
+ dropout=dropout,
215
+ ).to(device)
216
+
217
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
218
+ criterion = nn.CrossEntropyLoss()
219
+
220
+ losses = []
221
+ model.train()
222
+ for ep in range(1, epochs + 1):
223
+ batches = make_batches(data_ids, seq_len=seq_len, batch_size=batch_size)
224
+ if len(batches) == 0:
225
+ raise ValueError("Not enough data to make at least one full batch. Increase text length or reduce seq_len/batch_size.")
226
+ ep_loss = 0.0
227
+ count = 0
228
+
229
+ for xb, yb in batches:
230
+ xb = xb.to(device) # [B, T]
231
+ yb = yb.to(device) # [B, T]
232
+ hidden = model.init_hidden(batch_size=xb.size(0), device=device)
233
+
234
+ optimizer.zero_grad()
235
+ logits, _ = model(xb, hidden) # [B, T, V]
236
+ # Reshape for CE loss
237
+ B, T, V = logits.shape
238
+ loss = criterion(logits.view(B * T, V), yb.view(B * T))
239
+ loss.backward()
240
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
241
+ optimizer.step()
242
+
243
+ ep_loss += loss.item()
244
+ count += 1
245
+
246
+ mean_loss = ep_loss / max(count, 1)
247
+ losses.append(mean_loss)
248
+
249
+ train_time = time.time() - t0
250
+
251
+ # Generate
252
+ gen = generate_text(
253
+ model=model,
254
+ stoi=stoi,
255
+ itos=itos,
256
+ prime=prime_text,
257
+ length=generate_n_chars,
258
+ temperature=temperature,
259
+ device=device,
260
+ )
261
+
262
+ # Plot loss curve
263
+ fig = plt.figure(figsize=(5, 3.2), dpi=140)
264
+ xs = np.arange(1, len(losses) + 1)
265
+ plt.plot(xs, losses, marker="o")
266
+ plt.xlabel("Epoch")
267
+ plt.ylabel("Training Loss")
268
+ plt.title("Loss per Epoch")
269
+ plt.tight_layout()
270
+
271
+ # Hyperparameters summary (to show clearly in UI)
272
+ hparams = {
273
+ "model_type": model_type,
274
+ "embed_dim": embed_dim,
275
+ "hidden_size": hidden_size,
276
+ "num_layers": num_layers,
277
+ "seq_len": seq_len,
278
+ "batch_size": batch_size,
279
+ "learning_rate": lr,
280
+ "epochs": epochs,
281
+ "dropout": dropout,
282
+ "temperature": temperature,
283
+ "generate_n_chars": generate_n_chars,
284
+ "vocab_size": vocab_size,
285
+ "device": str(device),
286
+ "train_time_sec": round(train_time, 3),
287
+ "num_batches_per_epoch": len(make_batches(data_ids, seq_len, batch_size)),
288
+ "text_len": len(text),
289
+ }
290
+
291
+ return gen, fig, hparams
292
+
293
+ # -----------------------
294
+ # Gradio Interface
295
+ # -----------------------
296
+ EXAMPLE_TEXT = (
297
+ "In the beginning God created the heaven and the earth. "
298
+ "And the earth was without form, and void; and darkness was upon the face of the deep. "
299
+ "And the Spirit of God moved upon the face of the waters. "
300
+ "And God said, Let there be light: and there was light."
301
+ )
302
+
303
+ with gr.Blocks(title="PyTorch RNN/LSTM/GRU Text Generator") as demo:
304
+ gr.Markdown(
305
+ """# 🔠 PyTorch Character RNN / LSTM / GRU
306
+ Train a tiny character-level model on your text and generate new text.
307
+ Use the controls below to **show & tweak hyperparameters**, provide **input text**, and view **outputs** (generated text + loss curve).
308
+ """
309
+ )
310
+ with gr.Row():
311
+ with gr.Column():
312
+ corpus = gr.Textbox(
313
+ label="Training Text (paste here)",
314
+ value=EXAMPLE_TEXT,
315
+ lines=10,
316
+ placeholder="Paste training text here… (longer is better)"
317
+ )
318
+ upload = gr.File(label="Or upload a .txt file (optional)", file_count="single")
319
+ prime_text = gr.Textbox(label="Prime text (seed for generation)", value="The ", lines=1)
320
+
321
+ with gr.Row():
322
+ model_type = gr.Radio(choices=["LSTM", "GRU", "RNN"], value="LSTM", label="Model type")
323
+ force_cpu = gr.Checkbox(value=False, label="Force CPU (uncheck to use GPU if available)")
324
+
325
+ with gr.Accordion("Hyperparameters", open=True):
326
+ with gr.Row():
327
+ embed_dim = gr.Slider(16, 256, value=64, step=1, label="Embedding Dim")
328
+ hidden_size = gr.Slider(32, 512, value=128, step=1, label="Hidden Size")
329
+ with gr.Row():
330
+ num_layers = gr.Slider(1, 4, value=2, step=1, label="Layers")
331
+ dropout = gr.Slider(0.0, 0.6, value=0.0, step=0.05, label="Dropout")
332
+ with gr.Row():
333
+ seq_len = gr.Slider(16, 256, value=64, step=1, label="Sequence Length")
334
+ batch_size = gr.Slider(4, 128, value=16, step=1, label="Batch Size")
335
+ with gr.Row():
336
+ lr = gr.Number(value=0.003, precision=6, label="Learning Rate")
337
+ epochs = gr.Slider(1, 15, value=3, step=1, label="Epochs")
338
+ with gr.Row():
339
+ temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature (sampling)")
340
+ generate_n_chars = gr.Slider(50, 1000, value=400, step=10, label="Generate N Chars")
341
+
342
+ run_btn = gr.Button("🚀 Train & Generate", variant="primary")
343
+
344
+ with gr.Column():
345
+ gen_text = gr.Textbox(label="Generated Text (output)", lines=20)
346
+ loss_plot = gr.Plot(label="Training Loss")
347
+ hparams_json = gr.JSON(label="Hyperparameters (for your records)")
348
+
349
+ def run_pipeline(corpus_text, uploaded_file, **kwargs):
350
+ merged = (text_from_file(uploaded_file) + "\n" + (corpus_text or "")).strip() if uploaded_file else (corpus_text or "")
351
+ out_text, fig, hp = train_one_run(raw_text=merged, **kwargs)
352
+ return out_text, fig, hp
353
+
354
+ run_btn.click(
355
+ run_pipeline,
356
+ inputs=[
357
+ corpus, upload,
358
+ {"label": "model_type"},
359
+ {"label": "embed_dim"},
360
+ {"label": "hidden_size"},
361
+ {"label": "num_layers"},
362
+ {"label": "seq_len"},
363
+ {"label": "batch_size"},
364
+ {"label": "lr"},
365
+ {"label": "epochs"},
366
+ {"label": "dropout"},
367
+ {"label": "temperature"},
368
+ {"label": "generate_n_chars"},
369
+ prime_text,
370
+ force_cpu
371
+ ],
372
+ outputs=[gen_text, loss_plot, hparams_json],
373
+ queue=False
374
+ )
375
+
376
+ if __name__ == "__main__":
377
+ demo.launch()