RNN_Test / app.py
eaglelandsonce's picture
Create app.py
f9a9b57 verified
raw
history blame
13.2 kB
# app.py
# Character-level text generator with PyTorch (RNN/LSTM/GRU) + Gradio UI
import io
import math
import time
import random
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import gradio as gr
# -----------------------
# Utilities
# -----------------------
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(42)
def pick_device(force_cpu: bool = False):
if not force_cpu and torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def text_from_file(f):
if f is None:
return ""
name = getattr(f, "name", None)
try:
# Try to read as bytes then decode safely
raw = f.read()
if isinstance(raw, bytes):
try:
return raw.decode("utf-8")
except UnicodeDecodeError:
return raw.decode("latin-1", errors="ignore")
return str(raw)
except Exception:
# Fallback: if it's a path-like
if name:
try:
with open(name, "rb") as fh:
raw = fh.read()
try:
return raw.decode("utf-8")
except UnicodeDecodeError:
return raw.decode("latin-1", errors="ignore")
except Exception:
return ""
return ""
def build_vocab(text: str):
chars = sorted(list(set(text)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
return stoi, itos, len(chars)
def encode(text: str, stoi: dict):
return torch.tensor([stoi[c] for c in text], dtype=torch.long)
def decode(indices, itos: dict):
return "".join([itos[int(i)] for i in indices])
def make_batches(data_ids: torch.Tensor, seq_len: int, batch_size: int):
# data_ids: shape [T]
total_len = data_ids.size(0)
num_sequences = (total_len - 1) // seq_len
if num_sequences <= 0:
return []
# Trim so we have full sequences
trimmed = num_sequences * seq_len
x = data_ids[:trimmed]
y = data_ids[1:trimmed + 1]
x = x.view(num_sequences, seq_len)
y = y.view(num_sequences, seq_len)
# Shuffle sequence order
perm = torch.randperm(num_sequences)
x = x[perm]
y = y[perm]
# Batch into mini-batches
batches = []
for i in range(0, num_sequences, batch_size):
xb = x[i:i + batch_size]
yb = y[i:i + batch_size]
if xb.size(0) == batch_size: # keep full batches only for simplicity
batches.append((xb, yb))
return batches
# -----------------------
# Model
# -----------------------
class CharRNN(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int, hidden_size: int, num_layers: int, cell_type: str = "LSTM", dropout: float = 0.0):
super().__init__()
assert cell_type in {"RNN", "LSTM", "GRU"}
self.vocab_size = vocab_size
self.embed = nn.Embedding(vocab_size, embed_dim)
self.cell_type = cell_type
self.hidden_size = hidden_size
self.num_layers = num_layers
if cell_type == "RNN":
self.rnn = nn.RNN(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
elif cell_type == "GRU":
self.rnn = nn.GRU(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
else: # LSTM
self.rnn = nn.LSTM(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden=None):
# x: [B, T]
x = self.embed(x) # [B, T, E]
out, hidden = self.rnn(x, hidden) # out: [B, T, H]
logits = self.fc(out) # [B, T, V]
return logits, hidden
def init_hidden(self, batch_size: int, device):
if self.cell_type == "LSTM":
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
return (h0, c0)
else:
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
return h0
# -----------------------
# Training & Generation
# -----------------------
def temperature_sample(logits: torch.Tensor, temperature: float):
# logits: [V]
if temperature <= 0:
# argmax
return int(torch.argmax(logits).item())
probs = torch.softmax(logits / temperature, dim=-1)
return int(torch.multinomial(probs, num_samples=1).item())
def generate_text(model: CharRNN, stoi, itos, prime: str, length: int, temperature: float, device):
if len(prime) == 0:
# start from random char if prime is empty
prime = random.choice(list(stoi.keys()))
model.eval()
# Feed prime to warm up hidden state
input_ids = torch.tensor([[stoi.get(ch, 0) for ch in prime]], dtype=torch.long, device=device)
hidden = model.init_hidden(batch_size=1, device=device)
with torch.no_grad():
logits, hidden = model(input_ids, hidden)
last_char_id = input_ids[0, -1]
generated = [ch for ch in prime]
for _ in range(length):
with torch.no_grad():
# take last timestep logits
last_logits = logits[0, -1, :] # [V]
next_id = temperature_sample(last_logits, temperature)
generated.append(itos[next_id])
# next step
inp = torch.tensor([[next_id]], dtype=torch.long, device=device)
logits, hidden = model(inp, hidden)
return "".join(generated)
def train_one_run(
raw_text: str,
model_type: str = "LSTM",
embed_dim: int = 64,
hidden_size: int = 128,
num_layers: int = 2,
seq_len: int = 64,
batch_size: int = 16,
lr: float = 0.003,
epochs: int = 3,
dropout: float = 0.0,
temperature: float = 0.8,
generate_n_chars: int = 400,
prime_text: str = "The ",
force_cpu: bool = False,
):
t0 = time.time()
# Sanity text
text = raw_text.strip()
if len(text) < max(100, seq_len + 1):
# Provide a small default if not enough text
default = (
"To be, or not to be, that is the question:\n"
"Whether 'tis nobler in the mind to suffer\n"
"The slings and arrows of outrageous fortune,\n"
"Or to take arms against a sea of troubles\n"
"And by opposing end them."
)
text = (text + "\n" + default).strip()
stoi, itos, vocab_size = build_vocab(text)
data_ids = encode(text, stoi)
device = pick_device(force_cpu=force_cpu)
model = CharRNN(
vocab_size=vocab_size,
embed_dim=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
cell_type=model_type,
dropout=dropout,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
losses = []
model.train()
for ep in range(1, epochs + 1):
batches = make_batches(data_ids, seq_len=seq_len, batch_size=batch_size)
if len(batches) == 0:
raise ValueError("Not enough data to make at least one full batch. Increase text length or reduce seq_len/batch_size.")
ep_loss = 0.0
count = 0
for xb, yb in batches:
xb = xb.to(device) # [B, T]
yb = yb.to(device) # [B, T]
hidden = model.init_hidden(batch_size=xb.size(0), device=device)
optimizer.zero_grad()
logits, _ = model(xb, hidden) # [B, T, V]
# Reshape for CE loss
B, T, V = logits.shape
loss = criterion(logits.view(B * T, V), yb.view(B * T))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
ep_loss += loss.item()
count += 1
mean_loss = ep_loss / max(count, 1)
losses.append(mean_loss)
train_time = time.time() - t0
# Generate
gen = generate_text(
model=model,
stoi=stoi,
itos=itos,
prime=prime_text,
length=generate_n_chars,
temperature=temperature,
device=device,
)
# Plot loss curve
fig = plt.figure(figsize=(5, 3.2), dpi=140)
xs = np.arange(1, len(losses) + 1)
plt.plot(xs, losses, marker="o")
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.title("Loss per Epoch")
plt.tight_layout()
# Hyperparameters summary (to show clearly in UI)
hparams = {
"model_type": model_type,
"embed_dim": embed_dim,
"hidden_size": hidden_size,
"num_layers": num_layers,
"seq_len": seq_len,
"batch_size": batch_size,
"learning_rate": lr,
"epochs": epochs,
"dropout": dropout,
"temperature": temperature,
"generate_n_chars": generate_n_chars,
"vocab_size": vocab_size,
"device": str(device),
"train_time_sec": round(train_time, 3),
"num_batches_per_epoch": len(make_batches(data_ids, seq_len, batch_size)),
"text_len": len(text),
}
return gen, fig, hparams
# -----------------------
# Gradio Interface
# -----------------------
EXAMPLE_TEXT = (
"In the beginning God created the heaven and the earth. "
"And the earth was without form, and void; and darkness was upon the face of the deep. "
"And the Spirit of God moved upon the face of the waters. "
"And God said, Let there be light: and there was light."
)
with gr.Blocks(title="PyTorch RNN/LSTM/GRU Text Generator") as demo:
gr.Markdown(
"""# 🔠 PyTorch Character RNN / LSTM / GRU
Train a tiny character-level model on your text and generate new text.
Use the controls below to **show & tweak hyperparameters**, provide **input text**, and view **outputs** (generated text + loss curve).
"""
)
with gr.Row():
with gr.Column():
corpus = gr.Textbox(
label="Training Text (paste here)",
value=EXAMPLE_TEXT,
lines=10,
placeholder="Paste training text here… (longer is better)"
)
upload = gr.File(label="Or upload a .txt file (optional)", file_count="single")
prime_text = gr.Textbox(label="Prime text (seed for generation)", value="The ", lines=1)
with gr.Row():
model_type = gr.Radio(choices=["LSTM", "GRU", "RNN"], value="LSTM", label="Model type")
force_cpu = gr.Checkbox(value=False, label="Force CPU (uncheck to use GPU if available)")
with gr.Accordion("Hyperparameters", open=True):
with gr.Row():
embed_dim = gr.Slider(16, 256, value=64, step=1, label="Embedding Dim")
hidden_size = gr.Slider(32, 512, value=128, step=1, label="Hidden Size")
with gr.Row():
num_layers = gr.Slider(1, 4, value=2, step=1, label="Layers")
dropout = gr.Slider(0.0, 0.6, value=0.0, step=0.05, label="Dropout")
with gr.Row():
seq_len = gr.Slider(16, 256, value=64, step=1, label="Sequence Length")
batch_size = gr.Slider(4, 128, value=16, step=1, label="Batch Size")
with gr.Row():
lr = gr.Number(value=0.003, precision=6, label="Learning Rate")
epochs = gr.Slider(1, 15, value=3, step=1, label="Epochs")
with gr.Row():
temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature (sampling)")
generate_n_chars = gr.Slider(50, 1000, value=400, step=10, label="Generate N Chars")
run_btn = gr.Button("🚀 Train & Generate", variant="primary")
with gr.Column():
gen_text = gr.Textbox(label="Generated Text (output)", lines=20)
loss_plot = gr.Plot(label="Training Loss")
hparams_json = gr.JSON(label="Hyperparameters (for your records)")
def run_pipeline(corpus_text, uploaded_file, **kwargs):
merged = (text_from_file(uploaded_file) + "\n" + (corpus_text or "")).strip() if uploaded_file else (corpus_text or "")
out_text, fig, hp = train_one_run(raw_text=merged, **kwargs)
return out_text, fig, hp
run_btn.click(
run_pipeline,
inputs=[
corpus, upload,
{"label": "model_type"},
{"label": "embed_dim"},
{"label": "hidden_size"},
{"label": "num_layers"},
{"label": "seq_len"},
{"label": "batch_size"},
{"label": "lr"},
{"label": "epochs"},
{"label": "dropout"},
{"label": "temperature"},
{"label": "generate_n_chars"},
prime_text,
force_cpu
],
outputs=[gen_text, loss_plot, hparams_json],
queue=False
)
if __name__ == "__main__":
demo.launch()