# app.py # Character-level text generator with PyTorch (RNN/LSTM/GRU) + Gradio UI import io import math import time import random import numpy as np import torch import torch.nn as nn import matplotlib.pyplot as plt import gradio as gr # ----------------------- # Utilities # ----------------------- def set_seed(seed: int = 42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) set_seed(42) def pick_device(force_cpu: bool = False): if not force_cpu and torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def text_from_file(f): if f is None: return "" name = getattr(f, "name", None) try: # Try to read as bytes then decode safely raw = f.read() if isinstance(raw, bytes): try: return raw.decode("utf-8") except UnicodeDecodeError: return raw.decode("latin-1", errors="ignore") return str(raw) except Exception: # Fallback: if it's a path-like if name: try: with open(name, "rb") as fh: raw = fh.read() try: return raw.decode("utf-8") except UnicodeDecodeError: return raw.decode("latin-1", errors="ignore") except Exception: return "" return "" def build_vocab(text: str): chars = sorted(list(set(text))) stoi = {ch: i for i, ch in enumerate(chars)} itos = {i: ch for ch, i in stoi.items()} return stoi, itos, len(chars) def encode(text: str, stoi: dict): return torch.tensor([stoi[c] for c in text], dtype=torch.long) def decode(indices, itos: dict): return "".join([itos[int(i)] for i in indices]) def make_batches(data_ids: torch.Tensor, seq_len: int, batch_size: int): # data_ids: shape [T] total_len = data_ids.size(0) num_sequences = (total_len - 1) // seq_len if num_sequences <= 0: return [] # Trim so we have full sequences trimmed = num_sequences * seq_len x = data_ids[:trimmed] y = data_ids[1:trimmed + 1] x = x.view(num_sequences, seq_len) y = y.view(num_sequences, seq_len) # Shuffle sequence order perm = torch.randperm(num_sequences) x = x[perm] y = y[perm] # Batch into mini-batches batches = [] for i in range(0, num_sequences, batch_size): xb = x[i:i + batch_size] yb = y[i:i + batch_size] if xb.size(0) == batch_size: # keep full batches only for simplicity batches.append((xb, yb)) return batches # ----------------------- # Model # ----------------------- class CharRNN(nn.Module): def __init__(self, vocab_size: int, embed_dim: int, hidden_size: int, num_layers: int, cell_type: str = "LSTM", dropout: float = 0.0): super().__init__() assert cell_type in {"RNN", "LSTM", "GRU"} self.vocab_size = vocab_size self.embed = nn.Embedding(vocab_size, embed_dim) self.cell_type = cell_type self.hidden_size = hidden_size self.num_layers = num_layers if cell_type == "RNN": self.rnn = nn.RNN(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) elif cell_type == "GRU": self.rnn = nn.GRU(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) else: # LSTM self.rnn = nn.LSTM(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) self.fc = nn.Linear(hidden_size, vocab_size) def forward(self, x, hidden=None): # x: [B, T] x = self.embed(x) # [B, T, E] out, hidden = self.rnn(x, hidden) # out: [B, T, H] logits = self.fc(out) # [B, T, V] return logits, hidden def init_hidden(self, batch_size: int, device): if self.cell_type == "LSTM": h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device) c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device) return (h0, c0) else: h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device) return h0 # ----------------------- # Training & Generation # ----------------------- def temperature_sample(logits: torch.Tensor, temperature: float): # logits: [V] if temperature <= 0: # argmax return int(torch.argmax(logits).item()) probs = torch.softmax(logits / temperature, dim=-1) return int(torch.multinomial(probs, num_samples=1).item()) def generate_text(model: CharRNN, stoi, itos, prime: str, length: int, temperature: float, device): if len(prime) == 0: # start from random char if prime is empty prime = random.choice(list(stoi.keys())) model.eval() # Feed prime to warm up hidden state input_ids = torch.tensor([[stoi.get(ch, 0) for ch in prime]], dtype=torch.long, device=device) hidden = model.init_hidden(batch_size=1, device=device) with torch.no_grad(): logits, hidden = model(input_ids, hidden) last_char_id = input_ids[0, -1] generated = [ch for ch in prime] for _ in range(length): with torch.no_grad(): # take last timestep logits last_logits = logits[0, -1, :] # [V] next_id = temperature_sample(last_logits, temperature) generated.append(itos[next_id]) # next step inp = torch.tensor([[next_id]], dtype=torch.long, device=device) logits, hidden = model(inp, hidden) return "".join(generated) def train_one_run( raw_text: str, model_type: str = "LSTM", embed_dim: int = 64, hidden_size: int = 128, num_layers: int = 2, seq_len: int = 64, batch_size: int = 16, lr: float = 0.003, epochs: int = 3, dropout: float = 0.0, temperature: float = 0.8, generate_n_chars: int = 400, prime_text: str = "The ", force_cpu: bool = False, ): t0 = time.time() # Sanity text text = raw_text.strip() if len(text) < max(100, seq_len + 1): # Provide a small default if not enough text default = ( "To be, or not to be, that is the question:\n" "Whether 'tis nobler in the mind to suffer\n" "The slings and arrows of outrageous fortune,\n" "Or to take arms against a sea of troubles\n" "And by opposing end them." ) text = (text + "\n" + default).strip() stoi, itos, vocab_size = build_vocab(text) data_ids = encode(text, stoi) device = pick_device(force_cpu=force_cpu) model = CharRNN( vocab_size=vocab_size, embed_dim=embed_dim, hidden_size=hidden_size, num_layers=num_layers, cell_type=model_type, dropout=dropout, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() losses = [] model.train() for ep in range(1, epochs + 1): batches = make_batches(data_ids, seq_len=seq_len, batch_size=batch_size) if len(batches) == 0: raise ValueError("Not enough data to make at least one full batch. Increase text length or reduce seq_len/batch_size.") ep_loss = 0.0 count = 0 for xb, yb in batches: xb = xb.to(device) # [B, T] yb = yb.to(device) # [B, T] hidden = model.init_hidden(batch_size=xb.size(0), device=device) optimizer.zero_grad() logits, _ = model(xb, hidden) # [B, T, V] # Reshape for CE loss B, T, V = logits.shape loss = criterion(logits.view(B * T, V), yb.view(B * T)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() ep_loss += loss.item() count += 1 mean_loss = ep_loss / max(count, 1) losses.append(mean_loss) train_time = time.time() - t0 # Generate gen = generate_text( model=model, stoi=stoi, itos=itos, prime=prime_text, length=generate_n_chars, temperature=temperature, device=device, ) # Plot loss curve fig = plt.figure(figsize=(5, 3.2), dpi=140) xs = np.arange(1, len(losses) + 1) plt.plot(xs, losses, marker="o") plt.xlabel("Epoch") plt.ylabel("Training Loss") plt.title("Loss per Epoch") plt.tight_layout() # Hyperparameters summary (to show clearly in UI) hparams = { "model_type": model_type, "embed_dim": embed_dim, "hidden_size": hidden_size, "num_layers": num_layers, "seq_len": seq_len, "batch_size": batch_size, "learning_rate": lr, "epochs": epochs, "dropout": dropout, "temperature": temperature, "generate_n_chars": generate_n_chars, "vocab_size": vocab_size, "device": str(device), "train_time_sec": round(train_time, 3), "num_batches_per_epoch": len(make_batches(data_ids, seq_len, batch_size)), "text_len": len(text), } return gen, fig, hparams # ----------------------- # Gradio Interface # ----------------------- EXAMPLE_TEXT = ( "In the beginning God created the heaven and the earth. " "And the earth was without form, and void; and darkness was upon the face of the deep. " "And the Spirit of God moved upon the face of the waters. " "And God said, Let there be light: and there was light." ) with gr.Blocks(title="PyTorch RNN/LSTM/GRU Text Generator") as demo: gr.Markdown( """# 🔠 PyTorch Character RNN / LSTM / GRU Train a tiny character-level model on your text and generate new text. Use the controls below to **show & tweak hyperparameters**, provide **input text**, and view **outputs** (generated text + loss curve). """ ) with gr.Row(): with gr.Column(): corpus = gr.Textbox( label="Training Text (paste here)", value=EXAMPLE_TEXT, lines=10, placeholder="Paste training text here… (longer is better)" ) upload = gr.File(label="Or upload a .txt file (optional)", file_count="single") prime_text = gr.Textbox(label="Prime text (seed for generation)", value="The ", lines=1) with gr.Row(): model_type = gr.Radio(choices=["LSTM", "GRU", "RNN"], value="LSTM", label="Model type") force_cpu = gr.Checkbox(value=False, label="Force CPU (uncheck to use GPU if available)") with gr.Accordion("Hyperparameters", open=True): with gr.Row(): embed_dim = gr.Slider(16, 256, value=64, step=1, label="Embedding Dim") hidden_size = gr.Slider(32, 512, value=128, step=1, label="Hidden Size") with gr.Row(): num_layers = gr.Slider(1, 4, value=2, step=1, label="Layers") dropout = gr.Slider(0.0, 0.6, value=0.0, step=0.05, label="Dropout") with gr.Row(): seq_len = gr.Slider(16, 256, value=64, step=1, label="Sequence Length") batch_size = gr.Slider(4, 128, value=16, step=1, label="Batch Size") with gr.Row(): lr = gr.Number(value=0.003, precision=6, label="Learning Rate") epochs = gr.Slider(1, 15, value=3, step=1, label="Epochs") with gr.Row(): temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature (sampling)") generate_n_chars = gr.Slider(50, 1000, value=400, step=10, label="Generate N Chars") run_btn = gr.Button("🚀 Train & Generate", variant="primary") with gr.Column(): gen_text = gr.Textbox(label="Generated Text (output)", lines=20) loss_plot = gr.Plot(label="Training Loss") hparams_json = gr.JSON(label="Hyperparameters (for your records)") def run_pipeline(corpus_text, uploaded_file, **kwargs): merged = (text_from_file(uploaded_file) + "\n" + (corpus_text or "")).strip() if uploaded_file else (corpus_text or "") out_text, fig, hp = train_one_run(raw_text=merged, **kwargs) return out_text, fig, hp run_btn.click( run_pipeline, inputs=[ corpus, upload, {"label": "model_type"}, {"label": "embed_dim"}, {"label": "hidden_size"}, {"label": "num_layers"}, {"label": "seq_len"}, {"label": "batch_size"}, {"label": "lr"}, {"label": "epochs"}, {"label": "dropout"}, {"label": "temperature"}, {"label": "generate_n_chars"}, prime_text, force_cpu ], outputs=[gen_text, loss_plot, hparams_json], queue=False ) if __name__ == "__main__": demo.launch()