Spaces:
Running
Running
import torch | |
import gradio as gr | |
from modeling_diffusion import DiffusionTextModel | |
# ===================== | |
# Load Model from Hub | |
# ===================== | |
model = DiffusionTextModel.from_pretrained("yasserrmd/diffusion-text-demo") | |
model.eval() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
PAD_TOKEN = "[PAD]" | |
MASK_TOKEN = "[MASK]" | |
vocab = {PAD_TOKEN: 0, MASK_TOKEN: 1} | |
# Reverse mapping | |
id_to_word = {i: w for w, i in vocab.items()} | |
# Special token IDs | |
pad_id = vocab[PAD_TOKEN] | |
mask_id = vocab[MASK_TOKEN] | |
# ===================== | |
# Generation Function | |
# ===================== | |
def generate_with_prompt(model, input_text, max_length=50, T=10): | |
# Ensure max_length does not exceed 99 | |
max_length = min(max_length, 99) | |
model.eval() | |
input_tokens = input_text.split() | |
input_ids = [vocab.get(tok, mask_id) for tok in input_tokens] | |
seq = torch.full((1, max_length), mask_id, dtype=torch.long, device=device) | |
seq[0, :len(input_ids)] = torch.tensor(input_ids, device=device) | |
for step in range(T, 0, -1): | |
with torch.no_grad(): | |
logits = model(seq, torch.tensor([step], device=device)) | |
probs = torch.softmax(logits, dim=-1) | |
for pos in range(len(input_ids), max_length): | |
if seq[0, pos].item() == mask_id: | |
seq[0, pos] = torch.multinomial(probs[0, pos], 1) | |
ids = seq[0].tolist() | |
if pad_id in ids: | |
ids = ids[:ids.index(pad_id)] | |
return " ".join(id_to_word[i] for i in ids) | |
# ===================== | |
# Gradio App | |
# ===================== | |
def chat_fn(message, history, steps, max_len): | |
response = generate_with_prompt(model, message, max_length=max_len, T=steps) | |
history.append((message, response)) | |
return "", history | |
with gr.Blocks() as demo: | |
gr.Markdown("## π DiffusionTextModel QA Chat Demo") | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(placeholder="Type your question or prompt here...") | |
steps = gr.Slider(1, 50, value=10, step=1, label="Diffusion Steps (T)") | |
max_len = gr.Slider(10, 99, value=50, step=1, label="Max Token Length (β€ 99)") | |
clear = gr.Button("Clear") | |
msg.submit(chat_fn, [msg, chatbot, steps, max_len], [msg, chatbot]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.launch() | |