import os os.system("pip install git+https://github.com/shumingma/transformers.git") import threading import torch import torch._dynamo torch._dynamo.config.suppress_errors = True from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, ) import gradio as gr import spaces model_id = "microsoft/bitnet-b1.58-2B-4T" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto" ) print(model.device) @spaces.GPU def respond( message: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, ): """ Generate a chat response using streaming with TextIteratorStreamer. Args: message: User's current message. history: List of (user, assistant) tuples from previous turns. system_message: Initial system prompt guiding the assistant. max_tokens: Maximum number of tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling probability. Yields: The growing response text as new tokens are generated. """ messages = [{"role": "system", "content": system_message}] for user_msg, bot_msg in history: if user_msg: messages.append({"role": "user", "content": user_msg}) if bot_msg: messages.append({"role": "assistant", "content": bot_msg}) messages.append({"role": "user", "content": message}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, ) thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() response = "" for new_text in streamer: response += new_text yield response demo = gr.ChatInterface( fn=respond, title="Bitnet-b1.58-2B-4T Chatbot", description="This chat application is powered by Microsoft's SOTA Bitnet-b1.58-2B-4T and designed for natural and fast conversations.", examples=[ [ "Hello! How are you?", "You are a helpful AI assistant.", 512, 0.7, 0.95, ], [ "Can you code a snake game in Python?", "You are a helpful AI assistant.", 2048, 0.7, 0.95, ], ], additional_inputs=[ gr.Textbox( value="You are a helpful AI assistant.", label="System message" ), gr.Slider( minimum=1, maximum=8192, value=2048, step=1, label="Max new tokens" ), gr.Slider( minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" ), ], ) if __name__ == "__main__": demo.launch()