File size: 2,744 Bytes
182c7c5
 
4664dfb
182c7c5
 
13a2bfa
182c7c5
4664dfb
 
182c7c5
4664dfb
 
182c7c5
4664dfb
 
 
 
182c7c5
 
4664dfb
 
182c7c5
 
 
 
4664dfb
 
182c7c5
 
4664dfb
 
 
 
182c7c5
4664dfb
 
 
 
 
182c7c5
4664dfb
182c7c5
4664dfb
 
 
182c7c5
4664dfb
182c7c5
 
4664dfb
 
 
 
 
 
 
 
182c7c5
4664dfb
 
 
 
 
182c7c5
 
4664dfb
 
 
 
 
 
 
182c7c5
4664dfb
182c7c5
4664dfb
182c7c5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "mistralai/Devstral-Small-2505"
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

def code_completion(prompt, max_new_tokens=128, temperature=0.2):
    if not prompt.strip():
        return "Please enter some code to complete."
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated[len(prompt):]

custom_css = """
body {background: #f7f8fa;}
.gradio-container {background: #f7f8fa;}
h1, h2, h3, h4, h5, h6 {font-family: 'Inter', sans-serif;}
#main-title {
    text-align: center;
    font-weight: 800;
    font-size: 2.3em;
    margin-bottom: 0.2em;
    letter-spacing: -1px;
    color: #222;
}
#subtitle {
    text-align: center;
    color: #6c6f7a;
    font-size: 1.1em;
    margin-bottom: 2em;
}
.gr-box {border-radius: 16px;}
"""

with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        <h1 id="main-title">Devstral Code Autocomplete</h1>
        <div id="subtitle">Minimal, beautiful code completion powered by <b>Devstral</b></div>
        """)
    with gr.Row():
        with gr.Column(scale=1):
            prompt = gr.Textbox(
                label="Your code prompt",
                lines=10,
                placeholder="def quicksort(arr):\n    \"\"\"Sort the array using quicksort algorithm.\"\"\"\n    if len(arr) <= 1:\n        return arr\n    pivot = arr[len(arr) // 2]\n    ",
                show_copy_button=True,
                autofocus=True
            )
            with gr.Row():
                max_tokens = gr.Slider(16, 256, value=128, step=8, label="Max new tokens")
                temperature = gr.Slider(0.1, 1.0, value=0.2, step=0.05, label="Temperature")
            btn = gr.Button("Generate Completion", elem_id="generate-btn")
        with gr.Column(scale=1):
            output = gr.Code(
                label="Generated code",
                language="python",
                lines=12,
                interactive=False
            )
    btn.click(code_completion, inputs=[prompt, max_tokens, temperature], outputs=output)

demo.launch()