Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
# Character-level text generator with PyTorch (RNN/LSTM/GRU) + Gradio UI
|
3 |
+
import io
|
4 |
+
import math
|
5 |
+
import time
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import gradio as gr
|
12 |
+
|
13 |
+
# -----------------------
|
14 |
+
# Utilities
|
15 |
+
# -----------------------
|
16 |
+
def set_seed(seed: int = 42):
|
17 |
+
random.seed(seed)
|
18 |
+
np.random.seed(seed)
|
19 |
+
torch.manual_seed(seed)
|
20 |
+
torch.cuda.manual_seed_all(seed)
|
21 |
+
|
22 |
+
set_seed(42)
|
23 |
+
|
24 |
+
def pick_device(force_cpu: bool = False):
|
25 |
+
if not force_cpu and torch.cuda.is_available():
|
26 |
+
return torch.device("cuda")
|
27 |
+
return torch.device("cpu")
|
28 |
+
|
29 |
+
def text_from_file(f):
|
30 |
+
if f is None:
|
31 |
+
return ""
|
32 |
+
name = getattr(f, "name", None)
|
33 |
+
try:
|
34 |
+
# Try to read as bytes then decode safely
|
35 |
+
raw = f.read()
|
36 |
+
if isinstance(raw, bytes):
|
37 |
+
try:
|
38 |
+
return raw.decode("utf-8")
|
39 |
+
except UnicodeDecodeError:
|
40 |
+
return raw.decode("latin-1", errors="ignore")
|
41 |
+
return str(raw)
|
42 |
+
except Exception:
|
43 |
+
# Fallback: if it's a path-like
|
44 |
+
if name:
|
45 |
+
try:
|
46 |
+
with open(name, "rb") as fh:
|
47 |
+
raw = fh.read()
|
48 |
+
try:
|
49 |
+
return raw.decode("utf-8")
|
50 |
+
except UnicodeDecodeError:
|
51 |
+
return raw.decode("latin-1", errors="ignore")
|
52 |
+
except Exception:
|
53 |
+
return ""
|
54 |
+
return ""
|
55 |
+
|
56 |
+
def build_vocab(text: str):
|
57 |
+
chars = sorted(list(set(text)))
|
58 |
+
stoi = {ch: i for i, ch in enumerate(chars)}
|
59 |
+
itos = {i: ch for ch, i in stoi.items()}
|
60 |
+
return stoi, itos, len(chars)
|
61 |
+
|
62 |
+
def encode(text: str, stoi: dict):
|
63 |
+
return torch.tensor([stoi[c] for c in text], dtype=torch.long)
|
64 |
+
|
65 |
+
def decode(indices, itos: dict):
|
66 |
+
return "".join([itos[int(i)] for i in indices])
|
67 |
+
|
68 |
+
def make_batches(data_ids: torch.Tensor, seq_len: int, batch_size: int):
|
69 |
+
# data_ids: shape [T]
|
70 |
+
total_len = data_ids.size(0)
|
71 |
+
num_sequences = (total_len - 1) // seq_len
|
72 |
+
if num_sequences <= 0:
|
73 |
+
return []
|
74 |
+
|
75 |
+
# Trim so we have full sequences
|
76 |
+
trimmed = num_sequences * seq_len
|
77 |
+
x = data_ids[:trimmed]
|
78 |
+
y = data_ids[1:trimmed + 1]
|
79 |
+
|
80 |
+
x = x.view(num_sequences, seq_len)
|
81 |
+
y = y.view(num_sequences, seq_len)
|
82 |
+
|
83 |
+
# Shuffle sequence order
|
84 |
+
perm = torch.randperm(num_sequences)
|
85 |
+
x = x[perm]
|
86 |
+
y = y[perm]
|
87 |
+
|
88 |
+
# Batch into mini-batches
|
89 |
+
batches = []
|
90 |
+
for i in range(0, num_sequences, batch_size):
|
91 |
+
xb = x[i:i + batch_size]
|
92 |
+
yb = y[i:i + batch_size]
|
93 |
+
if xb.size(0) == batch_size: # keep full batches only for simplicity
|
94 |
+
batches.append((xb, yb))
|
95 |
+
return batches
|
96 |
+
|
97 |
+
# -----------------------
|
98 |
+
# Model
|
99 |
+
# -----------------------
|
100 |
+
class CharRNN(nn.Module):
|
101 |
+
def __init__(self, vocab_size: int, embed_dim: int, hidden_size: int, num_layers: int, cell_type: str = "LSTM", dropout: float = 0.0):
|
102 |
+
super().__init__()
|
103 |
+
assert cell_type in {"RNN", "LSTM", "GRU"}
|
104 |
+
self.vocab_size = vocab_size
|
105 |
+
self.embed = nn.Embedding(vocab_size, embed_dim)
|
106 |
+
self.cell_type = cell_type
|
107 |
+
self.hidden_size = hidden_size
|
108 |
+
self.num_layers = num_layers
|
109 |
+
|
110 |
+
if cell_type == "RNN":
|
111 |
+
self.rnn = nn.RNN(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
|
112 |
+
elif cell_type == "GRU":
|
113 |
+
self.rnn = nn.GRU(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
|
114 |
+
else: # LSTM
|
115 |
+
self.rnn = nn.LSTM(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
|
116 |
+
|
117 |
+
self.fc = nn.Linear(hidden_size, vocab_size)
|
118 |
+
|
119 |
+
def forward(self, x, hidden=None):
|
120 |
+
# x: [B, T]
|
121 |
+
x = self.embed(x) # [B, T, E]
|
122 |
+
out, hidden = self.rnn(x, hidden) # out: [B, T, H]
|
123 |
+
logits = self.fc(out) # [B, T, V]
|
124 |
+
return logits, hidden
|
125 |
+
|
126 |
+
def init_hidden(self, batch_size: int, device):
|
127 |
+
if self.cell_type == "LSTM":
|
128 |
+
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
|
129 |
+
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
|
130 |
+
return (h0, c0)
|
131 |
+
else:
|
132 |
+
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
|
133 |
+
return h0
|
134 |
+
|
135 |
+
# -----------------------
|
136 |
+
# Training & Generation
|
137 |
+
# -----------------------
|
138 |
+
def temperature_sample(logits: torch.Tensor, temperature: float):
|
139 |
+
# logits: [V]
|
140 |
+
if temperature <= 0:
|
141 |
+
# argmax
|
142 |
+
return int(torch.argmax(logits).item())
|
143 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
144 |
+
return int(torch.multinomial(probs, num_samples=1).item())
|
145 |
+
|
146 |
+
def generate_text(model: CharRNN, stoi, itos, prime: str, length: int, temperature: float, device):
|
147 |
+
if len(prime) == 0:
|
148 |
+
# start from random char if prime is empty
|
149 |
+
prime = random.choice(list(stoi.keys()))
|
150 |
+
model.eval()
|
151 |
+
# Feed prime to warm up hidden state
|
152 |
+
input_ids = torch.tensor([[stoi.get(ch, 0) for ch in prime]], dtype=torch.long, device=device)
|
153 |
+
hidden = model.init_hidden(batch_size=1, device=device)
|
154 |
+
with torch.no_grad():
|
155 |
+
logits, hidden = model(input_ids, hidden)
|
156 |
+
last_char_id = input_ids[0, -1]
|
157 |
+
|
158 |
+
generated = [ch for ch in prime]
|
159 |
+
for _ in range(length):
|
160 |
+
with torch.no_grad():
|
161 |
+
# take last timestep logits
|
162 |
+
last_logits = logits[0, -1, :] # [V]
|
163 |
+
next_id = temperature_sample(last_logits, temperature)
|
164 |
+
generated.append(itos[next_id])
|
165 |
+
|
166 |
+
# next step
|
167 |
+
inp = torch.tensor([[next_id]], dtype=torch.long, device=device)
|
168 |
+
logits, hidden = model(inp, hidden)
|
169 |
+
|
170 |
+
return "".join(generated)
|
171 |
+
|
172 |
+
def train_one_run(
|
173 |
+
raw_text: str,
|
174 |
+
model_type: str = "LSTM",
|
175 |
+
embed_dim: int = 64,
|
176 |
+
hidden_size: int = 128,
|
177 |
+
num_layers: int = 2,
|
178 |
+
seq_len: int = 64,
|
179 |
+
batch_size: int = 16,
|
180 |
+
lr: float = 0.003,
|
181 |
+
epochs: int = 3,
|
182 |
+
dropout: float = 0.0,
|
183 |
+
temperature: float = 0.8,
|
184 |
+
generate_n_chars: int = 400,
|
185 |
+
prime_text: str = "The ",
|
186 |
+
force_cpu: bool = False,
|
187 |
+
):
|
188 |
+
t0 = time.time()
|
189 |
+
|
190 |
+
# Sanity text
|
191 |
+
text = raw_text.strip()
|
192 |
+
if len(text) < max(100, seq_len + 1):
|
193 |
+
# Provide a small default if not enough text
|
194 |
+
default = (
|
195 |
+
"To be, or not to be, that is the question:\n"
|
196 |
+
"Whether 'tis nobler in the mind to suffer\n"
|
197 |
+
"The slings and arrows of outrageous fortune,\n"
|
198 |
+
"Or to take arms against a sea of troubles\n"
|
199 |
+
"And by opposing end them."
|
200 |
+
)
|
201 |
+
text = (text + "\n" + default).strip()
|
202 |
+
|
203 |
+
stoi, itos, vocab_size = build_vocab(text)
|
204 |
+
data_ids = encode(text, stoi)
|
205 |
+
|
206 |
+
device = pick_device(force_cpu=force_cpu)
|
207 |
+
|
208 |
+
model = CharRNN(
|
209 |
+
vocab_size=vocab_size,
|
210 |
+
embed_dim=embed_dim,
|
211 |
+
hidden_size=hidden_size,
|
212 |
+
num_layers=num_layers,
|
213 |
+
cell_type=model_type,
|
214 |
+
dropout=dropout,
|
215 |
+
).to(device)
|
216 |
+
|
217 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
218 |
+
criterion = nn.CrossEntropyLoss()
|
219 |
+
|
220 |
+
losses = []
|
221 |
+
model.train()
|
222 |
+
for ep in range(1, epochs + 1):
|
223 |
+
batches = make_batches(data_ids, seq_len=seq_len, batch_size=batch_size)
|
224 |
+
if len(batches) == 0:
|
225 |
+
raise ValueError("Not enough data to make at least one full batch. Increase text length or reduce seq_len/batch_size.")
|
226 |
+
ep_loss = 0.0
|
227 |
+
count = 0
|
228 |
+
|
229 |
+
for xb, yb in batches:
|
230 |
+
xb = xb.to(device) # [B, T]
|
231 |
+
yb = yb.to(device) # [B, T]
|
232 |
+
hidden = model.init_hidden(batch_size=xb.size(0), device=device)
|
233 |
+
|
234 |
+
optimizer.zero_grad()
|
235 |
+
logits, _ = model(xb, hidden) # [B, T, V]
|
236 |
+
# Reshape for CE loss
|
237 |
+
B, T, V = logits.shape
|
238 |
+
loss = criterion(logits.view(B * T, V), yb.view(B * T))
|
239 |
+
loss.backward()
|
240 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
241 |
+
optimizer.step()
|
242 |
+
|
243 |
+
ep_loss += loss.item()
|
244 |
+
count += 1
|
245 |
+
|
246 |
+
mean_loss = ep_loss / max(count, 1)
|
247 |
+
losses.append(mean_loss)
|
248 |
+
|
249 |
+
train_time = time.time() - t0
|
250 |
+
|
251 |
+
# Generate
|
252 |
+
gen = generate_text(
|
253 |
+
model=model,
|
254 |
+
stoi=stoi,
|
255 |
+
itos=itos,
|
256 |
+
prime=prime_text,
|
257 |
+
length=generate_n_chars,
|
258 |
+
temperature=temperature,
|
259 |
+
device=device,
|
260 |
+
)
|
261 |
+
|
262 |
+
# Plot loss curve
|
263 |
+
fig = plt.figure(figsize=(5, 3.2), dpi=140)
|
264 |
+
xs = np.arange(1, len(losses) + 1)
|
265 |
+
plt.plot(xs, losses, marker="o")
|
266 |
+
plt.xlabel("Epoch")
|
267 |
+
plt.ylabel("Training Loss")
|
268 |
+
plt.title("Loss per Epoch")
|
269 |
+
plt.tight_layout()
|
270 |
+
|
271 |
+
# Hyperparameters summary (to show clearly in UI)
|
272 |
+
hparams = {
|
273 |
+
"model_type": model_type,
|
274 |
+
"embed_dim": embed_dim,
|
275 |
+
"hidden_size": hidden_size,
|
276 |
+
"num_layers": num_layers,
|
277 |
+
"seq_len": seq_len,
|
278 |
+
"batch_size": batch_size,
|
279 |
+
"learning_rate": lr,
|
280 |
+
"epochs": epochs,
|
281 |
+
"dropout": dropout,
|
282 |
+
"temperature": temperature,
|
283 |
+
"generate_n_chars": generate_n_chars,
|
284 |
+
"vocab_size": vocab_size,
|
285 |
+
"device": str(device),
|
286 |
+
"train_time_sec": round(train_time, 3),
|
287 |
+
"num_batches_per_epoch": len(make_batches(data_ids, seq_len, batch_size)),
|
288 |
+
"text_len": len(text),
|
289 |
+
}
|
290 |
+
|
291 |
+
return gen, fig, hparams
|
292 |
+
|
293 |
+
# -----------------------
|
294 |
+
# Gradio Interface
|
295 |
+
# -----------------------
|
296 |
+
EXAMPLE_TEXT = (
|
297 |
+
"In the beginning God created the heaven and the earth. "
|
298 |
+
"And the earth was without form, and void; and darkness was upon the face of the deep. "
|
299 |
+
"And the Spirit of God moved upon the face of the waters. "
|
300 |
+
"And God said, Let there be light: and there was light."
|
301 |
+
)
|
302 |
+
|
303 |
+
with gr.Blocks(title="PyTorch RNN/LSTM/GRU Text Generator") as demo:
|
304 |
+
gr.Markdown(
|
305 |
+
"""# 🔠 PyTorch Character RNN / LSTM / GRU
|
306 |
+
Train a tiny character-level model on your text and generate new text.
|
307 |
+
Use the controls below to **show & tweak hyperparameters**, provide **input text**, and view **outputs** (generated text + loss curve).
|
308 |
+
"""
|
309 |
+
)
|
310 |
+
with gr.Row():
|
311 |
+
with gr.Column():
|
312 |
+
corpus = gr.Textbox(
|
313 |
+
label="Training Text (paste here)",
|
314 |
+
value=EXAMPLE_TEXT,
|
315 |
+
lines=10,
|
316 |
+
placeholder="Paste training text here… (longer is better)"
|
317 |
+
)
|
318 |
+
upload = gr.File(label="Or upload a .txt file (optional)", file_count="single")
|
319 |
+
prime_text = gr.Textbox(label="Prime text (seed for generation)", value="The ", lines=1)
|
320 |
+
|
321 |
+
with gr.Row():
|
322 |
+
model_type = gr.Radio(choices=["LSTM", "GRU", "RNN"], value="LSTM", label="Model type")
|
323 |
+
force_cpu = gr.Checkbox(value=False, label="Force CPU (uncheck to use GPU if available)")
|
324 |
+
|
325 |
+
with gr.Accordion("Hyperparameters", open=True):
|
326 |
+
with gr.Row():
|
327 |
+
embed_dim = gr.Slider(16, 256, value=64, step=1, label="Embedding Dim")
|
328 |
+
hidden_size = gr.Slider(32, 512, value=128, step=1, label="Hidden Size")
|
329 |
+
with gr.Row():
|
330 |
+
num_layers = gr.Slider(1, 4, value=2, step=1, label="Layers")
|
331 |
+
dropout = gr.Slider(0.0, 0.6, value=0.0, step=0.05, label="Dropout")
|
332 |
+
with gr.Row():
|
333 |
+
seq_len = gr.Slider(16, 256, value=64, step=1, label="Sequence Length")
|
334 |
+
batch_size = gr.Slider(4, 128, value=16, step=1, label="Batch Size")
|
335 |
+
with gr.Row():
|
336 |
+
lr = gr.Number(value=0.003, precision=6, label="Learning Rate")
|
337 |
+
epochs = gr.Slider(1, 15, value=3, step=1, label="Epochs")
|
338 |
+
with gr.Row():
|
339 |
+
temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature (sampling)")
|
340 |
+
generate_n_chars = gr.Slider(50, 1000, value=400, step=10, label="Generate N Chars")
|
341 |
+
|
342 |
+
run_btn = gr.Button("🚀 Train & Generate", variant="primary")
|
343 |
+
|
344 |
+
with gr.Column():
|
345 |
+
gen_text = gr.Textbox(label="Generated Text (output)", lines=20)
|
346 |
+
loss_plot = gr.Plot(label="Training Loss")
|
347 |
+
hparams_json = gr.JSON(label="Hyperparameters (for your records)")
|
348 |
+
|
349 |
+
def run_pipeline(corpus_text, uploaded_file, **kwargs):
|
350 |
+
merged = (text_from_file(uploaded_file) + "\n" + (corpus_text or "")).strip() if uploaded_file else (corpus_text or "")
|
351 |
+
out_text, fig, hp = train_one_run(raw_text=merged, **kwargs)
|
352 |
+
return out_text, fig, hp
|
353 |
+
|
354 |
+
run_btn.click(
|
355 |
+
run_pipeline,
|
356 |
+
inputs=[
|
357 |
+
corpus, upload,
|
358 |
+
{"label": "model_type"},
|
359 |
+
{"label": "embed_dim"},
|
360 |
+
{"label": "hidden_size"},
|
361 |
+
{"label": "num_layers"},
|
362 |
+
{"label": "seq_len"},
|
363 |
+
{"label": "batch_size"},
|
364 |
+
{"label": "lr"},
|
365 |
+
{"label": "epochs"},
|
366 |
+
{"label": "dropout"},
|
367 |
+
{"label": "temperature"},
|
368 |
+
{"label": "generate_n_chars"},
|
369 |
+
prime_text,
|
370 |
+
force_cpu
|
371 |
+
],
|
372 |
+
outputs=[gen_text, loss_plot, hparams_json],
|
373 |
+
queue=False
|
374 |
+
)
|
375 |
+
|
376 |
+
if __name__ == "__main__":
|
377 |
+
demo.launch()
|