import gradio as gr from transformers import pipeline import torch import spaces # Initialize the model pipeline model_id = "facebook/MobileLLM-R1-950M" pipe = pipeline( "text-generation", model=model_id, torch_dtype=torch.float16, device_map="auto", ) @spaces.GPU(duration=120) def respond(message, history): # Build prompt from history prompt = "" for user_msg, assistant_msg in history: if user_msg: prompt += f"User: {user_msg}\n" if assistant_msg: prompt += f"Assistant: {assistant_msg}\n" # Add current message prompt += f"User: {message}\nAssistant: " # Generate response with streaming streamer = pipe.tokenizer.decode # Generate tokens inputs = pipe.tokenizer(prompt, return_tensors="pt").to(pipe.model.device) with torch.no_grad(): outputs = pipe.model.generate( **inputs, max_new_tokens=10000, temperature=0.7, do_sample=True, pad_token_id=pipe.tokenizer.eos_token_id, ) # Decode the generated tokens, skipping the input tokens generated_tokens = outputs[0][inputs['input_ids'].shape[-1]:] # Stream the output token by token response_text = "" for i in range(len(generated_tokens)): token = generated_tokens[i:i+1] token_text = pipe.tokenizer.decode(token, skip_special_tokens=True) response_text += token_text yield response_text # Create the chat interface demo = gr.ChatInterface( fn=respond, title="MobileLLM Chat", description="Chat with Meta MobileLLM-R1-950M", examples=[ "Write a Python function that returns the square of a number.", "Compute: 1-2+3-4+5- ... +99-100.", "Write a C++ program that prints 'Hello, World!'.", "Explain how recursion works in programming.", "What is the difference between a list and a tuple in Python?", ], theme=gr.themes.Soft(), ) if __name__ == "__main__": demo.launch(share=True)