Spaces:
Runtime error
Runtime error
# app.py | |
# Character-level text generator with PyTorch (RNN/LSTM/GRU) + Gradio UI | |
import io | |
import math | |
import time | |
import random | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
# ----------------------- | |
# Utilities | |
# ----------------------- | |
def set_seed(seed: int = 42): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
set_seed(42) | |
def pick_device(force_cpu: bool = False): | |
if not force_cpu and torch.cuda.is_available(): | |
return torch.device("cuda") | |
return torch.device("cpu") | |
def text_from_file(f): | |
if f is None: | |
return "" | |
name = getattr(f, "name", None) | |
try: | |
# Try to read as bytes then decode safely | |
raw = f.read() | |
if isinstance(raw, bytes): | |
try: | |
return raw.decode("utf-8") | |
except UnicodeDecodeError: | |
return raw.decode("latin-1", errors="ignore") | |
return str(raw) | |
except Exception: | |
# Fallback: if it's a path-like | |
if name: | |
try: | |
with open(name, "rb") as fh: | |
raw = fh.read() | |
try: | |
return raw.decode("utf-8") | |
except UnicodeDecodeError: | |
return raw.decode("latin-1", errors="ignore") | |
except Exception: | |
return "" | |
return "" | |
def build_vocab(text: str): | |
chars = sorted(list(set(text))) | |
stoi = {ch: i for i, ch in enumerate(chars)} | |
itos = {i: ch for ch, i in stoi.items()} | |
return stoi, itos, len(chars) | |
def encode(text: str, stoi: dict): | |
return torch.tensor([stoi[c] for c in text], dtype=torch.long) | |
def decode(indices, itos: dict): | |
return "".join([itos[int(i)] for i in indices]) | |
def make_batches(data_ids: torch.Tensor, seq_len: int, batch_size: int): | |
# data_ids: shape [T] | |
total_len = data_ids.size(0) | |
num_sequences = (total_len - 1) // seq_len | |
if num_sequences <= 0: | |
return [] | |
# Trim so we have full sequences | |
trimmed = num_sequences * seq_len | |
x = data_ids[:trimmed] | |
y = data_ids[1:trimmed + 1] | |
x = x.view(num_sequences, seq_len) | |
y = y.view(num_sequences, seq_len) | |
# Shuffle sequence order | |
perm = torch.randperm(num_sequences) | |
x = x[perm] | |
y = y[perm] | |
# Batch into mini-batches | |
batches = [] | |
for i in range(0, num_sequences, batch_size): | |
xb = x[i:i + batch_size] | |
yb = y[i:i + batch_size] | |
if xb.size(0) == batch_size: # keep full batches only for simplicity | |
batches.append((xb, yb)) | |
return batches | |
# ----------------------- | |
# Model | |
# ----------------------- | |
class CharRNN(nn.Module): | |
def __init__(self, vocab_size: int, embed_dim: int, hidden_size: int, num_layers: int, cell_type: str = "LSTM", dropout: float = 0.0): | |
super().__init__() | |
assert cell_type in {"RNN", "LSTM", "GRU"} | |
self.vocab_size = vocab_size | |
self.embed = nn.Embedding(vocab_size, embed_dim) | |
self.cell_type = cell_type | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
if cell_type == "RNN": | |
self.rnn = nn.RNN(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) | |
elif cell_type == "GRU": | |
self.rnn = nn.GRU(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) | |
else: # LSTM | |
self.rnn = nn.LSTM(embed_dim, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) | |
self.fc = nn.Linear(hidden_size, vocab_size) | |
def forward(self, x, hidden=None): | |
# x: [B, T] | |
x = self.embed(x) # [B, T, E] | |
out, hidden = self.rnn(x, hidden) # out: [B, T, H] | |
logits = self.fc(out) # [B, T, V] | |
return logits, hidden | |
def init_hidden(self, batch_size: int, device): | |
if self.cell_type == "LSTM": | |
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device) | |
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device) | |
return (h0, c0) | |
else: | |
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device) | |
return h0 | |
# ----------------------- | |
# Training & Generation | |
# ----------------------- | |
def temperature_sample(logits: torch.Tensor, temperature: float): | |
# logits: [V] | |
if temperature <= 0: | |
# argmax | |
return int(torch.argmax(logits).item()) | |
probs = torch.softmax(logits / temperature, dim=-1) | |
return int(torch.multinomial(probs, num_samples=1).item()) | |
def generate_text(model: CharRNN, stoi, itos, prime: str, length: int, temperature: float, device): | |
if len(prime) == 0: | |
# start from random char if prime is empty | |
prime = random.choice(list(stoi.keys())) | |
model.eval() | |
# Feed prime to warm up hidden state | |
input_ids = torch.tensor([[stoi.get(ch, 0) for ch in prime]], dtype=torch.long, device=device) | |
hidden = model.init_hidden(batch_size=1, device=device) | |
with torch.no_grad(): | |
logits, hidden = model(input_ids, hidden) | |
last_char_id = input_ids[0, -1] | |
generated = [ch for ch in prime] | |
for _ in range(length): | |
with torch.no_grad(): | |
# take last timestep logits | |
last_logits = logits[0, -1, :] # [V] | |
next_id = temperature_sample(last_logits, temperature) | |
generated.append(itos[next_id]) | |
# next step | |
inp = torch.tensor([[next_id]], dtype=torch.long, device=device) | |
logits, hidden = model(inp, hidden) | |
return "".join(generated) | |
def train_one_run( | |
raw_text: str, | |
model_type: str = "LSTM", | |
embed_dim: int = 64, | |
hidden_size: int = 128, | |
num_layers: int = 2, | |
seq_len: int = 64, | |
batch_size: int = 16, | |
lr: float = 0.003, | |
epochs: int = 3, | |
dropout: float = 0.0, | |
temperature: float = 0.8, | |
generate_n_chars: int = 400, | |
prime_text: str = "The ", | |
force_cpu: bool = False, | |
): | |
t0 = time.time() | |
# Sanity text | |
text = raw_text.strip() | |
if len(text) < max(100, seq_len + 1): | |
# Provide a small default if not enough text | |
default = ( | |
"To be, or not to be, that is the question:\n" | |
"Whether 'tis nobler in the mind to suffer\n" | |
"The slings and arrows of outrageous fortune,\n" | |
"Or to take arms against a sea of troubles\n" | |
"And by opposing end them." | |
) | |
text = (text + "\n" + default).strip() | |
stoi, itos, vocab_size = build_vocab(text) | |
data_ids = encode(text, stoi) | |
device = pick_device(force_cpu=force_cpu) | |
model = CharRNN( | |
vocab_size=vocab_size, | |
embed_dim=embed_dim, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
cell_type=model_type, | |
dropout=dropout, | |
).to(device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
criterion = nn.CrossEntropyLoss() | |
losses = [] | |
model.train() | |
for ep in range(1, epochs + 1): | |
batches = make_batches(data_ids, seq_len=seq_len, batch_size=batch_size) | |
if len(batches) == 0: | |
raise ValueError("Not enough data to make at least one full batch. Increase text length or reduce seq_len/batch_size.") | |
ep_loss = 0.0 | |
count = 0 | |
for xb, yb in batches: | |
xb = xb.to(device) # [B, T] | |
yb = yb.to(device) # [B, T] | |
hidden = model.init_hidden(batch_size=xb.size(0), device=device) | |
optimizer.zero_grad() | |
logits, _ = model(xb, hidden) # [B, T, V] | |
# Reshape for CE loss | |
B, T, V = logits.shape | |
loss = criterion(logits.view(B * T, V), yb.view(B * T)) | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
ep_loss += loss.item() | |
count += 1 | |
mean_loss = ep_loss / max(count, 1) | |
losses.append(mean_loss) | |
train_time = time.time() - t0 | |
# Generate | |
gen = generate_text( | |
model=model, | |
stoi=stoi, | |
itos=itos, | |
prime=prime_text, | |
length=generate_n_chars, | |
temperature=temperature, | |
device=device, | |
) | |
# Plot loss curve | |
fig = plt.figure(figsize=(5, 3.2), dpi=140) | |
xs = np.arange(1, len(losses) + 1) | |
plt.plot(xs, losses, marker="o") | |
plt.xlabel("Epoch") | |
plt.ylabel("Training Loss") | |
plt.title("Loss per Epoch") | |
plt.tight_layout() | |
# Hyperparameters summary (to show clearly in UI) | |
hparams = { | |
"model_type": model_type, | |
"embed_dim": embed_dim, | |
"hidden_size": hidden_size, | |
"num_layers": num_layers, | |
"seq_len": seq_len, | |
"batch_size": batch_size, | |
"learning_rate": lr, | |
"epochs": epochs, | |
"dropout": dropout, | |
"temperature": temperature, | |
"generate_n_chars": generate_n_chars, | |
"vocab_size": vocab_size, | |
"device": str(device), | |
"train_time_sec": round(train_time, 3), | |
"num_batches_per_epoch": len(make_batches(data_ids, seq_len, batch_size)), | |
"text_len": len(text), | |
} | |
return gen, fig, hparams | |
# ----------------------- | |
# Gradio Interface | |
# ----------------------- | |
EXAMPLE_TEXT = ( | |
"In the beginning God created the heaven and the earth. " | |
"And the earth was without form, and void; and darkness was upon the face of the deep. " | |
"And the Spirit of God moved upon the face of the waters. " | |
"And God said, Let there be light: and there was light." | |
) | |
with gr.Blocks(title="PyTorch RNN/LSTM/GRU Text Generator") as demo: | |
gr.Markdown( | |
"""# 🔠 PyTorch Character RNN / LSTM / GRU | |
Train a tiny character-level model on your text and generate new text. | |
Use the controls below to **show & tweak hyperparameters**, provide **input text**, and view **outputs** (generated text + loss curve). | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
corpus = gr.Textbox( | |
label="Training Text (paste here)", | |
value=EXAMPLE_TEXT, | |
lines=10, | |
placeholder="Paste training text here… (longer is better)" | |
) | |
upload = gr.File(label="Or upload a .txt file (optional)", file_count="single") | |
prime_text = gr.Textbox(label="Prime text (seed for generation)", value="The ", lines=1) | |
with gr.Row(): | |
model_type = gr.Radio(choices=["LSTM", "GRU", "RNN"], value="LSTM", label="Model type") | |
force_cpu = gr.Checkbox(value=False, label="Force CPU (uncheck to use GPU if available)") | |
with gr.Accordion("Hyperparameters", open=True): | |
with gr.Row(): | |
embed_dim = gr.Slider(16, 256, value=64, step=1, label="Embedding Dim") | |
hidden_size = gr.Slider(32, 512, value=128, step=1, label="Hidden Size") | |
with gr.Row(): | |
num_layers = gr.Slider(1, 4, value=2, step=1, label="Layers") | |
dropout = gr.Slider(0.0, 0.6, value=0.0, step=0.05, label="Dropout") | |
with gr.Row(): | |
seq_len = gr.Slider(16, 256, value=64, step=1, label="Sequence Length") | |
batch_size = gr.Slider(4, 128, value=16, step=1, label="Batch Size") | |
with gr.Row(): | |
lr = gr.Number(value=0.003, precision=6, label="Learning Rate") | |
epochs = gr.Slider(1, 15, value=3, step=1, label="Epochs") | |
with gr.Row(): | |
temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature (sampling)") | |
generate_n_chars = gr.Slider(50, 1000, value=400, step=10, label="Generate N Chars") | |
run_btn = gr.Button("🚀 Train & Generate", variant="primary") | |
with gr.Column(): | |
gen_text = gr.Textbox(label="Generated Text (output)", lines=20) | |
loss_plot = gr.Plot(label="Training Loss") | |
hparams_json = gr.JSON(label="Hyperparameters (for your records)") | |
def run_pipeline(corpus_text, uploaded_file, **kwargs): | |
merged = (text_from_file(uploaded_file) + "\n" + (corpus_text or "")).strip() if uploaded_file else (corpus_text or "") | |
out_text, fig, hp = train_one_run(raw_text=merged, **kwargs) | |
return out_text, fig, hp | |
run_btn.click( | |
run_pipeline, | |
inputs=[ | |
corpus, upload, | |
{"label": "model_type"}, | |
{"label": "embed_dim"}, | |
{"label": "hidden_size"}, | |
{"label": "num_layers"}, | |
{"label": "seq_len"}, | |
{"label": "batch_size"}, | |
{"label": "lr"}, | |
{"label": "epochs"}, | |
{"label": "dropout"}, | |
{"label": "temperature"}, | |
{"label": "generate_n_chars"}, | |
prime_text, | |
force_cpu | |
], | |
outputs=[gen_text, loss_plot, hparams_json], | |
queue=False | |
) | |
if __name__ == "__main__": | |
demo.launch() | |