Spaces:
Running
on
Zero
Running
on
Zero
| # dreamon_app.py | |
| ## this app is built based on https://huggingface.co/spaces/multimodalart/Dream/blob/main/app.py | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import spaces # Ensure spaces is installed if needed for GPU decorator | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel, AutoConfig | |
| import time | |
| import re | |
| from typing import List, Dict, Tuple, Optional, Any, Iterable # Added Any | |
| import torch.distributions as dists # Added import | |
| import traceback # For better error printing | |
| import random | |
| import gzip | |
| import json | |
| import subprocess | |
| import time | |
| from multiprocessing import Process, Queue | |
| import io | |
| import sys | |
| def unsafe_execute(prompt, completion, suffix, test_case, timeout=3): | |
| check_program = (prompt | |
| + completion | |
| + suffix | |
| + "\n\n" | |
| + test_case | |
| ) | |
| # 重定向标准输出和标准错误 | |
| old_stdout = sys.stdout | |
| old_stderr = sys.stderr | |
| new_stdout = io.StringIO() | |
| new_stderr = io.StringIO() | |
| sys.stdout = new_stdout | |
| sys.stderr = new_stderr | |
| try: | |
| # 执行代码 | |
| exec(check_program, {}) | |
| output = new_stdout.getvalue().strip() | |
| error_output = new_stderr.getvalue().strip() | |
| except Exception as e: | |
| # 捕获异常并记录堆栈跟踪 | |
| output = '' | |
| error_output = str(e) | |
| finally: | |
| # 恢复标准输出和标准错误 | |
| sys.stdout = old_stdout | |
| sys.stderr = old_stderr | |
| # 处理输出 | |
| if output and not error_output: | |
| return output | |
| elif not output and not error_output: | |
| return 'Pass all test cases!' | |
| else: | |
| error_lines = error_output.splitlines() | |
| if error_lines: | |
| return f'Error: {error_lines[-1]}' | |
| else: | |
| return 'Error: Unknown error' | |
| return f'Error: {error_output}' | |
| def read_problems() -> Dict[str, Dict]: | |
| benchmark_file = "HumanEval-SingleLineInfilling.jsonl.gz" | |
| return {task["task_id"]: task for task in stream_jsonl(benchmark_file)} | |
| def stream_jsonl(filename: str) -> Iterable[Dict]: | |
| """ | |
| Parses each jsonl line and yields it as a dictionary | |
| """ | |
| if filename.endswith(".gz"): | |
| with open(filename, "rb") as gzfp: | |
| with gzip.open(gzfp, "rt") as fp: | |
| for line in fp: | |
| if any(not x.isspace() for x in line): | |
| yield json.loads(line) | |
| else: | |
| with open(filename, "r") as fp: | |
| for line in fp: | |
| if any(not x.isspace() for x in line): | |
| yield json.loads(line) | |
| problems = read_problems() | |
| class HFTokenizerWrapper(): | |
| def __init__(self, hf_tokenizer: str) -> None: | |
| self.tokenizer = hf_tokenizer | |
| self.bos_id = self.tokenizer.bos_token_id | |
| self.eos_id = self.tokenizer.eos_token_id | |
| self.mask_id = self.tokenizer.mask_token_id | |
| def encode(self, s: str, add_bos: bool = False, add_eos: bool = False): | |
| tokens = [self.bos_id] * add_bos + self.tokenizer.encode(s) + [self.eos_id] * add_eos | |
| return tokens | |
| def decode(self, tokens: List[int], **kwargs): | |
| return self.tokenizer.decode(tokens, **kwargs) | |
| def get_token_offsets( | |
| self, text: str, tokens: Optional[List[int]] = None | |
| ) -> Tuple[List[str], List[int]]: | |
| """Return the offsets of the tokens in the original text. Only used for evaluation.""" | |
| pass | |
| def convert_tokens_to_ids(self, tokens): | |
| return self.tokenizer.convert_tokens_to_ids(tokens) | |
| # --- START: Copied Helper functions from generation_utils.py --- | |
| # [Keep the copied functions: top_p_logits, top_k_logits, sample_tokens] | |
| def top_p_logits(logits, top_p=None): | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Shift the indices to the right to keep the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) | |
| mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) | |
| logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) | |
| return logits | |
| def top_k_logits(logits, top_k=None): | |
| top_k = min(top_k, logits.size(-1)) # Safety check | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) | |
| return logits | |
| def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): | |
| if temperature > 0: | |
| logits = logits / temperature | |
| if top_p is not None and top_p < 1: | |
| logits = top_p_logits(logits, top_p) | |
| if top_k is not None and top_k > 0: | |
| logits = top_k_logits(logits, top_k) | |
| probs = torch.softmax(logits, dim=-1) | |
| if temperature > 0: | |
| try: | |
| x0 = dists.Categorical(probs=probs).sample() | |
| confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) | |
| except: | |
| confidence, x0 = probs.max(dim=-1) | |
| else: | |
| confidence, x0 = probs.max(dim=-1) | |
| if margin_confidence: | |
| sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) | |
| # Extract top1 and top2 probabilities | |
| top1_probs = sorted_probs[:, 0] | |
| top2_probs = sorted_probs[:, 1] | |
| # Calculate confidence as top1 - top2 | |
| confidence = top1_probs - top2_probs | |
| if neg_entropy: | |
| epsilon = 1e-10 | |
| log_probs = torch.log(probs + epsilon) | |
| confidence = torch.sum(probs * log_probs, dim=-1) | |
| return confidence, x0 | |
| # --- END: Copied Helper functions --- | |
| # --- Model Loading and Constants --- | |
| # Load model configuration to get special token IDs | |
| model_path = "Dream-org/DreamOn-v0-7B" | |
| config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Using device: {device}") | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| tokenizer = HFTokenizerWrapper(tokenizer) | |
| print("Loading model...") | |
| model = AutoModel.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32, | |
| trust_remote_code=True, | |
| attn_implementation="sdpa" # Explicitly request SDPA | |
| ) | |
| model = model.to(device).eval() | |
| print("Model loaded.") | |
| MASK_TOKEN = '<|mask|>' | |
| MASK_ID = tokenizer.mask_id | |
| EOS_ID = tokenizer.eos_id | |
| try: | |
| EXPAND_ID = tokenizer.convert_tokens_to_ids('<|expand|>') | |
| except: | |
| raise ValueError("Cannot determine EXPAND_ID. Check model's tokenizer configuration") | |
| if MASK_ID is None: | |
| raise ValueError("Cannot determine MASK_ID. Check model's tokenizer configuration.") | |
| SPECIAL_TOKEN_IDS = {EOS_ID, MASK_ID} | |
| try: | |
| IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>") | |
| IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| SPECIAL_TOKEN_IDS.add(IM_START_ID) | |
| SPECIAL_TOKEN_IDS.add(IM_END_ID) | |
| except KeyError: | |
| print("Warning: <|im_start|> or <|im_end|> not found in tokenizer vocab.") | |
| IM_START_ID = None | |
| IM_END_ID = None | |
| # --- Helper Functions --- | |
| def parse_constraints(constraints_text: str) -> Dict[int, List[int]]: | |
| """ Parses word constraints. """ | |
| constraints = {} | |
| if not constraints_text: return constraints | |
| parts = constraints_text.split(',') | |
| for part in parts: | |
| part = part.strip() | |
| if ':' not in part: continue | |
| pos_str, word = part.split(':', 1) | |
| try: | |
| pos = int(pos_str.strip()) | |
| word = word.strip() | |
| token_ids = [] | |
| if word: | |
| text_to_encode = (" " + word) if (pos > 0 and not word.startswith(" ")) else word | |
| token_ids = tokenizer.encode(text_to_encode, add_special_tokens=False) | |
| if token_ids and pos >= 0: constraints[pos] = token_ids | |
| elif not token_ids and word: print(f"Warning: Could not tokenize constraint word '{word}'") | |
| except ValueError: print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'") | |
| except Exception as e: print(f"Warning: Error processing constraint '{part}': {e}") | |
| return constraints | |
| # Removed format_chat_history as the state will now be in the correct format | |
| def apply_constraints_to_state( | |
| x: torch.Tensor, | |
| prompt_length: int, | |
| total_length: int, | |
| parsed_constraints: Dict[int, List[int]], | |
| current_step: Optional[int] = None | |
| ) -> torch.Tensor: | |
| """ Applies constraints directly to the state tensor `x`. """ | |
| modified_x = x.clone() | |
| for rel_pos, word_token_ids in parsed_constraints.items(): | |
| abs_start_pos = prompt_length + rel_pos | |
| abs_end_pos = abs_start_pos + len(word_token_ids) | |
| if abs_start_pos < total_length and abs_end_pos <= total_length: | |
| try: | |
| constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device) | |
| modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor | |
| except IndexError: print(f"Warning (Step {current_step}): Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.") | |
| except Exception as e: print(f"Warning (Step {current_step}): Failed to apply constraint at {rel_pos}: {e}") | |
| return modified_x | |
| # --- Core Generation Logic with Live Visualization --- | |
| def infilling_dream( | |
| prefix: str, | |
| suffix: str, | |
| start_gen_len: int, | |
| max_gen_len: int, | |
| expand_budget: int, | |
| temperature: float, | |
| top_p: Optional[float], | |
| top_k: Optional[int], | |
| alg: str, | |
| alg_temp: Optional[float], | |
| visualization_delay: float, | |
| delete_righthand_eos: bool, | |
| task_id: str | |
| ) -> List[Tuple[str, str]]: | |
| # ------1. Prepare the input for infilling ----------------- | |
| prefix = prefix | |
| suffix = suffix | |
| prefix = tokenizer.encode(prefix, add_bos = True, add_eos = False) | |
| prefix_len = len(prefix) | |
| suffix = tokenizer.encode(suffix, add_bos = False, add_eos = True) | |
| input_ids = prefix + [MASK_ID] * start_gen_len + suffix | |
| input_ids = torch.LongTensor([input_ids]).to(device) | |
| max_tokens = input_ids.shape[1] + max_gen_len | |
| num_generation_tokens = start_gen_len | |
| cur_generation_window_length = input_ids.shape[1] - start_gen_len + num_generation_tokens | |
| x = F.pad(input_ids, (0, max_tokens - input_ids.shape[1]), value = MASK_ID) | |
| # ------ Visualization Setup | |
| initial_generated_tokens = input_ids[0, prefix_len: prefix_len + num_generation_tokens] | |
| #yield vis_data_initial | |
| yield tokenizer.decode(initial_generated_tokens.tolist()), '' | |
| time.sleep(visualization_delay) | |
| # ----2. Step by Step Infilling ---------------------------------------- | |
| for i in range(4 * max_gen_len): | |
| cur_generation_window_length = input_ids.shape[1] - start_gen_len + num_generation_tokens | |
| attention_mask = torch.ones([input_ids.shape[0], cur_generation_window_length], dtype = torch.int16).to(input_ids.device) | |
| attention_mask = F.pad(attention_mask, (0, max_tokens - attention_mask.shape[1]), value = 0) | |
| mask_index = (x == MASK_ID) & (attention_mask == 1) | |
| if torch.all(~mask_index[:,:cur_generation_window_length]): | |
| break | |
| tok_idx = attention_mask.long().cumsum(-1) - 1 | |
| tok_idx.masked_fill_(attention_mask == 0, 1) | |
| attention_mask = torch.logical_and( | |
| attention_mask.unsqueeze(1).unsqueeze(-2), | |
| attention_mask.unsqueeze(1).unsqueeze(-1), | |
| ) | |
| output = model(x, attention_mask, tok_idx) | |
| logits = output.logits | |
| logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) | |
| logits = logits[mask_index] | |
| ## block the logit for expansion when token budget is all used | |
| if cur_generation_window_length == max_tokens or expand_budget == 0: | |
| logits[:,EXPAND_ID] -= 1e9 | |
| if alg == 'maskgit_plus': | |
| confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k) | |
| elif alg == 'topk_margin': | |
| confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True) | |
| elif alg == 'entropy': | |
| confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k, neg_entropy=True) | |
| else: | |
| raise RuntimeError(f"Unknown alg: {alg}") | |
| #num_mask_token = mask_index.sum() | |
| #number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token | |
| number_transfer_tokens = 1 | |
| if number_transfer_tokens > 0: | |
| if alg_temp is None or alg_temp == 0: | |
| _, transfer_index = torch.topk(confidence, number_transfer_tokens) | |
| else: | |
| confidence = confidence / alg_temp | |
| confidence = F.softmax(confidence, dim=-1) | |
| transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens) | |
| x0_ = torch.zeros_like(x0, device=device, dtype=torch.long) + MASK_ID | |
| x0_[transfer_index] = x0[transfer_index].clone() | |
| x[mask_index] = x0_ | |
| # Only process if batch size is 1 | |
| if delete_righthand_eos: | |
| if x.shape[0] != 1: | |
| raise NotImplementedError | |
| x_seq = x[0] # Flatten to 1D: shape [seq_len] | |
| # Find indices where EOS occurs | |
| eos_indices = (x_seq == EOS_ID).nonzero(as_tuple=True) | |
| if len(eos_indices[0]) > 0: | |
| # Get the first occurrence of EOS | |
| # mask indices | |
| first_eos_idx = eos_indices[0][0].item() | |
| position_mask = torch.arange(x_seq.size(0), device=x.device) >= first_eos_idx | |
| replace_mask = position_mask & mask_index[0] | |
| # Set all tokens after EOS to eos_id | |
| x_seq.masked_fill_(replace_mask, EOS_ID) | |
| # # Reshape back to original shape (unsqueeze) | |
| x = x_seq.unsqueeze(0) | |
| ## Visualize Denoise Step | |
| cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens] | |
| cur_tokens = tokenizer.decode(cur_generated_tokens.tolist()) | |
| ## replace all <|endoftext|> with <|delete|> | |
| cur_tokens = cur_tokens.replace("<|endoftext|>", "<|delete|>") | |
| yield cur_tokens, '' | |
| time.sleep(visualization_delay) | |
| # Expansion Step: Check for expand_id and replace with two mask tokens | |
| expand_indices = (x[0] == EXPAND_ID).nonzero(as_tuple=False).squeeze(1) | |
| if expand_indices.numel() > 0: | |
| # Process from right to left to prevent shifting issues | |
| for idx in sorted(expand_indices.tolist(), reverse=True): | |
| x = torch.cat(( | |
| x[:, :idx], | |
| torch.tensor([[MASK_ID, MASK_ID]], device=x.device), | |
| x[:, idx + 1:] | |
| ), dim=1) | |
| num_generation_tokens += 1 | |
| expand_budget -= 1 | |
| # Truncate back to max_tokens if needed | |
| if x.shape[1] > max_tokens: | |
| x = x[:, :max_tokens] | |
| cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens] | |
| vis_data = [] | |
| # [Visualization formatting logic remains the same] | |
| for j in range(num_generation_tokens): | |
| current_tok_id = cur_generated_tokens[j].item() | |
| try: | |
| decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False) | |
| display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token | |
| except Exception: display_token = f"[ID:{current_tok_id}]" | |
| color = None; token_to_display = display_token | |
| if current_tok_id == MASK_ID: color = "#444444" | |
| else: color = "#6699CC" | |
| if token_to_display: vis_data.append((token_to_display, color)) | |
| yield tokenizer.decode(cur_generated_tokens.tolist()), '' | |
| #yield vis_data | |
| time.sleep(visualization_delay) | |
| ## detele EOS tokens from middle | |
| # Find indices where EOS occurs | |
| eos_indices = ((x[0] == EOS_ID) & (mask_index[0] == 1)).nonzero(as_tuple=False).squeeze(1) | |
| if eos_indices.numel() > 0: | |
| for idx in sorted(eos_indices.tolist(), reverse=True): | |
| x = torch.cat(( | |
| x[:, :idx], | |
| x[:, idx + 1:], | |
| torch.tensor([[MASK_ID]], device = x.device) | |
| ), dim = 1) | |
| num_generation_tokens -= 1 | |
| cur_generated_tokens = x[0, prefix_len: prefix_len + num_generation_tokens] | |
| yield tokenizer.decode(cur_generated_tokens.tolist()), '' | |
| time.sleep(visualization_delay) | |
| generated_code = tokenizer.decode(x[0, prefix_len: prefix_len + num_generation_tokens].tolist()) | |
| yield generated_code, '' | |
| def get_example_input(): | |
| ### this functions samples a case from humaneval-infilling as prefix and suffix | |
| task_id = random.choice(list(problems.keys())) | |
| problem = problems[task_id] | |
| prefix, suffix = problem['prompt'], problem['suffix'] | |
| test_case = problem['test'] | |
| pattern = r'METADATA\s*=\s*\{.*?\}\n\n' | |
| cleaned_code = re.sub(pattern, '', code, flags=re.DOTALL) | |
| test_case = test_case.replace('def check(candidate):', 'def run_test():') | |
| test_case = test_case.replace('candidate', problem['entry_point']) | |
| test_case = test_case + '\n\n\nrun_test()' | |
| return prefix, '', suffix, test_case, task_id, '' | |
| def check_result(prompt, completion, suffix, test_case): | |
| prompt = str(prompt) if prompt is not None else "" | |
| completion = str(completion) if completion is not None else "" | |
| suffix = str(suffix) if suffix is not None else "" | |
| test_case = str(test_case) if test_case is not None else "" | |
| print('prefix', prompt) | |
| print('middle', completion) | |
| print('suffix', suffix) | |
| print('test', test_case) | |
| result = unsafe_execute(prompt, completion, suffix, test_case) | |
| return result | |
| # --- Gradio UI --- | |
| css = ''' | |
| .category-legend{display:none} | |
| ''' | |
| def create_chatbot_demo(): | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# DreamOn: Diffusion Language Models For Code Infilling Beyond Fixed-size Canvas\nClick **Example Prompt** to get a prefix and suffix, then click **Generate** to generate code. Have fun!") | |
| gr.Markdown( | |
| "[[Model Card](https://huggingface.co/Dream-org/DreamOn-v0-7B)] " | |
| "[[Blog](https://hkunlp.github.io/blog/2025/dreamon/)]" | |
| ) | |
| with gr.Row(): | |
| sample_btn = gr.Button("Example Prompt") | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| check_btn = gr.Button("Run test case") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Prefix input | |
| prefix_input = gr.Textbox( | |
| label="Prefix Text", | |
| placeholder="Enter the beginning of your text...", | |
| lines=2 | |
| ) | |
| # Middle generation/visualization area | |
| output_vis = gr.Textbox( | |
| label="Generated Text (Middle)", | |
| lines=2 | |
| ) | |
| # Suffix input | |
| suffix_input = gr.Textbox( | |
| label="Suffix Text", | |
| placeholder="Enter the end of your text...", | |
| lines=2 | |
| ) | |
| # Hidden Task ID input | |
| task_id_input = gr.Textbox( | |
| label="Task ID", | |
| placeholder="Task ID will be stored here...", | |
| visible=False | |
| ) | |
| with gr.Column(): | |
| # Test Case input | |
| test_case_input = gr.Textbox( | |
| label="Test Case", | |
| placeholder="Enter your test case here...", | |
| lines=2 | |
| ) | |
| # Result of execution | |
| result_output = gr.Textbox( | |
| label="Result of Execution", | |
| placeholder="Execution result will be shown here...", | |
| lines=2 | |
| ) | |
| # Generation Settings | |
| with gr.Accordion("Generation Settings"): | |
| with gr.Row(): | |
| start_gen_len = gr.Slider( | |
| minimum=4, | |
| maximum=64, | |
| value=4, | |
| step=4, | |
| label="Initial Generation Length" | |
| ) | |
| max_gen_len = gr.Slider( | |
| minimum=32, | |
| maximum=64, | |
| value=64, | |
| step=8, | |
| label="Maximum Generation Length" | |
| ) | |
| with gr.Row(): | |
| expand_budget = gr.Slider( | |
| minimum=0, | |
| maximum=256, | |
| value=64, | |
| step=8, | |
| label="Expansion Budget" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| label="Temperature" | |
| ) | |
| with gr.Row(): | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-P (0 disables)" | |
| ) | |
| with gr.Row(): | |
| top_k = gr.Slider( | |
| minimum=0, | |
| maximum=200, | |
| value=0, | |
| step=5, | |
| label="Top-K (0 disables)") | |
| with gr.Row(): | |
| alg = gr.Radio( | |
| choices=['maskgit_plus', 'topk_margin', 'entropy'], | |
| value='entropy', | |
| label="Generation Algorithm" | |
| ) | |
| alg_temp = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| label="Algorithm Temperature" | |
| ) | |
| with gr.Row(): | |
| visualization_delay = gr.Slider( | |
| minimum=0.1, | |
| maximum=3, | |
| value=0.2, | |
| step=0.1, | |
| label="Visualization Delay (s)" | |
| ) | |
| pad_delete_righthand = gr.Checkbox( | |
| label="Delete all tokens on the righthand side of <|delete|>", | |
| value=True | |
| ) | |
| # Connect the UI elements | |
| generation_inputs = [ | |
| prefix_input, | |
| suffix_input, | |
| start_gen_len, | |
| max_gen_len, | |
| expand_budget, | |
| temperature, | |
| top_p, | |
| top_k, | |
| alg, | |
| alg_temp, | |
| visualization_delay, | |
| pad_delete_righthand, | |
| task_id_input | |
| ] | |
| test_inputs=[prefix_input, output_vis, suffix_input, test_case_input] | |
| generate_btn.click( | |
| fn=infilling_dream, | |
| inputs=generation_inputs, | |
| outputs=[output_vis, result_output], | |
| show_progress="hidden" | |
| ) | |
| clear_btn.click( | |
| lambda: ("", "", "", "", "", ""), # Clear all inputs and outputs | |
| inputs=[], | |
| outputs=[prefix_input, suffix_input, output_vis, test_case_input, result_output, task_id_input], | |
| queue=False | |
| ) | |
| check_btn.click( | |
| fn=check_result, | |
| inputs=test_inputs, | |
| outputs=[result_output], | |
| queue=False | |
| ) | |
| sample_btn.click( | |
| fn=get_example_input, | |
| outputs=[prefix_input, output_vis, suffix_input, test_case_input, task_id_input, result_output], | |
| queue=False | |
| ) | |
| return demo | |
| # --- Launch --- | |
| if __name__ == "__main__": | |
| #test() | |
| demo = create_chatbot_demo() | |
| demo.queue().launch(debug=True) |