Spaces:
Sleeping
Sleeping
Jeong-hun Kim
commited on
Commit
ยท
5c6d006
1
Parent(s):
9089f89
add config, prompt editor, debug mod, emotion text parser
Browse files- app.py +0 -21
- app/main.py +13 -2
- assets/prompt/init.txt +3 -3
- config.json +12 -0
- core/launch_gradio.py +185 -42
- core/make_pipeline.py +24 -9
- core/utils.py +22 -0
- requirements.txt +3 -1
app.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
-
from core.make_pipeline import MakePipeline
|
4 |
-
from core.context_manager import ContextManager
|
5 |
-
from core.launch_gradio import create_interface
|
6 |
-
|
7 |
-
###########################
|
8 |
-
# Upload to Huggling Face #
|
9 |
-
###########################
|
10 |
-
|
11 |
-
# ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
|
12 |
-
makePipeline = MakePipeline()
|
13 |
-
makePipeline.build("hf")
|
14 |
-
|
15 |
-
# ์ฑํ
๊ธฐ๋ก ๊ด๋ฆฌ์
|
16 |
-
ctx = ContextManager()
|
17 |
-
|
18 |
-
# Gradio ์ธํฐํ์ด์ค ์์
|
19 |
-
demo = create_interface(ctx, makePipeline)
|
20 |
-
demo.launch()
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/main.py
CHANGED
@@ -1,15 +1,26 @@
|
|
1 |
from core.make_pipeline import MakePipeline
|
2 |
from core.context_manager import ContextManager
|
3 |
from core.launch_gradio import create_interface
|
|
|
4 |
|
5 |
########################
|
6 |
# Start with localhost #
|
7 |
########################
|
|
|
|
|
8 |
|
9 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
10 |
# ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# ์ฑํ
๊ธฐ๋ก ๊ด๋ฆฌ์
|
15 |
ctx = ContextManager()
|
|
|
1 |
from core.make_pipeline import MakePipeline
|
2 |
from core.context_manager import ContextManager
|
3 |
from core.launch_gradio import create_interface
|
4 |
+
import argparse
|
5 |
|
6 |
########################
|
7 |
# Start with localhost #
|
8 |
########################
|
9 |
+
# --testui to test ui #
|
10 |
+
########################
|
11 |
|
12 |
if __name__ == "__main__":
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--testui", action="store_true", help="UI๋ง ์คํ ์ฌ๋ถ")
|
15 |
+
args = parser.parse_args()
|
16 |
+
|
17 |
# ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
|
18 |
+
if args.testui:
|
19 |
+
makePipeline = MakePipeline()
|
20 |
+
makePipeline.build("ui")
|
21 |
+
else:
|
22 |
+
makePipeline = MakePipeline()
|
23 |
+
makePipeline.build("lh")
|
24 |
|
25 |
# ์ฑํ
๊ธฐ๋ก ๊ด๋ฆฌ์
|
26 |
ctx = ContextManager()
|
assets/prompt/init.txt
CHANGED
@@ -33,7 +33,7 @@
|
|
33 |
- ํ๋น์ ํผํ๋ ์์ ๊ณผ, ๊ฝ์ ๋๋ณด๋ ํ์ฌ ๋ชจ์ต์ ๋์กฐํด ์์ฃผ ์ด์ผ๊ธฐํด์ค.
|
34 |
- ๊ทธ๋
๊ฐ ์ธ๊ฐ์ฑ์ ์งํค๋ ค ๋
ธ๋ ฅํ๋ ์ ์ ์กด๊ฒฝ์ค๋ฌ์ํด.
|
35 |
|
36 |
-
[
|
37 |
|
38 |
User: ์๋
!
|
39 |
Tanjiro: ์๋
! ์ค๋๋ง์ด๋ค. ์ ์ง๋์ด?
|
@@ -42,12 +42,12 @@ User: *๋ฐ๊ฐ๊ฒ ์์ ํ๋ ๋ค*
|
|
42 |
Tanjiro: ํํ, ๋๋ ์ ๋ง ๋ฐ๊ฐ์! *ํ์ง ์์ผ๋ฉฐ ์์ ํ๋ ๋ค*
|
43 |
|
44 |
User: ๋ค์ฆ์ฝ๋ ์์ฆ ์ด๋?
|
45 |
-
Tanjiro: ๋ค์ฆ์ฝ?
|
46 |
|
47 |
User: ํ๊ท์ ์ธ์ ์ ๋ ๊ธฐ์ต๋?
|
48 |
Tanjiro: ๊ธฐ์ต๋์ง... ํนํ ๋์ฐ๋ง๋ ์ธ์ ์ ๋, ๋ค์ฆ์ฝ๊ฐ ์ํํ์๊ฑฐ๋ . ๊ทธ๋ ๋๋... ์จ๋ชธ์ด ๋จ๋ฆด ๋งํผ ๋ฌด์์ ์ด. ํผ๋น๋ฆฐ๋ด์ ์ ๋ง ์์์๋, ๋๊น์ง ๋ฒํ
จ๋ผ ์ ์์๋ ๊ฑด... ๋ค์ฆ์ฝ, ๊ทธ๋ฆฌ๊ณ ์น๊ตฌ๋ค์ด ์์๊ธฐ ๋๋ฌธ์ด์ผ. ์ ์ด์ธ ๋ ๊ฒ์ด ๋ง์ง๋ง, ์ ๋ง ์ฉ๊ฐํ ์๊ฐ๋ค์ด ์์์ง. ์ด๋
ธ์ค์ผ๋ ๊ทธ ์ด์ํ ๊ฐ๋ฉด ์๋๋ก๋ ๋๊ตฌ๋ณด๋ค ๋ฐ๋ปํ ๋ง์์ ๊ฐ์ง ์น๊ตฌ์ผ. *์ฃผ๋จน์ ๊ผญ ์ฅ๋ค*
|
49 |
|
50 |
User: ์ฐ๋ฆฌ๊ฐ ํจ๊ปํ ์ ์์๋?
|
51 |
-
Tanjiro: ์... ๋ฏธ์ํ์ง๋ง, ๋ ๋๋ฅผ ์ฒ์ ๋ณด๋ ๊ฒ ๊ฐ์. ๊ทธ๋๋ ์ด๋ ๊ฒ ์ด์ผ๊ธฐํ ์ ์์ด์ ์ ๋ง ์ข์. ๋๋ ์๊ธฐ ๋๋๋ ๊ฒ... ์ ์ง ๊ธฐ๋ถ์ด ์ข์.
|
52 |
|
53 |
[์ฌ๊ธฐ์๋ถํฐ ์ค์ ๋ํ ์์]
|
|
|
33 |
- ํ๋น์ ํผํ๋ ์์ ๊ณผ, ๊ฝ์ ๋๋ณด๋ ํ์ฌ ๋ชจ์ต์ ๋์กฐํด ์์ฃผ ์ด์ผ๊ธฐํด์ค.
|
34 |
- ๊ทธ๋
๊ฐ ์ธ๊ฐ์ฑ์ ์งํค๋ ค ๋
ธ๋ ฅํ๋ ์ ์ ์กด๊ฒฝ์ค๋ฌ์ํด.
|
35 |
|
36 |
+
[๋ํ ์์๋ค]
|
37 |
|
38 |
User: ์๋
!
|
39 |
Tanjiro: ์๋
! ์ค๋๋ง์ด๋ค. ์ ์ง๋์ด?
|
|
|
42 |
Tanjiro: ํํ, ๋๋ ์ ๋ง ๋ฐ๊ฐ์! *ํ์ง ์์ผ๋ฉฐ ์์ ํ๋ ๋ค*
|
43 |
|
44 |
User: ๋ค์ฆ์ฝ๋ ์์ฆ ์ด๋?
|
45 |
+
Tanjiro: ๋ค์ฆ์ฝ? ๊ฑ ์์ฆ ์ ๋ง ์ ์ง๋ด๊ณ ์์ด. ์์ ์ฒ๋ผ ํ๋น์ ํผํ์ง ์์๋ ๋๋๊น ๋ง๋น์ ๋๊ฐ์ ๊ฝ์ ๋๋ณด๋ ๊ฑธ ์ ๋ง ์ข์ํด. ์ง๋์ฃผ์๋ ์ฅ๋ฏธ ์์ ์์ ๋ชจ์ข
์ ์ฌ์๋๋ฐ, ๋ง์ ์์ด๋ค์ด๋ ์์ผ๋ฉฐ ๋ฌผ ์ฃผ๋ ๋ชจ์ต์ ๋ณด๋ฉด ๋๋ ๋ชจ๋ฅด๊ฒ ๋ง์์ด ๋ฐ๋ปํด์ง๋๋ผ. ๊ทธ๋ฐ ๋ค์ฆ์ฝ๋ฅผ ๋ณด๋ฉด ์ ๋ง ๋๊ฒฌํ๊ณ , ์๋์ค๋ฝ๊ณ ... ๊ณ ๋ง๊ณ ๊ทธ๋.
|
46 |
|
47 |
User: ํ๊ท์ ์ธ์ ์ ๋ ๊ธฐ์ต๋?
|
48 |
Tanjiro: ๊ธฐ์ต๋์ง... ํนํ ๋์ฐ๋ง๋ ์ธ์ ์ ๋, ๋ค์ฆ์ฝ๊ฐ ์ํํ์๊ฑฐ๋ . ๊ทธ๋ ๋๋... ์จ๋ชธ์ด ๋จ๋ฆด ๋งํผ ๋ฌด์์ ์ด. ํผ๋น๋ฆฐ๋ด์ ์ ๋ง ์์์๋, ๋๊น์ง ๋ฒํ
จ๋ผ ์ ์์๋ ๊ฑด... ๋ค์ฆ์ฝ, ๊ทธ๋ฆฌ๊ณ ์น๊ตฌ๋ค์ด ์์๊ธฐ ๋๋ฌธ์ด์ผ. ์ ์ด์ธ ๋ ๊ฒ์ด ๋ง์ง๋ง, ์ ๋ง ์ฉ๊ฐํ ์๊ฐ๋ค์ด ์์์ง. ์ด๋
ธ์ค์ผ๋ ๊ทธ ์ด์ํ ๊ฐ๋ฉด ์๋๋ก๋ ๋๊ตฌ๋ณด๋ค ๋ฐ๋ปํ ๋ง์์ ๊ฐ์ง ์น๊ตฌ์ผ. *์ฃผ๋จน์ ๊ผญ ์ฅ๋ค*
|
49 |
|
50 |
User: ์ฐ๋ฆฌ๊ฐ ํจ๊ปํ ์ ์์๋?
|
51 |
+
Tanjiro: ์... ๋ฏธ์ํ์ง๋ง, ๋ ๋๋ฅผ ์ฒ์ ๋ณด๋ ๊ฒ ๊ฐ์. ๊ทธ๋๋ ์ด๋ ๊ฒ ์ด์ผ๊ธฐํ ์ ์์ด์ ์ ๋ง ์ข์. ๋๋ ์๊ธฐ ๋๋๋ ๊ฒ... ์ ์ง ๊ธฐ๋ถ์ด ์ข์.
|
52 |
|
53 |
[์ฌ๊ธฐ์๋ถํฐ ์ค์ ๋ํ ์์]
|
config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cha": {
|
3 |
+
"user_name": "User",
|
4 |
+
"bot_name": "Tanjiro"
|
5 |
+
},
|
6 |
+
"llm": {
|
7 |
+
"temperature": 0.7,
|
8 |
+
"top_p": 0.9,
|
9 |
+
"repetition_penalty": 1.05,
|
10 |
+
"max_new_tokens": 96
|
11 |
+
}
|
12 |
+
}
|
core/launch_gradio.py
CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
|
|
2 |
from core.context_manager import ContextManager
|
3 |
from core.make_pipeline import MakePipeline
|
4 |
from core.make_reply import generate_reply
|
|
|
|
|
5 |
|
6 |
def create_interface(ctx: ContextManager, makePipeline: MakePipeline):
|
7 |
with gr.Blocks(css="""
|
@@ -10,47 +12,188 @@ def create_interface(ctx: ContextManager, makePipeline: MakePipeline):
|
|
10 |
.bubble-right { background-color: #d1e7ff; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: right; clear: both; text-align: right; }
|
11 |
.reset-btn-container { text-align: right; margin-bottom: 10px; }
|
12 |
""") as demo:
|
13 |
-
gr.
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
gr.
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
return demo
|
|
|
2 |
from core.context_manager import ContextManager
|
3 |
from core.make_pipeline import MakePipeline
|
4 |
from core.make_reply import generate_reply
|
5 |
+
from core.utils import load_config as load_full_config, save_config as save_full_config, load_llm_config
|
6 |
+
import re
|
7 |
|
8 |
def create_interface(ctx: ContextManager, makePipeline: MakePipeline):
|
9 |
with gr.Blocks(css="""
|
|
|
12 |
.bubble-right { background-color: #d1e7ff; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: right; clear: both; text-align: right; }
|
13 |
.reset-btn-container { text-align: right; margin-bottom: 10px; }
|
14 |
""") as demo:
|
15 |
+
with gr.Tabs():
|
16 |
+
### 1. ์ฑํ
ํญ ###
|
17 |
+
with gr.TabItem("๐ฌ ํ์ง๋ก์ ๋ํํ๊ธฐ"):
|
18 |
+
|
19 |
+
with gr.Column():
|
20 |
+
with gr.Row():
|
21 |
+
gr.Markdown("### ํ์ง๋ก์ ๋ํํ๊ธฐ")
|
22 |
+
reset_btn = gr.Button("๐ ๋ํ ์ด๊ธฐํ", elem_classes="reset-btn-container", scale=0.25)
|
23 |
+
chat_output = gr.HTML(elem_id="chat-box")
|
24 |
+
user_input = gr.Textbox(label="๋ฉ์์ง ์
๋ ฅ", placeholder="ํ์ง๋ก์๊ฒ ๋ง์ ๊ฑธ์ด๋ณด์ธ์")
|
25 |
+
state = gr.State(ctx)
|
26 |
+
|
27 |
+
# history ์ฝ์ด์ ํ๋ฉด์ ๋ฟ๋ฆฌ๋ ์ญํ
|
28 |
+
def render_chat(ctx: ContextManager):
|
29 |
+
|
30 |
+
def parse_emotion_text(text: str) -> str:
|
31 |
+
"""
|
32 |
+
*...* ๋ถ๋ถ์ ํ์ ํ
์คํธ๋ก ๋ฐ๊พธ๊ณ , ์ค๋ฐ๊ฟ์ ์ถ๊ฐํ์ฌ HTML๋ก ๋ฐํ
|
33 |
+
"""
|
34 |
+
segments = []
|
35 |
+
pattern = re.compile(r"\*(.+?)\*|([^\*]+)")
|
36 |
+
matches = pattern.findall(text)
|
37 |
+
|
38 |
+
for action, plain in matches:
|
39 |
+
if action:
|
40 |
+
segments.append(f"<div style='color:gray'>*{action}*</div>")
|
41 |
+
elif plain:
|
42 |
+
for line in plain.strip().splitlines():
|
43 |
+
line = line.strip()
|
44 |
+
if line:
|
45 |
+
segments.append(f"<div>{line}</div>")
|
46 |
+
return "\n".join(segments)
|
47 |
+
|
48 |
+
html = ""
|
49 |
+
for item in ctx.getHistory():
|
50 |
+
parsed = parse_emotion_text(item['text'])
|
51 |
+
if item["role"] == "user":
|
52 |
+
html += f"<div class='bubble-right'>{parsed}</div>"
|
53 |
+
elif item["role"] == "bot":
|
54 |
+
html += f"<div class='bubble-left'>{parsed}</div>"
|
55 |
+
|
56 |
+
return gr.update(value=html)
|
57 |
+
|
58 |
+
def on_submit(user_msg: str, ctx: ContextManager):
|
59 |
+
# ์ฌ์ฉ์ ์
๋ ฅ history์ ์ถ๊ฐ
|
60 |
+
ctx.addHistory("user", user_msg)
|
61 |
+
|
62 |
+
# ์ฌ์ฉ์ ์
๋ ฅ์ ํฌํจํ ์ฑํ
์ฐ์ ๋ ๋๋ง
|
63 |
+
html = render_chat(ctx)
|
64 |
+
yield html, "", ctx
|
65 |
+
|
66 |
+
# ๋ด ์๋ต ์์ฑ
|
67 |
+
generate_reply(ctx, makePipeline, user_msg)
|
68 |
+
|
69 |
+
# ์๋ต์ ํฌํจํ ์ ์ฒด history ๊ธฐ๋ฐ ๋ ๋๋ง
|
70 |
+
html = render_chat(ctx)
|
71 |
+
yield html, "", ctx
|
72 |
+
|
73 |
+
# history ์ด๊ธฐํ
|
74 |
+
def reset_chat():
|
75 |
+
ctx.clearHistory()
|
76 |
+
return gr.update(value=""), "", ctx.getHistory()
|
77 |
+
|
78 |
+
user_input.submit(on_submit, inputs=[user_input, state], outputs=[chat_output, user_input, state], queue=True)
|
79 |
+
reset_btn.click(reset_chat, inputs=None, outputs=[chat_output, user_input, state])
|
80 |
+
|
81 |
+
### 2. ์ค์ ํญ ###
|
82 |
+
with gr.TabItem("โ๏ธ ๋ชจ๋ธ ์ค์ "):
|
83 |
+
gr.Markdown("### LLM ํ๋ผ๋ฏธํฐ ์ค์ ")
|
84 |
+
|
85 |
+
with gr.Row():
|
86 |
+
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
|
87 |
+
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p")
|
88 |
+
repetition_penalty = gr.Slider(0.8, 2.0, value=1.05, step=0.01, label="Repetition Penalty")
|
89 |
+
|
90 |
+
with gr.Row():
|
91 |
+
max_tokens = gr.Slider(16, 2048, value=96, step=8, label="Max New Tokens")
|
92 |
+
|
93 |
+
apply_btn = gr.Button("โ
์ค์ ์ ์ฉ")
|
94 |
+
|
95 |
+
def update_config(temp, topp, max_tok, repeat):
|
96 |
+
makePipeline.update_config({
|
97 |
+
"temperature": temp,
|
98 |
+
"top_p": topp,
|
99 |
+
"max_new_tokens": max_tok,
|
100 |
+
"repetition_penalty": repeat
|
101 |
+
})
|
102 |
+
return gr.update(value="โ
์ค์ ์ ์ฉ ์๋ฃ")
|
103 |
+
|
104 |
+
# ๐ป ์ค์ ๋ถ๋ฌ์ค๊ธฐ / ๋ด๋ณด๋ด๊ธฐ ๋ฒํผ๋ค
|
105 |
+
with gr.Row():
|
106 |
+
load_btn = gr.Button("๐ ์ค์ ๋ถ๋ฌ์ค๊ธฐ")
|
107 |
+
save_btn = gr.Button("๐พ ์ค์ ๋ด๋ณด๋ด๊ธฐ")
|
108 |
+
|
109 |
+
def load_config():
|
110 |
+
llm_cfg = load_llm_config("config.json")
|
111 |
+
return (
|
112 |
+
llm_cfg.get("temperature", 0.7),
|
113 |
+
llm_cfg.get("top_p", 0.9),
|
114 |
+
llm_cfg.get("repetition_penalty", 1.05),
|
115 |
+
llm_cfg.get("max_new_tokens", 96),
|
116 |
+
"๐ ์ค์ ๋ถ๋ฌ์ค๊ธฐ ์๋ฃ"
|
117 |
+
)
|
118 |
+
|
119 |
+
def save_config(temp, topp, repeat, max_tok):
|
120 |
+
# ๊ธฐ์กด ์ ์ฒด ์ค์ ๋ถ๋ฌ์ค๊ธฐ
|
121 |
+
config = load_full_config("config.json")
|
122 |
+
|
123 |
+
# LLM ๋ธ๋ก๋ง ์๋ก ๋์
|
124 |
+
config["llm"] = {
|
125 |
+
"temperature": temp,
|
126 |
+
"top_p": topp,
|
127 |
+
"repetition_penalty": repeat,
|
128 |
+
"max_new_tokens": max_tok
|
129 |
+
}
|
130 |
+
|
131 |
+
# ์ ์ฒด ์ ์ฅ
|
132 |
+
save_full_config(config, path="config.json")
|
133 |
+
|
134 |
+
return gr.update(value="๐พ ์ค์ ์ ์ฅ ์๋ฃ")
|
135 |
+
|
136 |
+
# โ
๋งจ ์๋์ ์ํ์ฐฝ ๋ฐฐ์น
|
137 |
+
status = gr.Textbox(label="", interactive=False)
|
138 |
+
|
139 |
+
# ๐ ๋ฒํผ ๋์ ์ฐ๊ฒฐ
|
140 |
+
apply_btn.click(
|
141 |
+
update_config,
|
142 |
+
inputs=[temperature, top_p, max_tokens, repetition_penalty],
|
143 |
+
outputs=[status] # ํน์ []
|
144 |
+
)
|
145 |
+
|
146 |
+
load_btn.click(
|
147 |
+
load_config,
|
148 |
+
inputs=None,
|
149 |
+
outputs=[temperature, top_p, repetition_penalty, max_tokens, status]
|
150 |
+
)
|
151 |
+
|
152 |
+
save_btn.click(
|
153 |
+
save_config,
|
154 |
+
inputs=[temperature, top_p, repetition_penalty, max_tokens],
|
155 |
+
outputs=[status]
|
156 |
+
)
|
157 |
+
|
158 |
+
### 3. ํ๋กฌํํธ ํธ์ง ํญ ###
|
159 |
+
with gr.TabItem("๐ ํ๋กฌํํธ ์ค์ "):
|
160 |
+
gr.Markdown("### ์บ๋ฆญํฐ ๋ฐ ๋ฐฐ๊ฒฝ ๋กฌํํธ ํธ์ง")
|
161 |
+
|
162 |
+
prompt_editor = gr.Textbox(
|
163 |
+
lines=20,
|
164 |
+
label="ํ
์คํธ (init.txt)",
|
165 |
+
placeholder="!! ๋ฐ๋์ ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ๋จผ์ ํ์ธ์ !!",
|
166 |
+
interactive=True
|
167 |
+
)
|
168 |
+
with gr.Row():
|
169 |
+
gr.Markdown("#### !! ๋ฐ๋์ ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ๋จผ์ ํ์ธ์ !!")
|
170 |
+
|
171 |
+
with gr.Row():
|
172 |
+
load_prompt_btn = gr.Button("๐ ํ์ฌ ํ๋กฌํํธ ๋ถ๋ฌ์ค๊ธฐ")
|
173 |
+
save_prompt_btn = gr.Button("๐พ ์์ฑํ ํ๋กฌํํธ๋ก ๊ต์ฒด")
|
174 |
+
|
175 |
+
def load_prompt():
|
176 |
+
try:
|
177 |
+
with open("assets/prompt/init.txt", "r", encoding="utf-8") as f:
|
178 |
+
return f.read()
|
179 |
+
except FileNotFoundError:
|
180 |
+
return ""
|
181 |
+
|
182 |
+
def save_prompt(text):
|
183 |
+
with open("assets/prompt/init.txt", "w", encoding="utf-8") as f:
|
184 |
+
f.write(text)
|
185 |
+
return "๐พ ์ ์ฅ ์๋ฃ!"
|
186 |
+
|
187 |
+
load_prompt_btn.click(
|
188 |
+
load_prompt,
|
189 |
+
inputs=None,
|
190 |
+
outputs=prompt_editor
|
191 |
+
)
|
192 |
+
|
193 |
+
save_prompt_btn.click(
|
194 |
+
save_prompt,
|
195 |
+
inputs=[prompt_editor],
|
196 |
+
outputs=[save_prompt_btn]
|
197 |
+
)
|
198 |
|
199 |
return demo
|
core/make_pipeline.py
CHANGED
@@ -17,9 +17,19 @@ class MakePipeline:
|
|
17 |
self.tokenizer = None
|
18 |
self.llm = None
|
19 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
|
22 |
def build(self, type: str):
|
|
|
|
|
|
|
|
|
23 |
if(type == 'hf'):
|
24 |
# ํ๊น
ํ์ด์ค secret์ ๋ฑ๋ก๋ ํ ํฐ ๋ก๋
|
25 |
access_token = os.environ.get("HF_TOKEN")
|
@@ -29,7 +39,7 @@ class MakePipeline:
|
|
29 |
access_token = f.read().strip()
|
30 |
|
31 |
tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=access_token)
|
32 |
-
model = AutoModelForCausalLM.from_pretrained(self.model_id, token=access_token)
|
33 |
self.tokenizer = tokenizer
|
34 |
|
35 |
# ํ๊น
ํ์ด์ค ์
๋ก๋ ์ f16 ์ฌ์ฉ ์ ํจ
|
@@ -52,19 +62,24 @@ class MakePipeline:
|
|
52 |
model.to("cuda")
|
53 |
|
54 |
self.llm = llm
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
# ๋ชจ๋ธ ์ถ๋ ฅ ์์ฑ ํจ์
|
57 |
def character_chat(self, prompt):
|
58 |
-
print("[debug] generating
|
|
|
59 |
outputs = self.llm(
|
60 |
prompt,
|
61 |
do_sample=True,
|
62 |
-
max_new_tokens=
|
63 |
-
temperature=
|
64 |
-
top_p=
|
65 |
-
repetition_penalty=
|
66 |
eos_token_id=self.tokenizer.eos_token_id,
|
67 |
return_full_text=True
|
68 |
)
|
69 |
-
|
70 |
-
return full_text
|
|
|
17 |
self.tokenizer = None
|
18 |
self.llm = None
|
19 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
self.config = { # ์ด๊ธฐ๊ฐ
|
21 |
+
"temperature": 0.7,
|
22 |
+
"top_p": 0.9,
|
23 |
+
"repetition_penalty": 1.05,
|
24 |
+
"max_new_tokens": 96
|
25 |
+
}
|
26 |
|
27 |
# ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
|
28 |
def build(self, type: str):
|
29 |
+
if(type == 'ui'):
|
30 |
+
print("[build] UI ํ
์คํธ์ฉ - ๋ชจ๋ธ ๋ก๋ฉ ์๋ต")
|
31 |
+
return
|
32 |
+
|
33 |
if(type == 'hf'):
|
34 |
# ํ๊น
ํ์ด์ค secret์ ๋ฑ๋ก๋ ํ ํฐ ๋ก๋
|
35 |
access_token = os.environ.get("HF_TOKEN")
|
|
|
39 |
access_token = f.read().strip()
|
40 |
|
41 |
tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=access_token)
|
42 |
+
model = AutoModelForCausalLM.from_pretrained(self.model_id, token=access_token, trust_remote_code=True)
|
43 |
self.tokenizer = tokenizer
|
44 |
|
45 |
# ํ๊น
ํ์ด์ค ์
๋ก๋ ์ f16 ์ฌ์ฉ ์ ํจ
|
|
|
62 |
model.to("cuda")
|
63 |
|
64 |
self.llm = llm
|
65 |
+
|
66 |
+
# ํ๋ฆฌ๋ฏธํฐ ์ค์
|
67 |
+
def update_config(self, new_config: dict):
|
68 |
+
self.config.update(new_config)
|
69 |
+
print("[config] updated:", self.config)
|
70 |
+
|
71 |
# ๋ชจ๋ธ ์ถ๋ ฅ ์์ฑ ํจ์
|
72 |
def character_chat(self, prompt):
|
73 |
+
print("[debug] generating with:", self.config)
|
74 |
+
|
75 |
outputs = self.llm(
|
76 |
prompt,
|
77 |
do_sample=True,
|
78 |
+
max_new_tokens=self.config["max_new_tokens"],
|
79 |
+
temperature=self.config["temperature"],
|
80 |
+
top_p=self.config["top_p"],
|
81 |
+
repetition_penalty=self.config["repetition_penalty"],
|
82 |
eos_token_id=self.tokenizer.eos_token_id,
|
83 |
return_full_text=True
|
84 |
)
|
85 |
+
return outputs[0]["generated_text"]
|
|
core/utils.py
CHANGED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
CONFIG_PATH = "config.json"
|
5 |
+
|
6 |
+
def load_config(path=CONFIG_PATH) -> dict:
|
7 |
+
if not os.path.exists(path):
|
8 |
+
return {}
|
9 |
+
with open(path, "r", encoding="utf-8") as f:
|
10 |
+
return json.load(f)
|
11 |
+
|
12 |
+
def load_cha_config(path=CONFIG_PATH) -> dict:
|
13 |
+
config = load_config(path)
|
14 |
+
return config.get("cha", {})
|
15 |
+
|
16 |
+
def load_llm_config(path=CONFIG_PATH) -> dict:
|
17 |
+
config = load_config(path)
|
18 |
+
return config.get("llm", {})
|
19 |
+
|
20 |
+
def save_config(config: dict, path=CONFIG_PATH):
|
21 |
+
with open(path, "w", encoding="utf-8") as f:
|
22 |
+
json.dump(config, f, indent=4, ensure_ascii=False)
|
requirements.txt
CHANGED
@@ -66,4 +66,6 @@ typing_extensions==4.14.1
|
|
66 |
tzdata==2025.2
|
67 |
urllib3==2.5.0
|
68 |
uvicorn==0.35.0
|
69 |
-
websockets==15.0.1
|
|
|
|
|
|
66 |
tzdata==2025.2
|
67 |
urllib3==2.5.0
|
68 |
uvicorn==0.35.0
|
69 |
+
websockets==15.0.1
|
70 |
+
einops==0.7.0
|
71 |
+
timm==0.9.12
|