File size: 5,392 Bytes
de2c765
 
6f2fe4d
de2c765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f2fe4d
de2c765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f2fe4d
de2c765
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# app.py

import gradio as gr
import torch
from unsloth import FastLanguageModel
import langid

# 1. ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ „์—ญ์ ์œผ๋กœ ํ•œ ๋ฒˆ๋งŒ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
# Zero-GPU ํ™˜๊ฒฝ์— ๋งž๊ฒŒ 4๋น„ํŠธ๋กœ ๋ชจ๋ธ์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
max_seq_length = 2048
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/DeepSeek-R1-0528-Qwen3-8B",
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    # Zero-GPU(CPU) ํ™˜๊ฒฝ์ด๋ฏ€๋กœ vLLM ๋น„ํ™œ์„ฑํ™”
    fast_inference=False, 
    # LoRA ์–ด๋Œ‘ํ„ฐ๋ฅผ ๋กœ๋“œํ•˜๊ธฐ ์œ„ํ•ด ๋ฏธ๋ฆฌ ์ตœ๋Œ€ ๋žญํฌ๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
    max_lora_rank=32,
)

# PEFT ๋ชจ๋ธ์— LoRA ๋ชจ๋“ˆ์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
# ์ด ๋‹จ๊ณ„๋Š” ์ถ”ํ›„ model.load_lora()๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
model = FastLanguageModel.get_peft_model(
    model,
    r=32,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=64,
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

# 2. ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ์ •์˜
# ๋…ธํŠธ๋ถ์—์„œ ์‚ฌ์šฉ๋œ ๊ฒƒ๊ณผ ๋™์ผํ•œ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.
system_prompt = (
    "You are given a problem.\n"
    "Think about the problem and provide your working out.\n"
    "You must think in Bahasa Indonesia."
)

# 3. ์ถ”๋ก  ํ•จ์ˆ˜ ์ •์˜
def generate_response(user_prompt, use_lora):
    """
    ์‚ฌ์šฉ์ž ์ž…๋ ฅ๊ณผ LoRA ์‚ฌ์šฉ ์—ฌ๋ถ€์— ๋”ฐ๋ผ ๋ชจ๋ธ ์‘๋‹ต์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
    """
    lora_request = None
    if use_lora:
        try:
            # Hugging Face Space์— ํ•จ๊ป˜ ์—…๋กœ๋“œ๋œ LoRA ์–ด๋Œ‘ํ„ฐ๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
            # ํด๋” ์ด๋ฆ„์€ ๋…ธํŠธ๋ถ์—์„œ ์ €์žฅํ•œ 'grpo_lora'์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
            lora_request = model.load_lora("grpo_lora")
        except Exception as e:
            return f"LoRA ์–ด๋Œ‘ํ„ฐ๋ฅผ ๋กœ๋“œํ•˜๋Š” ๋ฐ ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค: {e}\n'grpo_lora' ํด๋”๋ฅผ Space์— ์—…๋กœ๋“œํ–ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.", "์˜ค๋ฅ˜"

    # ์ฑ„ํŒ… ํ…œํ”Œ๋ฆฟ ํ˜•์‹์— ๋งž๊ฒŒ ์ž…๋ ฅ ๋ฉ”์‹œ์ง€๋ฅผ ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]

    # ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ํฌ๋งทํŒ…ํ•ฉ๋‹ˆ๋‹ค.
    input_text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    
    inputs = tokenizer(input_text, return_tensors="pt").to("cpu")

    # ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ ์ƒ์„ฑ
    # Unsloth๋Š” CPU์—์„œ๋„ ๋น ๋ฅธ ์ƒ์„ฑ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.
    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        use_cache=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    generated_text = tokenizer.batch_decode(outputs)[0]
    
    # ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ์—์„œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ œ์™ธํ•˜๊ณ  ์ˆœ์ˆ˜ ์‘๋‹ต๋งŒ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
    response_only = generated_text[len(input_text):]

    # ์ƒ์„ฑ๋œ ์‘๋‹ต์˜ ์–ธ์–ด๋ฅผ ๊ฐ์ง€ํ•ฉ๋‹ˆ๋‹ค.
    lang, score = langid.classify(response_only)
    lang_info = f"๊ฐ์ง€๋œ ์–ธ์–ด: {lang} (์‹ ๋ขฐ๋„: {score:.2f})"

    return response_only, lang_info


# 4. Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # ๐Ÿ‡ฎ๐Ÿ‡ฉ DeepSeek-R1-Qwen3-8B ๋ชจ๋ธ ์ถ”๋ก  (GRPO ํŠœ๋‹)
        ์ด ๋ชจ๋ธ์€ ์ˆ˜ํ•™ ๋ฌธ์ œ์— ๋Œ€ํ•ด ์ธ๋„๋„ค์‹œ์•„์–ด๋กœ ์ถ”๋ก  ๊ณผ์ •์„ ์„ค๋ช…ํ•˜๋„๋ก ๋ฏธ์„ธ ์กฐ์ •๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
        - **'์ธ๋„๋„ค์‹œ์•„์–ด ์ถ”๋ก  LoRA ์ ์šฉ'** ์ฒดํฌ๋ฐ•์Šค๋ฅผ ํ™œ์„ฑํ™”ํ•˜๋ฉด, ํ•™์Šต๋œ LoRA ๊ฐ€์ค‘์น˜๊ฐ€ ์ ์šฉ๋˜์–ด ์ธ๋„๋„ค์‹œ์•„์–ด๋กœ ๋œ ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•˜๋„๋ก ์œ ๋„ํ•ฉ๋‹ˆ๋‹ค.
        - ์ฒดํฌ๋ฐ•์Šค๋ฅผ ๋น„ํ™œ์„ฑํ™”ํ•˜๋ฉด ์›๋ณธ ๋ชจ๋ธ์˜ ์ถ”๋ก  ๋Šฅ๋ ฅ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=2):
            prompt_input = gr.Textbox(
                label="์งˆ๋ฌธ ์ž…๋ ฅ", 
                placeholder="์˜ˆ: Solve (x + 2)^2 = 0"
            )
            lora_checkbox = gr.Checkbox(
                label="์ธ๋„๋„ค์‹œ์•„์–ด ์ถ”๋ก  LoRA ์ ์šฉ", 
                value=True
            )
            submit_button = gr.Button("์ƒ์„ฑํ•˜๊ธฐ", variant="primary")
        
        with gr.Column(scale=3):
            output_text = gr.Textbox(
                label="๋ชจ๋ธ ์‘๋‹ต", 
                interactive=False
            )
            language_info = gr.Textbox(
                label="์–ธ์–ด ๊ฐ์ง€ ๊ฒฐ๊ณผ", 
                interactive=False
            )
    
    submit_button.click(
        fn=generate_response,
        inputs=[prompt_input, lora_checkbox],
        outputs=[output_text, language_info]
    )
    
    gr.Examples(
        [
            ["Solve (x + 2)^2 = 0", True],
            ["What is the square root of 101?", True],
            ["In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$ be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$ and $\\angle BDC = 90^\\circ$. Suppose that $AD = 1$ and that $\\frac{BD}{CD} = \\frac{3}{2}$. If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$ where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.", True]
        ],
        inputs=[prompt_input, lora_checkbox],
        outputs=[output_text, language_info],
        fn=generate_response,
        cache_examples=False,
    )

# Gradio ์•ฑ ์‹คํ–‰
demo.launch()