File size: 2,058 Bytes
e8c1290
 
 
deec06c
e8c1290
 
0ad5e40
e8c1290
 
 
e26eb4c
e8c1290
 
 
deec06c
e8c1290
e26eb4c
 
e8c1290
 
e26eb4c
e8c1290
e26eb4c
e8c1290
 
e26eb4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47d913b
e26eb4c
 
47d913b
e26eb4c
47d913b
e26eb4c
 
 
 
47d913b
e8c1290
 
 
 
 
2bdd70f
e8c1290
 
 
 
 
 
 
 
 
 
 
ad5f7a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import pipeline
import torch
import spaces

# Initialize the model pipeline
model_id = "facebook/MobileLLM-R1-950M"
pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.float16,
    device_map="auto",
)

@spaces.GPU(duration=120)
def respond(message, history):
    # Build prompt from history
    prompt = ""
    for user_msg, assistant_msg in history:
        if user_msg:
            prompt += f"User: {user_msg}\n"
        if assistant_msg:
            prompt += f"Assistant: {assistant_msg}\n"
    
    # Add current message
    prompt += f"User: {message}\nAssistant: "
    
    # Generate response with streaming
    streamer = pipe.tokenizer.decode
    
    # Generate tokens
    inputs = pipe.tokenizer(prompt, return_tensors="pt").to(pipe.model.device)
    
    with torch.no_grad():
        outputs = pipe.model.generate(
            **inputs,
            max_new_tokens=10000,
            temperature=0.7,
            do_sample=True,
            pad_token_id=pipe.tokenizer.eos_token_id,
        )
    
    # Decode the generated tokens, skipping the input tokens
    generated_tokens = outputs[0][inputs['input_ids'].shape[-1]:]
    
    # Stream the output token by token
    response_text = ""
    for i in range(len(generated_tokens)):
        token = generated_tokens[i:i+1]
        token_text = pipe.tokenizer.decode(token, skip_special_tokens=True)
        response_text += token_text
        yield response_text

# Create the chat interface
demo = gr.ChatInterface(
    fn=respond,
    title="MobileLLM Chat",
    description="Chat with Meta MobileLLM-R1-950M",
    examples=[
        "Write a Python function that returns the square of a number.",
        "Compute: 1-2+3-4+5- ... +99-100.",
        "Write a C++ program that prints 'Hello, World!'.",
        "Explain how recursion works in programming.",
        "What is the difference between a list and a tuple in Python?",
    ],
    theme=gr.themes.Soft(),
)

if __name__ == "__main__":
    demo.launch(share=True)