import deepsparse import gradio as gr from typing import Tuple, List deepsparse.cpu.print_hardware_capability() MODEL_ID = "hf:neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat-quant-ds" # MODEL_ID = "hf:neuralmagic/mpt-7b-chat-pruned50-quant" DESCRIPTION = f""" # Chat with an Efficient Sparse Llama 2 Model on CPU This demo showcases a groundbreaking [sparse Llama 2 7B model](https://huggingface.co/neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat-quant-ds) that has been pruned to 70% sparsity, retrained on pretraining data, and then sparse transferred for chat using the UltraChat 200k dataset. By leveraging the power of sparse transfer learning, this model delivers high-quality chat capabilities while significantly reducing computational costs and inference times. ### Under the Hood - **Sparse Transfer Learning**: The model's pre-sparsified structure enables efficient fine-tuning on new tasks, minimizing the need for extensive hyperparameter tuning and reducing training times. - **Accelerated Inference**: Powered by the [DeepSparse CPU inference runtime](https://github.com/neuralmagic/deepsparse), this model takes advantage of its inherent sparsity to provide lightning-fast token generation on CPUs. - **Quantization**: 8-bit weight and activation quantization further optimizes the model's performance and memory footprint without compromising quality. By combining state-of-the-art sparsity techniques with the robustness of the Llama 2 architecture, this model pushes the boundaries of efficient generation. Experience the future of AI-powered chat, where cutting-edge sparse models deliver exceptional performance on everyday hardware. """ MAX_MAX_NEW_TOKENS = 1024 DEFAULT_MAX_NEW_TOKENS = 200 # Setup the engine from deepsparse.legacy import Pipeline pipe = Pipeline.create( task="text-generation", model_path=MODEL_ID, sequence_length=MAX_MAX_NEW_TOKENS, prompt_sequence_length=8, num_cores=8, ) def clear_and_save_textbox(message: str) -> Tuple[str, str]: return "", message def display_input( message: str, history: List[Tuple[str, str]] ) -> List[Tuple[str, str]]: history.append((message, "")) return history def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]: try: message, _ = history.pop() except IndexError: message = "" return history, message or "" with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Group(): chatbot = gr.Chatbot(label="Chatbot") with gr.Row(): textbox = gr.Textbox( container=False, show_label=False, placeholder="Type a message...", scale=10, ) submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0) with gr.Row(): retry_button = gr.Button("🔄 Retry", variant="secondary") undo_button = gr.Button("↩ī¸ Undo", variant="secondary") clear_button = gr.Button("🗑ī¸ Clear", variant="secondary") saved_input = gr.State() gr.Examples( examples=[ "Write a story about sparse neurons.", "Write a story about a summer camp.", "Make a recipe for banana bread.", "Write a cookbook for gluten-free snacks.", "Write about the role of animation in video games." ], inputs=[textbox], ) max_new_tokens = gr.Slider( label="Max new tokens", value=DEFAULT_MAX_NEW_TOKENS, minimum=0, maximum=MAX_MAX_NEW_TOKENS, step=1, interactive=True, info="The maximum numbers of new tokens", ) temperature = gr.Slider( label="Temperature", value=0.9, minimum=0.05, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ) top_p = gr.Slider( label="Top-p (nucleus) sampling", value=0.40, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ) top_k = gr.Slider( label="Top-k sampling", value=20, minimum=1, maximum=100, step=1, interactive=True, info="Sample from the top_k most likely tokens", ) reptition_penalty = gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) # Generation inference def generate( message, history, max_new_tokens: int, temperature: float, top_p: float, top_k: int, reptition_penalty: float, ): generation_config = { "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "reptition_penalty": reptition_penalty, } conversation = [] conversation.append({"role": "user", "content": message}) formatted_conversation = pipe.tokenizer.apply_chat_template( conversation, tokenize=False, add_generation_prompt=True ) inference = pipe( sequences=formatted_conversation, generation_config=generation_config, streaming=True, ) for token in inference: history[-1][1] += token.generations[0].text yield history print(pipe.timer_manager) # Hooking up all the buttons textbox.submit( fn=clear_and_save_textbox, inputs=textbox, outputs=[textbox, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).success( generate, inputs=[ saved_input, chatbot, max_new_tokens, temperature, top_p, top_k, reptition_penalty, ], outputs=[chatbot], api_name=False, ) submit_button.click( fn=clear_and_save_textbox, inputs=textbox, outputs=[textbox, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).success( generate, inputs=[ saved_input, chatbot, max_new_tokens, temperature, top_p, top_k, reptition_penalty, ], outputs=[chatbot], api_name=False, ) retry_button.click( fn=delete_prev_fn, inputs=chatbot, outputs=[chatbot, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).then( generate, inputs=[ saved_input, chatbot, max_new_tokens, temperature, top_p, top_k, reptition_penalty, ], outputs=[chatbot], api_name=False, ) undo_button.click( fn=delete_prev_fn, inputs=chatbot, outputs=[chatbot, saved_input], api_name=False, queue=False, ).then( fn=lambda x: x, inputs=[saved_input], outputs=textbox, api_name=False, queue=False, ) clear_button.click( fn=lambda: ([], ""), outputs=[chatbot, saved_input], queue=False, api_name=False, ) demo.queue().launch(share=True)