tess-2-demo / app.py
hamishivi's picture
commit
17ff0d8 verified
raw
history blame
5.48 kB
import logging
import gradio as gr
import torch
from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
)
from sdlm.arguments import get_args
from sdlm.models.utils import load_model
from sdlm.pipelines.simplex_ddpm import SimplexDDPMPipeline
from sdlm.schedulers import TokenWiseSimplexDDPMScheduler
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def main():
model_args, data_args, training_args, diffusion_args = get_args()
tokenizer, model = load_model(model_args, data_args, training_args, diffusion_args, logger)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = SimplexDDPMPipeline(
model=model.to(device),
scheduler=TokenWiseSimplexDDPMScheduler(
num_train_timesteps=diffusion_args.num_train_timesteps
if hasattr(diffusion_args, "num_train_timesteps") else 100,
beta_schedule=getattr(diffusion_args, "beta_schedule", "squaredcos_improved_ddpm"),
simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
clip_sample=getattr(diffusion_args, "clip_sample", False),
device=device,
),
simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
top_p=getattr(diffusion_args, "top_p", 0.99),
sampling_type="top_p",
is_conditional_generation=True,
tokenizer=tokenizer,
classifier_free_uncond_input="empty_token",
temperature=getattr(diffusion_args, "temperature", 1.0),
guidance_softmax_combination=True,
)
def generate(
inputs,
simplex_value=5.0,
top_p=0.99,
temperature=1.0,
diffusion_steps=100,
beta_schedule="squaredcos_improved_ddpm",
clip_sample=False,
guidance_scale=1.0,
generated_sequence_length=256,
progress=gr.Progress(),
):
"""
Gradio-friendly generation function. Adjusts the pipeline's parameters
(simplex_value, top_p, etc.) as requested, then runs generation.
"""
with torch.inference_mode():
# Update pipeline scheduler with user-provided parameters:
pipeline.scheduler.num_train_timesteps = diffusion_steps
pipeline.scheduler.beta_schedule = beta_schedule
pipeline.scheduler.simplex_value = simplex_value
pipeline.scheduler.clip_sample = clip_sample
pipeline.simplex_value = simplex_value
pipeline.top_p = top_p
pipeline.temperature = temperature
# tulu chat template
inputs = "<|user|>\n" + inputs + "<|assistant|>\n"
# Tokenize and prepare input for diffusion
tokenized_input = tokenizer([inputs], add_special_tokens=False, return_tensors="pt").input_ids
tokenized_input_len = tokenized_input.shape[1]
# Concatenate BOS + input + blank space for generation
tokenized_input = torch.cat(
[
torch.ones((1, 1), dtype=torch.long) * tokenizer.bos_token_id,
tokenized_input,
torch.ones((1, generated_sequence_length), dtype=torch.long) * tokenizer.pad_token_id,
],
dim=-1,
)
# Create a mask over the generation region
span_mask = torch.cat(
[
torch.zeros((1, tokenized_input_len + 1), dtype=torch.bool),
torch.ones((1, generated_sequence_length), dtype=torch.bool),
],
dim=-1,
)
batch = {
"input_ids": tokenized_input.to(device),
"span_mask": span_mask.to(device),
}
# Run sampling
pipe = pipeline(batch=batch, seq_length=generated_sequence_length, guidance_scale=guidance_scale)
for out in pipe:
output_ids = out.logits.argmax(dim=-1)
generated_tokens = output_ids[:, tokenized_input_len + 1 :]
yield tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
# Quick test call (uncomment if you want a quick, non-Gradio test)
print("Test generation: ", generate("The best things in life are"))
demo = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(lines=5, label="Input Prompt"),
gr.Number(value=5.0, label="Simplex value"),
gr.Slider(0, 1, value=0.99, step=0.01, label="Top-p"),
gr.Slider(0, 5, value=1.0, step=0.1, label="Temperature"),
gr.Number(value=100, precision=0, label="Diffusion steps"),
gr.Dropdown(
choices=["linear", "scaled_linear", "squaredcos_cap_v2", "squaredcos_improved_ddpm"],
value="squaredcos_improved_ddpm",
label="Beta schedule",
),
gr.Checkbox(value=False, label="Clip sample?"),
gr.Number(value=1.0, label="Guidance scale"),
gr.Number(value=256, label="Generation length (tokens)"),
],
outputs="text",
title="Simplex Diffusion LM",
description="Generate text using a simplex-based diffusion model.",
)
demo.queue().launch(server_name="0.0.0.0", server_port=8888, share=True)
if __name__ == "__main__":
main()