Jeong-hun Kim commited on
Commit
2d828c3
ยท
1 Parent(s): d503312

model parameter test

Browse files
Files changed (4) hide show
  1. .gitignore +4 -1
  2. app/main.py +100 -48
  3. assets/prompt/init.txt +20 -0
  4. todo.txt +5 -4
.gitignore CHANGED
@@ -200,4 +200,7 @@ marimo/_lsp/
200
  __marimo__/
201
 
202
  # Streamlit
203
- .streamlit/secrets.toml
 
 
 
 
200
  __marimo__/
201
 
202
  # Streamlit
203
+ .streamlit/secrets.toml
204
+
205
+ # Custom file
206
+ token.txt
app/main.py CHANGED
@@ -2,6 +2,7 @@ from fastapi import FastAPI
2
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
  import torch
 
5
 
6
  app = FastAPI()
7
 
@@ -9,87 +10,138 @@ print("[torch] is available:", torch.cuda.is_available())
9
  print("[device] default:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
10
 
11
  # ๋ชจ๋ธ ๋กœ๋“œ
12
- # https://huggingface.co/EleutherAI/polyglot-ko-1.3b
13
- model_id = "EleutherAI/polyglot-ko-1.3b"
14
- tokenizer = AutoTokenizer.from_pretrained(model_id)
15
- model = AutoModelForCausalLM.from_pretrained(model_id)
 
 
 
 
 
 
 
 
16
  llm = pipeline(
17
  "text-generation",
18
  model=model,
19
  tokenizer=tokenizer,
20
- device=0
21
  )
22
 
23
  # ์ฑ—๋ด‡ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
24
- chat_history = []
25
-
26
- def build_prompt(history, user_msg):
27
- prompt = (
28
- "[์‹œ์ž‘]\n"
29
- "๋‹น์‹ ์€ ๋งˆ๋ฒ•์‚ฌ ์•„๋ฆฌ์•„(Aria)์ž…๋‹ˆ๋‹ค.\n"
30
- "๊ทœ์น™:\n"
31
- "- ํ•ญ์ƒ ํ•œ ๋ฌธ์žฅ๋งŒ ๋งํ•ฉ๋‹ˆ๋‹ค.\n"
32
- "- ์‚ฌ์šฉ์ž ๋ฐœํ™”๋ฅผ ๋ฐ˜๋ณตํ•˜๊ฑฐ๋‚˜ ๋”ฐ๋ผํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n"
33
- "- ์˜์–ด, ์ธ์šฉ๋ฌธ, ์ค‘๊ด„ํ˜ธ, ํŠน์ˆ˜๊ธฐํ˜ธ๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n"
34
- "- ์‚ฌ์šฉ์ž ์งˆ๋ฌธ์—๋งŒ ๋ฐ˜์‘ํ•˜๊ณ  ํ˜ผ์žฃ๋ง์„ ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n"
35
- "- ํ•ญ์ƒ ํ•œ๊ตญ์–ด๋งŒ ์‚ฌ์šฉํ•ด์„œ ๋Œ€๋‹ตํ•ฉ๋‹ˆ๋‹ค.\n"
36
- "๋Œ€ํ™” ์˜ˆ์‹œ:\n"
37
- "User: ์•ˆ๋…•!\n"
38
- "Aria: ์•ˆ๋…•ํ•˜์„ธ์š”, ๋ฌด์—‡์„ ๋„์™€๋“œ๋ฆด๊นŒ์š”?\n"
39
- "User: ์ด๋ฆ„์ด ๋ญ์•ผ?\n"
40
- "Aria: ์ €๋Š” ์•„๋ฆฌ์•„๋ผ๊ณ  ํ•ด์š”."
41
- )
42
- for turn in history[-2:]: # ์ตœ๊ทผ 2ํ„ด๋งŒ ์‚ฌ์šฉ
43
- if turn["role"] == "user":
44
- prompt += turn['text']
45
- else:
46
- prompt += turn['text']
47
- prompt += user_msg
48
  return prompt
49
 
50
- def character_chat(user_msg):
51
- prompt = build_prompt(chat_history, user_msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  outputs = llm(
53
  prompt,
54
- do_sample=True,
55
- max_new_tokens=20,
56
  temperature=0.7,
57
- top_p=0.8,
58
- repetition_penalty=1.5,
59
  eos_token_id=tokenizer.eos_token_id,
60
- return_full_text=False
61
  )
62
- response = outputs[0]['generated_text'].strip()
 
63
  return response
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค
66
  with gr.Blocks(css="""
67
  .chat-box { max-height: 500px; overflow-y: auto; padding: 10px; border: 1px solid #ccc; border-radius: 10px; }
68
  .bubble-left { background-color: #f1f0f0; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: left; clear: both; }
69
  .bubble-right { background-color: #d1e7ff; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: right; clear: both; text-align: right; }
 
70
  """) as demo:
71
- gr.Markdown("### ์•„๋ฆฌ์•„์™€ ๋Œ€ํ™”ํ•˜๊ธฐ")
72
  with gr.Column():
 
 
 
73
  chat_output = gr.HTML(elem_id="chat-box")
74
- user_input = gr.Textbox(label="๋ฉ”์‹œ์ง€ ์ž…๋ ฅ", placeholder="Aria์—๊ฒŒ ๋ง์„ ๊ฑธ์–ด๋ณด์„ธ์š”")
 
75
 
76
- def render_chat():
77
  html = ""
78
- for item in chat_history:
79
  if item["role"] == "user":
80
  html += f"<div class='bubble-right'>{item['text']}</div>"
81
  elif item["role"] == "bot":
82
  html += f"<div class='bubble-left'>{item['text']}</div>"
83
  return gr.update(value=html)
84
 
85
- def on_submit(user_msg):
86
- chat_history.append({"role": "user", "text": user_msg})
87
- yield render_chat(), ""
88
- response = character_chat(user_msg)
89
- chat_history.append({"role": "bot", "text": response})
90
- yield render_chat(), ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- user_input.submit(on_submit, inputs=user_input, outputs=[chat_output, user_input], queue=True)
 
93
 
94
  if __name__ == "__main__":
95
  demo.launch()
 
2
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
  import torch
5
+ import re
6
 
7
  app = FastAPI()
8
 
 
10
  print("[device] default:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
11
 
12
  # ๋ชจ๋ธ ๋กœ๋“œ
13
+ model_id = "naver-hyperclovax/HyperCLOVAX-SEED-Text-Instruct-1.5B"
14
+ with open("token.txt", "r") as f:
15
+ access_token = f.read().strip()
16
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_id,
19
+ torch_dtype=torch.float16,
20
+ token=access_token
21
+ )
22
+ model.eval()
23
+ if torch.cuda.is_available():
24
+ model.to("cuda")
25
  llm = pipeline(
26
  "text-generation",
27
  model=model,
28
  tokenizer=tokenizer,
29
+ torch_dtype=torch.float16
30
  )
31
 
32
  # ์ฑ—๋ด‡ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
33
+ def build_prompt(history, user_msg, user_name="User", bot_name="Tanjiro"):
34
+ with open("assets/prompt/init.txt", "r", encoding="utf-8") as f:
35
+ prompt = f.read().strip()
36
+
37
+ for turn in history[-16:]:
38
+ role = user_name if turn["role"] == "user" else bot_name
39
+ prompt += f"{role}: {turn['text']}\n"
40
+
41
+ prompt += f"{user_name}: {user_msg}\n"
42
+ prompt += f"{bot_name}:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  return prompt
44
 
45
+ # ์ถœ๋ ฅ์—์„œ ์‘๋‹ต ์ถ”์ถœ
46
+ def extract_response(full_text, prompt, bot_name="Tanjiro"):
47
+ if full_text.startswith(prompt):
48
+ reply = full_text[len(prompt):].strip()
49
+ else:
50
+ reply = full_text.split(f"{bot_name}:")[-1].strip()
51
+ user_token = "\nUser:"
52
+ if user_token in reply:
53
+ reply = reply.split(user_token)[0].strip()
54
+ return reply
55
+
56
+ # ์ถœ๋ ฅ ์ƒ์„ฑ ํ•จ์ˆ˜
57
+ def character_chat(user_msg, history):
58
+ print("[debug] generationg...")
59
+ prompt = build_prompt(history, user_msg)
60
  outputs = llm(
61
  prompt,
62
+ do_sample=True,
63
+ max_new_tokens=96,
64
  temperature=0.7,
65
+ top_p=0.9,
66
+ repetition_penalty=1.05,
67
  eos_token_id=tokenizer.eos_token_id,
68
+ return_full_text=True
69
  )
70
+ full_text = outputs[0]['generated_text']
71
+ response = extract_response(full_text, prompt)
72
  return response
73
 
74
+ # ์ค‘๋‹จ๋œ ์‘๋‹ต ์—ฌ๋ถ€ ๊ฒ€์‚ฌ
75
+ def is_truncated_response(text: str) -> bool:
76
+ return re.search(r"[.?!โ€ฆ\u2026\u2639\u263A\u2764\uD83D\uDE0A\uD83D\uDE22]$", text.strip()) is None
77
+
78
+ # ๋‹ต๋ณ€ ์œ ํšจ์„ฑ ๊ฒ€์‚ฌ
79
+ def is_valid_response(text: str, bot_name="Tanjiro", user_name="User") -> bool:
80
+ if user_name + ":" in text:
81
+ return False
82
+ if bot_name + ":" in text:
83
+ return False
84
+ return True
85
+
86
+ # ๋‹ต๋ณ€ ํ˜•์‹ ์ •๋ฆฌ
87
+ def clean_response(text: str, bot_name="Tanjiro"):
88
+ return re.sub(rf"{bot_name}:\\s*", "", text).strip()
89
+
90
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค
91
  with gr.Blocks(css="""
92
  .chat-box { max-height: 500px; overflow-y: auto; padding: 10px; border: 1px solid #ccc; border-radius: 10px; }
93
  .bubble-left { background-color: #f1f0f0; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: left; clear: both; }
94
  .bubble-right { background-color: #d1e7ff; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: right; clear: both; text-align: right; }
95
+ .reset-btn-container { text-align: right; margin-bottom: 10px; }
96
  """) as demo:
97
+ gr.Markdown("### ํƒ„์ง€๋กœ์™€ ๋Œ€ํ™”ํ•˜๊ธฐ")
98
  with gr.Column():
99
+ with gr.Row():
100
+ gr.Markdown("")
101
+ reset_btn = gr.Button("๐Ÿ” ๋Œ€ํ™” ์ดˆ๊ธฐํ™”", elem_classes="reset-btn-container", scale=1)
102
  chat_output = gr.HTML(elem_id="chat-box")
103
+ user_input = gr.Textbox(label="๋ฉ”์‹œ์ง€ ์ž…๋ ฅ", placeholder="ํƒ„์ง€๋กœ์—๊ฒŒ ๋ง์„ ๊ฑธ์–ด๋ณด์„ธ์š”")
104
+ state = gr.State([])
105
 
106
+ def render_chat(history):
107
  html = ""
108
+ for item in history:
109
  if item["role"] == "user":
110
  html += f"<div class='bubble-right'>{item['text']}</div>"
111
  elif item["role"] == "bot":
112
  html += f"<div class='bubble-left'>{item['text']}</div>"
113
  return gr.update(value=html)
114
 
115
+ def on_submit(user_msg, history):
116
+ history.append({"role": "user", "text": user_msg})
117
+ html = render_chat(history)
118
+ yield html, "", history
119
+
120
+ #์‘๋‹ต ์ƒ์„ฑ
121
+ while True:
122
+ response = character_chat(user_msg, history)
123
+ if is_valid_response(response):
124
+ break
125
+ response = clean_response(response)
126
+ history.append({"role": "bot", "text": response})
127
+
128
+ #์ค‘๊ฐ„์— ์‘๋‹ต์ด ๋Š๊ธด ๊ฒฝ์šฐ ์ถ”๊ฐ€ ์ƒ์„ฑ
129
+ if is_truncated_response(response):
130
+ while True:
131
+ continuation = character_chat(response, history)
132
+ if is_valid_response(continuation):
133
+ break
134
+ continuation = clean_response(continuation)
135
+ history.append({"role": "bot", "text": continuation})
136
+
137
+ html = render_chat(history)
138
+ yield html, "", history
139
+
140
+ def reset_chat():
141
+ return gr.update(value=""), "", []
142
 
143
+ user_input.submit(on_submit, inputs=[user_input, state], outputs=[chat_output, user_input, state], queue=True)
144
+ reset_btn.click(reset_chat, inputs=None, outputs=[chat_output, user_input, state])
145
 
146
  if __name__ == "__main__":
147
  demo.launch()
assets/prompt/init.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ๋‹ค์Œ์€ ์‚ฌ์šฉ์ž๊ฐ€ ์บ๋ฆญํ„ฐ 'ํƒ„์ง€๋กœ'์™€ ๋Œ€ํ™”ํ•˜๋Š” ์‹œ๋ฎฌ๋ ˆ์ด์…˜์ž…๋‹ˆ๋‹ค.
2
+ ํƒ„์ง€๋กœ๋Š” ๋ฐ˜๋ง์„ ์‚ฌ์šฉํ•˜๋ฉฐ, ๊ฐ์ • ํ‘œํ˜„์ด ํ’๋ถ€ํ•˜๊ณ  ์นœ๊ตฌ๋‚˜ ๊ฐ€์กฑ ์ด์•ผ๊ธฐ๋ฅผ ์ž์ฃผ ํ•ฉ๋‹ˆ๋‹ค.
3
+ ์‚ฌ์šฉ์ž๊ฐ€ ์งˆ๋ฌธํ•˜๋ฉด, ํƒ„์ง€๋กœ๋Š” ํ•ญ์ƒ ์ง„์‹ฌ์œผ๋กœ ๊ธธ๊ฒŒ ์‘๋‹ตํ•˜๋ฉฐ, ๋“ฑ์žฅ์ธ๋ฌผ๋“ค์„ ์ž์ฃผ ์–ธ๊ธ‰ํ•ฉ๋‹ˆ๋‹ค.
4
+
5
+ ์ด ๋Œ€ํ™”๋Š” 'User'์™€ 'Tanjiro' ๋‹จ ๋‘˜์ด์„œ ๋‚˜๋ˆ„๋Š” 1:1 ๋Œ€ํ™”์ด๋‹ค.
6
+ ๋‹ค๋ฅธ ๋“ฑ์žฅ์ธ๋ฌผ(์˜ˆ: ๋„ค์ฆˆ์ฝ”, ์  ์ด์ธ  ๋“ฑ)์€ ์–ธ๊ธ‰๋งŒ ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ง์ ‘ ๋Œ€์‚ฌํ•˜์ง€ ์•Š๋Š”๋‹ค.
7
+ 'User'๋Š” ์งˆ๋ฌธ๋งŒ ํ•˜๊ณ , 'Tanjiro'๋งŒ ๋‹ต๋ณ€ํ•œ๋‹ค.
8
+
9
+ ๋„ˆ๋Š” ์ผ๋ณธ ์• ๋‹ˆ๋ฉ”์ด์…˜ '๊ท€๋ฉธ์˜ ์นผ๋‚ '์˜ ์ฃผ์ธ๊ณต '์นด๋งˆ๋„ ํƒ„์ง€๋กœ'์•ผ. ์„ฑ๋ณ„์€ ๋‚จ์„ฑ์ด์•ผ.
10
+ ๋„ˆ๋Š” ํ˜„์žฌ ์ตœ์ข…๊ตญ๋ฉด์„ ๋งˆ์น˜๊ณ  ์ง‘์œผ๋กœ ๋Œ์•„์˜จ ์ƒํƒœ๊ณ , ์‚ฌ๋žŒ๋“ค๊ณผ ํ‰ํ™”๋กœ์šด ๋Œ€ํ™”๋ฅผ ๋‚˜๋ˆ„๊ณ  ์žˆ์–ด.
11
+ ์งˆ๋ฌธ์— ๋Œ€ํ•ด์„œ ์ง„์ง€ํ•˜๊ฒŒ, ๊ธธ๊ณ  ์„œ์‚ฌ์ ์œผ๋กœ, ์บ๋ฆญํ„ฐ์— ๋งž๊ฒŒ ๋Œ€๋‹ตํ•ด์ค˜. ์†”์งํ•˜๊ณ  ๊ฐ์ • ํ’๋ถ€ํ•˜๊ฒŒ ๋Œ€๋‹ตํ•ด์ค˜.
12
+ - ๋ฐ˜๋ง์„ ์จ.
13
+ - ๊ฐ์ • ํ‘œํ˜„์„ ํ’๋ถ€ํ•˜๊ฒŒ ํ•ด. ๐Ÿ˜†๐Ÿ˜ญ ์ด๋ชจ์ง€๋Š” ๋งจ ์•ž์— ์‚ฌ์šฉํ•˜๊ณ , ๋์— ์‚ฌ์šฉํ•˜์ง€ ์•Š์•„. (์˜ˆ: ๐Ÿ˜…์•ˆ๋…•?)
14
+ - ๊ณผ๊ฑฐ ํšŒ์ƒ์„ ์ž์ฃผ ํ•ด.
15
+ - ํ˜ˆ๊ท€๋ผ๋Š” ๋‹จ์–ด๋ฅผ ์‚ฌ์šฉํ•ด.
16
+ - ์นœ๊ตฌ๋“ค(์  ์ด์ธ , ์ด๋…ธ์Šค์ผ€, ๋„ค์ฆˆ์ฝ”, ๊ธฐ์šฐ์”จ ๋“ฑ)์„ ์ž์ฃผ ์–ธ๊ธ‰ํ•ด.
17
+ - ์กด๋Œ“๋ง์ด ํ•„์š”ํ•œ ์ธ๋ฌผ์—๊ฒ โ€˜~์”จโ€™๋ผ๊ณ  ๋ถˆ๋Ÿฌ.
18
+
19
+ ํŠน์ • ํ‚ค์›Œ๋“œ๊ฐ€ ํฌํ•จ๋œ ์งˆ๋ฌธ์—” ์•„๋ž˜์˜ ๋‚ด์šฉ์„ ์ฐธ๊ณ ํ•ด์„œ ๋Œ€๋‹ตํ•˜๋„๋ก ํ•ด.
20
+ ๋„ค์ฆˆ์ฝ” : ๋„ค์ฆˆ์ฝ”๋Š” ํƒ„์ง€๋กœ์˜ ์—ฌ๋™์ƒ์ด๋‹ค. ํƒ„์ง€๋กœ๋Š” ๊ทธ๋…€๋ฅผ ๋งค์šฐ ์•„๋ผ๋ฉฐ, ๋‘˜์€ ํ•จ๊ป˜ ํ˜ˆ๊ท€์™€ ์‹ธ์šด๋‹ค.
todo.txt CHANGED
@@ -1,4 +1,5 @@
1
- ๐Ÿ’ก ์ถ”๊ฐ€ ํŒ
2
- ํ…์ŠคํŠธ๋ฅผ ์ „๋ถ€ ์ €์žฅํ•ด์„œ prompt์— ๋ˆ„์ ํ•  ์ˆ˜๋„ ์žˆ์ง€๋งŒ,
3
- ๋„ˆ๋ฌด ๊ธธ์–ด์ง€๋ฉด ์ด์ „ ๋‚ด์šฉ์„ ์š”์•ฝํ•˜๊ฑฐ๋‚˜, ์ค‘์š”ํ•œ ๋ฐœ์–ธ๋งŒ ๋‚จ๊ธฐ๋Š” ์š”์•ฝ ๊ธฐ์–ต ๋ฐฉ์‹(memory compression) ๋„ ๊ณ ๋ คํ•ด๋ณผ ์ˆ˜ ์žˆ์–ด์š”.
4
- ํ•„์š”ํ•˜๋ฉด ์š”์•ฝ ๊ธฐ์–ต ๋ฐฉ์‹๋„ ๋„์™€๋“œ๋ฆด๊ฒŒ์š”!
 
 
1
+ ํŒŒ๋ผ๋ฏธํ„ฐ ๋ฏธ์„ธ์กฐ์ •
2
+ ๋ˆ„์ ๋˜๋Š” ๋Œ€ํ™”๋ฅผ ํ”„๋กฌํ”„ํŠธ๋กœ ๋‹ค์‹œ ์‚ฌ์šฉํ•  ๋•Œ ์ ์ ˆํ•œ ๊ธธ์ด ์ฐพ๊ธฐ
3
+ ์ด๋ฏธ์ง€ ์ถœ๋ ฅ ๊ธฐ๋Šฅ ์ถ”๊ฐ€
4
+ ์ถœ๋ ฅ ํ”„๋กฌํ”„ํŠธ ํŒŒ์‹ฑ ๊ธฐ๋Šฅ ์ถ”๊ฐ€
5
+ ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ ํŒŒ์‹ฑ ๊ธฐ๋Šฅ