File size: 5,480 Bytes
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import logging

import gradio as gr
import torch
from transformers import (
    MODEL_FOR_MASKED_LM_MAPPING,
)

from sdlm.arguments import get_args
from sdlm.models.utils import load_model
from sdlm.pipelines.simplex_ddpm import SimplexDDPMPipeline
from sdlm.schedulers import TokenWiseSimplexDDPMScheduler

logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


def main():
    model_args, data_args, training_args, diffusion_args = get_args()
    tokenizer, model = load_model(model_args, data_args, training_args, diffusion_args, logger)

    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    pipeline = SimplexDDPMPipeline(
        model=model.to(device),
        scheduler=TokenWiseSimplexDDPMScheduler(
            num_train_timesteps=diffusion_args.num_train_timesteps
            if hasattr(diffusion_args, "num_train_timesteps") else 100,
            beta_schedule=getattr(diffusion_args, "beta_schedule", "squaredcos_improved_ddpm"),
            simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
            clip_sample=getattr(diffusion_args, "clip_sample", False),
            device=device,
        ),
        simplex_value=getattr(diffusion_args, "simplex_value", 5.0),
        top_p=getattr(diffusion_args, "top_p", 0.99),
        sampling_type="top_p",
        is_conditional_generation=True,
        tokenizer=tokenizer,
        classifier_free_uncond_input="empty_token",
        temperature=getattr(diffusion_args, "temperature", 1.0),
        guidance_softmax_combination=True,
    )

    def generate(
        inputs,
        simplex_value=5.0,
        top_p=0.99,
        temperature=1.0,
        diffusion_steps=100,
        beta_schedule="squaredcos_improved_ddpm",
        clip_sample=False,
        guidance_scale=1.0,
        generated_sequence_length=256,
        progress=gr.Progress(),
    ):
        """
        Gradio-friendly generation function. Adjusts the pipeline's parameters
        (simplex_value, top_p, etc.) as requested, then runs generation.
        """
        with torch.inference_mode():
            # Update pipeline scheduler with user-provided parameters:
            pipeline.scheduler.num_train_timesteps = diffusion_steps
            pipeline.scheduler.beta_schedule = beta_schedule
            pipeline.scheduler.simplex_value = simplex_value
            pipeline.scheduler.clip_sample = clip_sample
            pipeline.simplex_value = simplex_value
            pipeline.top_p = top_p
            pipeline.temperature = temperature
            # tulu chat template
            inputs = "<|user|>\n" + inputs + "<|assistant|>\n"

            # Tokenize and prepare input for diffusion
            tokenized_input = tokenizer([inputs], add_special_tokens=False, return_tensors="pt").input_ids
            tokenized_input_len = tokenized_input.shape[1]

            # Concatenate BOS + input + blank space for generation
            tokenized_input = torch.cat(
                [
                    torch.ones((1, 1), dtype=torch.long) * tokenizer.bos_token_id,
                    tokenized_input,
                    torch.ones((1, generated_sequence_length), dtype=torch.long) * tokenizer.pad_token_id,
                ],
                dim=-1,
            )

            # Create a mask over the generation region
            span_mask = torch.cat(
                [
                    torch.zeros((1, tokenized_input_len + 1), dtype=torch.bool),
                    torch.ones((1, generated_sequence_length), dtype=torch.bool),
                ],
                dim=-1,
            )

            batch = {
                "input_ids": tokenized_input.to(device),
                "span_mask": span_mask.to(device),
            }

            # Run sampling
            
            pipe = pipeline(batch=batch, seq_length=generated_sequence_length, guidance_scale=guidance_scale)
            for out in pipe:
                output_ids = out.logits.argmax(dim=-1)
                generated_tokens = output_ids[:, tokenized_input_len + 1 :]
                yield tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

    # Quick test call (uncomment if you want a quick, non-Gradio test)
    print("Test generation: ", generate("The best things in life are"))

    demo = gr.Interface(
        fn=generate,
        inputs=[
            gr.Textbox(lines=5, label="Input Prompt"),
            gr.Number(value=5.0, label="Simplex value"),
            gr.Slider(0, 1, value=0.99, step=0.01, label="Top-p"),
            gr.Slider(0, 5, value=1.0, step=0.1, label="Temperature"),
            gr.Number(value=100, precision=0, label="Diffusion steps"),
            gr.Dropdown(
                choices=["linear", "scaled_linear", "squaredcos_cap_v2", "squaredcos_improved_ddpm"],
                value="squaredcos_improved_ddpm",
                label="Beta schedule",
            ),
            gr.Checkbox(value=False, label="Clip sample?"),
            gr.Number(value=1.0, label="Guidance scale"),
            gr.Number(value=256, label="Generation length (tokens)"),
        ],
        outputs="text",
        title="Simplex Diffusion LM",
        description="Generate text using a simplex-based diffusion model.",
    )

    demo.queue().launch(server_name="0.0.0.0", server_port=8888, share=True)


if __name__ == "__main__":
    main()