import argparse import os from queue import SimpleQueue from threading import Thread from typing import Iterator import gradio as gr import spaces import torch from gradio import Chatbot from image_utils import ImageStitcher from transformers import (AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer) from StreamDiffusionIO import LatentConsistencyModelStreamIO MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) DESCRIPTION = """\ # Kanji-Streaming Chat 🌍 This Space is adapted from [Llama-2-7b-chat](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat) space, demonstrating how to "chat" with LLM with [Kanji-Streaming](https://github.com/AgainstEntropy/kanji). 🔨 The technique behind Kanji-Streaming is [StreamDiffusionIO](https://github.com/AgainstEntropy/StreamDiffusionIO), which is based on [StreamDiffusion](https://github.com/cumulo-autumn/StreamDiffusion), *but especially allows to render text streams into image streams*. 🔎 For more details about Kanji-Streaming, take a look at the [github repository](https://github.com/AgainstEntropy/kanji). """ LICENSE = """
--- As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md). """ parser = argparse.ArgumentParser(description="Gradio launcher for Streaming-Kanji.") parser.add_argument( "--llama_model_id_or_path", type=str, default="meta-llama/Llama-2-7b-chat-hf", required=False, help="Path to downloaded llama-chat-hf model or model identifier from huggingface.co/models.", ) parser.add_argument( "--sd_model_id_or_path", type=str, default="runwayml/stable-diffusion-v1-5", required=False, help="Path to downloaded sd-1-5 model or model identifier from huggingface.co/models.", ) parser.add_argument( "--lora_path", type=str, default="AgainstEntropy/kanji-lora-sd-v1-5", required=False, help="Path to downloaded LoRA weight or model identifier from huggingface.co/models.", ) parser.add_argument( "--lcm_lora_path", type=str, default="AgainstEntropy/kanji-lcm-lora-sd-v1-5", required=False, help="Path to downloaded LCM-LoRA weight or model identifier from huggingface.co/models.", ) parser.add_argument( "--img_res", type=int, default=64, required=False, help="Image resolution for displaying Kanji characters in ChatBot.", ) parser.add_argument( "--img_per_line", type=int, default=16, required=False, help="Number of Kanji characters to display in a single line.", ) parser.add_argument( "--tmp_dir", type=str, default="./tmp", required=False, help="Path to save temporary images generated by StreamDiffusionIO.", ) args = parser.parse_args() if torch.cuda.is_available(): device = "cuda" else: device = "cpu" DESCRIPTION += "\nRunning on CPU 🥶 This demo works best on GPU.
" DESCRIPTION += "\nThis demo will get the best kanji streaming experience in localhost (or SSH forward), instead of shared link generated by Gradio.
" model = AutoModelForCausalLM.from_pretrained(args.llama_model_id_or_path, torch_dtype=torch.float16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(args.llama_model_id_or_path) tokenizer.use_default_system_prompt = False streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) lcm_stream = LatentConsistencyModelStreamIO( model_id_or_path=args.sd_model_id_or_path, lcm_lora_path=args.lcm_lora_path, lora_dict={args.lora_path: 1}, resolution=128, device=device, use_xformers=True, verbose=True, ) tmp_dir_template = f"{args.tmp_dir}/%d" response_num = 0 stitcher = ImageStitcher( tmp_dir=tmp_dir_template % response_num, img_res=args.img_res, img_per_line=args.img_per_line, verbose=True, ) @spaces.GPU def generate( message: str, chat_history: list[tuple[str, str]], show_original_response: bool, seed: int, system_prompt: str = '', max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: conversation = [] if system_prompt: conversation.append({"role": "system", "content": system_prompt}) for user, assistant in chat_history: if isinstance(assistant, tuple): assistant = assistant[1] else: assistant = str(assistant) conversation.extend([ {"role": "user", "content": user}, {"role": "assistant", "content": assistant}, ]) conversation.append({"role": "user", "content": message}) print(conversation) input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] prompt_queue = SimpleQueue() lcm_stream.reset(seed) stitcher.reset() global response_num response_num += 1 stitcher.update_tmp_dir(tmp_dir_template % response_num) def append_to_queue(): for text in streamer: outputs.append(text) prompt = text.strip() if prompt: if prompt.endswith("."): prompt = prompt[:-1] prompt_queue.put(prompt) prompt_queue.put(None) append_thread = Thread(target=append_to_queue) append_thread.start() def show_image(prompt: str = None): image, text = lcm_stream(prompt) img_path = None if image is not None: img_path = stitcher.add(image, text) return img_path while True: prompt = prompt_queue.get() if prompt is None: break img_path = show_image(prompt) if img_path is not None: yield (img_path, ) # Continue to display the remaining images while True: img_path = show_image() if img_path is not None: yield (img_path, ''.join(outputs)) if lcm_stream.stop(): break print(outputs) if show_original_response: yield ''.join(outputs) chat_interface = gr.ChatInterface( fn=generate, chatbot=Chatbot(height=400), additional_inputs=[ gr.Checkbox( label="Show original response", value=False, ), gr.Number( label="Seed", info="Random Seed for Kanji Generation (maybe some kind of accent 🤔)", step=1, value=1026, ), gr.Textbox(label="System prompt", lines=4), gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], stop_btn=None, examples=[ ["Hello there! How are you doing?"], ["Can you explain briefly to me what is the Python programming language?"], ["Explain the plot of Cinderella in a sentence."], ["How many hours does it take a man to eat a Helicopter?"], ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], ], ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") chat_interface.render() gr.Markdown(LICENSE) if __name__ == "__main__": demo.queue(max_size=20).launch(server_name="0.0.0.0", share=False)