Spaces:
Sleeping
Sleeping
File size: 5,480 Bytes
17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import logging
import gradio as gr
import torch
from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
)
from sdlm.arguments import get_args
from sdlm.models.utils import load_model
from sdlm.pipelines.simplex_ddpm import SimplexDDPMPipeline
from sdlm.schedulers import TokenWiseSimplexDDPMScheduler
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def main():
model_args, data_args, training_args, diffusion_args = get_args()
tokenizer, model = load_model(model_args, data_args, training_args, diffusion_args, logger)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = SimplexDDPMPipeline(
model=model.to(device),
scheduler=TokenWiseSimplexDDPMScheduler(
num_train_timesteps=diffusion_args.num_train_timesteps
if hasattr(diffusion_args, "num_train_timesteps") else 100,
beta_schedule=getattr(diffusion_args, "beta_schedule", "squaredcos_improved_ddpm"),
simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
clip_sample=getattr(diffusion_args, "clip_sample", False),
device=device,
),
simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
top_p=getattr(diffusion_args, "top_p", 0.99),
sampling_type="top_p",
is_conditional_generation=True,
tokenizer=tokenizer,
classifier_free_uncond_input="empty_token",
temperature=getattr(diffusion_args, "temperature", 1.0),
guidance_softmax_combination=True,
)
def generate(
inputs,
simplex_value=5.0,
top_p=0.99,
temperature=1.0,
diffusion_steps=100,
beta_schedule="squaredcos_improved_ddpm",
clip_sample=False,
guidance_scale=1.0,
generated_sequence_length=256,
progress=gr.Progress(),
):
"""
Gradio-friendly generation function. Adjusts the pipeline's parameters
(simplex_value, top_p, etc.) as requested, then runs generation.
"""
with torch.inference_mode():
# Update pipeline scheduler with user-provided parameters:
pipeline.scheduler.num_train_timesteps = diffusion_steps
pipeline.scheduler.beta_schedule = beta_schedule
pipeline.scheduler.simplex_value = simplex_value
pipeline.scheduler.clip_sample = clip_sample
pipeline.simplex_value = simplex_value
pipeline.top_p = top_p
pipeline.temperature = temperature
# tulu chat template
inputs = "<|user|>\n" + inputs + "<|assistant|>\n"
# Tokenize and prepare input for diffusion
tokenized_input = tokenizer([inputs], add_special_tokens=False, return_tensors="pt").input_ids
tokenized_input_len = tokenized_input.shape[1]
# Concatenate BOS + input + blank space for generation
tokenized_input = torch.cat(
[
torch.ones((1, 1), dtype=torch.long) * tokenizer.bos_token_id,
tokenized_input,
torch.ones((1, generated_sequence_length), dtype=torch.long) * tokenizer.pad_token_id,
],
dim=-1,
)
# Create a mask over the generation region
span_mask = torch.cat(
[
torch.zeros((1, tokenized_input_len + 1), dtype=torch.bool),
torch.ones((1, generated_sequence_length), dtype=torch.bool),
],
dim=-1,
)
batch = {
"input_ids": tokenized_input.to(device),
"span_mask": span_mask.to(device),
}
# Run sampling
pipe = pipeline(batch=batch, seq_length=generated_sequence_length, guidance_scale=guidance_scale)
for out in pipe:
output_ids = out.logits.argmax(dim=-1)
generated_tokens = output_ids[:, tokenized_input_len + 1 :]
yield tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
# Quick test call (uncomment if you want a quick, non-Gradio test)
print("Test generation: ", generate("The best things in life are"))
demo = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(lines=5, label="Input Prompt"),
gr.Number(value=5.0, label="Simplex value"),
gr.Slider(0, 1, value=0.99, step=0.01, label="Top-p"),
gr.Slider(0, 5, value=1.0, step=0.1, label="Temperature"),
gr.Number(value=100, precision=0, label="Diffusion steps"),
gr.Dropdown(
choices=["linear", "scaled_linear", "squaredcos_cap_v2", "squaredcos_improved_ddpm"],
value="squaredcos_improved_ddpm",
label="Beta schedule",
),
gr.Checkbox(value=False, label="Clip sample?"),
gr.Number(value=1.0, label="Guidance scale"),
gr.Number(value=256, label="Generation length (tokens)"),
],
outputs="text",
title="Simplex Diffusion LM",
description="Generate text using a simplex-based diffusion model.",
)
demo.queue().launch(server_name="0.0.0.0", server_port=8888, share=True)
if __name__ == "__main__":
main() |