m97j commited on
Commit
0fc77f3
·
1 Parent(s): a6b8e8d

Initial Gradio app for HF Space

Browse files
Files changed (8) hide show
  1. app.py +45 -0
  2. config.py +17 -0
  3. flags.json +11 -0
  4. inference.py +59 -0
  5. model_loader.py +43 -0
  6. readme.md +257 -0
  7. requirements.txt +7 -0
  8. utils_prompt.py +81 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import run_inference, reload_model # reload_model은 모델 재로딩 함수
3
+ from utils_prompt import build_webtest_prompt
4
+
5
+ def gradio_infer(npc_id, npc_location, player_utt):
6
+ prompt = build_webtest_prompt(npc_id, npc_location, player_utt)
7
+ result = run_inference(prompt)
8
+ return result["npc_output_text"], result["deltas"], result["flags_prob"]
9
+
10
+
11
+ # API 호출용
12
+ def api_infer(session_id, npc_id, prompt, max_tokens=200):
13
+ result = run_inference(prompt)
14
+ return {
15
+ "session_id": session_id,
16
+ "npc_id": npc_id,
17
+ "npc_response": result["npc_output_text"],
18
+ "deltas": result["deltas"],
19
+ "flags": result["flags_prob"],
20
+ "thresholds": result["flags_thr"]
21
+ }
22
+
23
+ # Colab에서 호출할 ping endpoint
24
+ def ping_reload():
25
+ reload_model(branch="latest") # latest 브랜치에서 재다운로드 & 로드
26
+ return {"status": "reloaded"}
27
+
28
+ with gr.Blocks() as demo:
29
+ gr.Markdown("## NPC Main Model Inference")
30
+
31
+ with gr.Tab("Web Test UI"):
32
+ npc_id = gr.Textbox(label="NPC ID")
33
+ npc_loc = gr.Textbox(label="NPC Location")
34
+ player_utt = gr.Textbox(label="Player Utterance")
35
+ npc_resp = gr.Textbox(label="NPC Response")
36
+ deltas = gr.JSON(label="Deltas")
37
+ flags = gr.JSON(label="Flags Probabilities")
38
+ btn = gr.Button("Run Inference")
39
+ btn.click(fn=gradio_infer, inputs=[npc_id, npc_loc, player_utt], outputs=[npc_resp, deltas, flags])
40
+
41
+ demo.add_api_route("/predict_main", api_infer, methods=["POST"], api_name="predict_main")
42
+ demo.add_api_route("/ping_reload", lambda: ping_reload(), methods=["POST"], api_name="ping_reload")
43
+
44
+ if __name__ == "__main__":
45
+ demo.launch(server_name="0.0.0.0", server_port=7860)
config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # 모델 경로
4
+ BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
5
+ ADAPTER_MODEL = "m97j/npc-LoRA-fps"
6
+
7
+ # 장치 설정
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # 토크나이저/모델 공통
11
+ MAX_LENGTH = 1024
12
+ NUM_FLAGS = 7 # flags.json 길이와 일치
13
+
14
+ # 생성 파라미터
15
+ GEN_MAX_NEW_TOKENS = 200
16
+ GEN_TEMPERATURE = 0.7
17
+ GEN_TOP_P = 0.9
flags.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ALL_FLAGS": [
3
+ "give_item",
4
+ "end_npc_main_story",
5
+ "quest_stage_change",
6
+ "change_game_state",
7
+ "change_player_state",
8
+ "npc_action",
9
+ "unlock_hidden_path"
10
+ ]
11
+ }
inference.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from config import DEVICE, MAX_LENGTH, GEN_MAX_NEW_TOKENS, GEN_TEMPERATURE, GEN_TOP_P
3
+ from model_loader import ModelWrapper
4
+
5
+ # 전역 로드 (서버 시작 시 1회)
6
+ wrapper = ModelWrapper()
7
+ tokenizer, model, flags_order = wrapper.get()
8
+
9
+ GEN_PARAMS = {
10
+ "max_new_tokens": GEN_MAX_NEW_TOKENS,
11
+ "temperature": GEN_TEMPERATURE,
12
+ "top_p": GEN_TOP_P,
13
+ "do_sample": True,
14
+ "repetition_penalty": 1.05,
15
+ }
16
+
17
+ def run_inference(prompt: str):
18
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH).to(DEVICE)
19
+
20
+ with torch.no_grad():
21
+ gen_ids = model.generate(**inputs, **GEN_PARAMS)
22
+ generated_text = tokenizer.decode(
23
+ gen_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
24
+ )
25
+
26
+ outputs = model(**inputs, output_hidden_states=True)
27
+ h = outputs.hidden_states[-1]
28
+
29
+ STATE_ID = tokenizer.convert_tokens_to_ids("<STATE>")
30
+ ids = inputs["input_ids"]
31
+ mask = (ids == STATE_ID).unsqueeze(-1)
32
+ if mask.any():
33
+ counts = mask.sum(dim=1).clamp_min(1)
34
+ pooled = (h * mask).sum(dim=1) / counts
35
+ else:
36
+ pooled = h[:, -1, :]
37
+
38
+ delta_pred = torch.tanh(model.delta_head(pooled))[0].cpu().tolist()
39
+ flag_prob = torch.sigmoid(model.flag_head(pooled))[0].cpu().tolist()
40
+ flag_thr = torch.sigmoid(model.flag_threshold_head(pooled))[0].cpu().tolist()
41
+
42
+ flags_prob_dict = {name: round(prob, 6) for name, prob in zip(flags_order, flag_prob)}
43
+ flags_thr_dict = {name: round(thr, 6) for name, thr in zip(flags_order, flag_thr)}
44
+
45
+ return {
46
+ "npc_output_text": generated_text.strip(),
47
+ "deltas": {
48
+ "trust": float(delta_pred[0]),
49
+ "relationship": float(delta_pred[1]),
50
+ },
51
+ "flags_prob": flags_prob_dict,
52
+ "flags_thr": flags_thr_dict,
53
+ }
54
+
55
+ def reload_model(branch="latest"):
56
+ global wrapper, tokenizer, model, flags_order
57
+ wrapper = ModelWrapper(branch=branch) # branch 인자로 latest 전달
58
+ tokenizer, model, flags_order = wrapper.get()
59
+ print(f"Model reloaded from branch: {branch}")
model_loader.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from peft import PeftModel
5
+ from config import BASE_MODEL, ADAPTER_MODEL, DEVICE
6
+
7
+ def get_current_branch():
8
+ if os.path.exists("current_branch.txt"):
9
+ with open("current_branch.txt", "r") as f:
10
+ return f.read().strip()
11
+ return "latest" # fallback
12
+
13
+ class ModelWrapper:
14
+ def __init__(self):
15
+ flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
16
+ self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
17
+ self.num_flags = len(self.flags_order)
18
+
19
+ self.tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL, use_fast=True)
20
+ if self.tokenizer.pad_token is None:
21
+ self.tokenizer.pad_token = self.tokenizer.eos_token
22
+ self.tokenizer.padding_side = "right"
23
+
24
+ branch = get_current_branch()
25
+ base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto", trust_remote_code=True)
26
+ self.model = PeftModel.from_pretrained(base, ADAPTER_MODEL, revision=branch, device_map="auto")
27
+
28
+ hidden_size = self.model.config.hidden_size
29
+ self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
30
+ self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
31
+ self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
32
+
33
+ if os.path.exists("delta_head.pt"):
34
+ self.model.delta_head.load_state_dict(torch.load("delta_head.pt", map_location=DEVICE))
35
+ if os.path.exists("flag_head.pt"):
36
+ self.model.flag_head.load_state_dict(torch.load("flag_head.pt", map_location=DEVICE))
37
+ if os.path.exists("flag_threshold_head.pt"):
38
+ self.model.flag_threshold_head.load_state_dict(torch.load("flag_threshold_head.pt", map_location=DEVICE))
39
+
40
+ self.model.eval()
41
+
42
+ def get(self):
43
+ return self.tokenizer, self.model, self.flags_order
readme.md ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: NPC Main Model Inference Server
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ python_version: 3.10
9
+ app_file: app.py
10
+ ---
11
+
12
+ # NPC 메인 모델 추론 서버 (hf-serve)
13
+
14
+ 이 Space는 **NPC 대화 메인 모델**의 추론 API와 간단한 Gradio UI를 제공합니다.
15
+ Hugging Face Hub에 업로드된
16
+ [Base model](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct)과
17
+ [LoRA adapter model](https://huggingface.co/m97j/npc_LoRA-fps)을 로드하여,
18
+ 플레이어 발화와 게임 상태를 기반으로 NPC의 응답, 감정 변화량(delta),
19
+ 플래그 확률/임계값을 예측합니다.
20
+
21
+ ---
22
+
23
+ ## 🚀 주요 기능
24
+ - **API 엔드포인트** `/predict_main`
25
+ - JSON payload로 prompt를 받아 모델 추론 결과 반환
26
+ - **웹 UI** `/ui`
27
+ - NPC ID, 위치, 플레이어 발화를 입력해 실시간 응답 확인
28
+ - **커스텀 헤드 예측**
29
+ - `delta_head`: trust / relationship 변화량
30
+ - `flag_head`: 각 flag의 확률
31
+ - `flag_threshold_head`: 각 flag의 임계값
32
+ - **모델 실시간 업데이트**
33
+ - Colab 학습 후 `latest` 브랜치 업로드 → `/ping_reload` 호출 시 즉시 재로드
34
+
35
+ ---
36
+
37
+ ## 📂 디렉토리 구조
38
+ ```
39
+ hf-serve/
40
+ ├─ app.py # Gradio UI + API 라우팅
41
+ ├─ inference.py # 모델 추론 로직
42
+ ├─ model_loader.py # 모델/토크나이저 로드
43
+ ├─ utils_prompt.py # prompt 생성 함수
44
+ ├─ flags.json # flag index → name 매핑
45
+ ├─ requirements.txt # 의존성 패키지
46
+ └─ README.md # (현재 문서)
47
+ ```
48
+
49
+ ---
50
+
51
+ ## ⚙️ 추론 로직 개요
52
+
53
+ 이 서버의 핵심은 `run_inference()` 함수로,
54
+ NPC 메인 모델에 프롬프트를 입력하고 응답·상태 변화를 예측하는 전 과정을 담당합니다.
55
+
56
+ ### 처리 흐름
57
+ 1. **프롬프트 토크나이즈**
58
+ - 입력된 prompt를 토크나이저로 변환하여 텐서 형태로 준비
59
+ - 길이 제한(`MAX_LENGTH`)과 디바이스(`DEVICE`) 설정 적용
60
+
61
+ 2. **언어모델 응답 생성**
62
+ - 사전 정의된 추론 파라미터(`GEN_PARAMS`)로 `model.generate()` 실행
63
+ → NPC의 대사 텍스트 생성
64
+ - 생성된 토큰을 디코딩하여 최종 문자열로 변환
65
+
66
+ 3. **히든 상태 추출**
67
+ - `output_hidden_states=True`로 모델 실행
68
+ - 마지막 레이어의 hidden state를 가져옴
69
+
70
+ 4. **<STATE> 토큰 위치 풀링**
71
+ - `<STATE>` 토큰이 있는 위치의 hidden state를 평균(pooling)
72
+ → NPC 상태를 대표하는 벡터로 사용
73
+ - 없을 경우 마지막 토큰의 hidden state 사용
74
+
75
+ 5. **커스텀 헤드 예측**
76
+ - `delta_head`: trust / relationship 변화량 예측
77
+ - `flag_head`: 각 flag의 발생 확률 예측
78
+ - `flag_threshold_head`: 각 flag의 임계값 예측
79
+
80
+ 6. **index → name 매핑**
81
+ - `flags.json`의 순서(`flags_order`)를 기반으로
82
+ 예측 벡터를 `{flag_name: 값}` 형태의 딕셔너리로 변환
83
+
84
+ ### 반환 형식
85
+ ```json
86
+ {
87
+ "npc_output_text": "<NPC 응답>",
88
+ "deltas": { "trust": 0.xx, "relationship": 0.xx },
89
+ "flags_prob": { "flag_name": 확률, ... },
90
+ "flags_thr": { "flag_name": 임계값, ... }
91
+ }
92
+ ```
93
+
94
+ ---
95
+
96
+ ## 📜 Prompt 포맷
97
+ 모델은 학습 시 아래와 같은 구조의 prompt를 사용합니다.
98
+
99
+ ```
100
+ <SYS>
101
+ NPC_ID={npc_id}
102
+ NPC_LOCATION={npc_location}
103
+ TAGS:
104
+ quest_stage={quest_stage}
105
+ relationship={relationship}
106
+ trust={trust}
107
+ npc_mood={npc_mood}
108
+ player_reputation={player_reputation}
109
+ style={style}
110
+ </SYS>
111
+ <RAG>
112
+ LORE: ...
113
+ DESCRIPTION: ...
114
+ </RAG>
115
+ <PLAYER_STATE>
116
+ ...
117
+ </PLAYER_STATE>
118
+ <CTX>
119
+ ...
120
+ </CTX>
121
+ <PLAYER>...
122
+ <STATE>
123
+ <NPC>
124
+ ```
125
+ ---
126
+
127
+ ## 💡 **일반적인 LLM 추론과의 차이점**
128
+ 이 서버는 단순히 텍스트를 생성하는 것에 그치지 않고,
129
+ `<STATE>` 토큰 기반 상태 벡터를 추출하여 커스텀 헤드에서 **감정 변화량(delta)**과
130
+ **플래그 확률/임계값**을 동시에 예측합니다.
131
+ 이를 통해 대사 생성과 게임 상태 업데이트를 **한 번의 추론으로 처리**할 수 있습니다.
132
+
133
+ ---
134
+
135
+ ## 🎯 추론 파라미터
136
+
137
+ | 파라미터 | 의미 | 영향 |
138
+ |----------|------|------|
139
+ | `temperature` | 샘플링 온도 (0.0~1.0+) | 낮을수록 결정적(Deterministic), 높을수록 다양성 증가 |
140
+ | `do_sample` | 샘플링 여부 | `False`면 Greedy/Beam Search, `True`면 확률 기반 샘플링 |
141
+ | `max_new_tokens` | 새로 생성할 토큰 수 제한 | 응답 길이 제한 |
142
+ | `top_p` | nucleus sampling 확률 누적 컷오프 | 다양성 제어 (0.9면 상위 90% 확률만 사용) |
143
+ | `top_k` | 확률 상위 k개 토큰만 샘플링 | 다양성 제어 (50이면 상위 50개 후보만) |
144
+ | `repetition_penalty` | 반복 억제 계수 | 1.0보다 크면 반복 줄임 |
145
+ | `stop` / `eos_token_id` | 생성 중단 토큰 | 특정 문자열/토큰에서 멈춤 |
146
+ | `presence_penalty` / `frequency_penalty` | 특정 토큰 등장 빈도 제어 | OpenAI 계열에서 주로 사용 |
147
+ | `seed` | 난�� 시드 | 재현성 확보 |
148
+
149
+ 위 파라미터들은 **학습 시에는 사용되지 않고**,
150
+ 모델이 응답을 생성하는 **추론 시점**에만 적용됩니다.
151
+
152
+
153
+
154
+ ## 💡 사용 예시
155
+
156
+ - **결정적 분류/판정용**
157
+ (예: `_llm_trigger_check` YES/NO)
158
+ ```python
159
+ temperature = 0.0
160
+ do_sample = False
161
+ max_new_tokens = 2
162
+ ```
163
+ → 항상 같은 입력에 같은 출력, 짧고 확정적인 답변 [ai_server/의 local fallback model에 특정 조건을 지시할 때 사용]
164
+
165
+ - **자연스러운 대화/창작용**
166
+ (예: main/fallback 대사 생성)
167
+ ```python
168
+ temperature = 0.7
169
+ top_p = 0.9
170
+ do_sample = True
171
+ repetition_penalty = 1.05
172
+ max_new_tokens = 200
173
+ ```
174
+ → 다양성과 자연스러움 확보 [main model 추론시에 사용]
175
+
176
+ hf-serve에서는 자연스러운 대화/창작용의 파라미터 예를 그대로 사용했습니다.
177
+
178
+ ---
179
+
180
+ ## 🌐 API & UI 차이
181
+
182
+ | 경로 | 입력 형식 | 내부 처리 |
183
+ |------|-----------|-----------|
184
+ | `/predict_main` | 완성된 prompt 문자열 | 그대로 추론 |
185
+ | `/ui` | NPC ID, Location, Utterance | `build_webtest_prompt()`로 prompt 생성 후 추론 |
186
+
187
+ ---
188
+
189
+ ## 📌 API 사용 예시
190
+
191
+ ### 요청
192
+ ```json
193
+ POST /api/predict_main
194
+ {
195
+ "session_id": "abc123",
196
+ "npc_id": "mother_abandoned_factory",
197
+ "prompt": "<SYS>...<NPC>",
198
+ "max_tokens": 200
199
+ }
200
+ ```
201
+
202
+ ### 응답
203
+ ```json
204
+ {
205
+ "session_id": "abc123",
206
+ "npc_id": "mother_abandoned_factory",
207
+ "npc_response": "그건 정말 놀라운 이야기군요.",
208
+ "deltas": { "trust": 0.42, "relationship": -0.13 },
209
+ "flags": { "give_item": 0.87, "end_npc_main_story": 0.02 },
210
+ "thresholds": { "give_item": 0.65, "end_npc_main_story": 0.5 }
211
+ }
212
+ ```
213
+
214
+ ---
215
+
216
+ ## 🔄 모델 업데이트 흐름
217
+ 1. Colab에서 학습 완료
218
+ 2. Hugging Face Hub `latest` 브랜치에 업로드
219
+ 3. Colab에서 `/api/ping_reload` 호출
220
+ 4. Space가 최신 모델 재다운로드 & 로드
221
+
222
+ ---
223
+
224
+ ## 🛠 실행 방법
225
+
226
+ ### 로컬 실행
227
+ ```bash
228
+ git clone https://huggingface.co/spaces/m97j/PersonaChatEngine
229
+ cd PersonaChatEngine
230
+ pip install -r requirements.txt
231
+ python app.py
232
+ ```
233
+
234
+ ### Hugging Face Space에서 실행
235
+ - 웹 UI: `https://m97j-PersonaChatEngine.hf.space/ui`
236
+ - API: `POST https://m97j-PersonaChatEngine.hf.space/api/predict_main`
237
+
238
+ ---
239
+
240
+ ## 🛠 실행 환경
241
+ - Python 3.10
242
+ - FastAPI, Gradio, Transformers, PEFT, Torch
243
+ - GPU 지원 시 추론 속도 향상
244
+ ---
245
+
246
+ ## 💡 비용 최적화 팁
247
+ - Space Settings → Hardware에서 Free CPU로 전환 시 과금 없음
248
+ - GPU 사용 시 테스트 후 Stop 버튼으로 Space 중지
249
+ - 48시간 요청 없으면 자동 sleep
250
+
251
+ ---
252
+
253
+ ## 🔗 관련 리포지토리
254
+ - **전체 프로젝트 개요 & AI 서버 코드**: [GitHub - persona-chat-engine](https://github.com/m97j/persona-chat-engine)
255
+ - **모델 어댑터 파일(HF Hub)**: [Hugging Face Model Repo](https://huggingface.co/m97j/npc_LoRA-fps)
256
+
257
+ ---
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ transformers==4.43.3
3
+ torch==2.3.1
4
+ accelerate==0.33.0
5
+ peft==0.11.1
6
+ sentence-transformers==3.0.1
7
+ python-dotenv==1.0.1
utils_prompt.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+
3
+ def build_webtest_prompt(npc_id: str, npc_location: str, player_utt: str) -> str:
4
+ # 웹 테스트에서는 최소 필드만 채운 pre dict 생성
5
+ pre = {
6
+ "tags": {
7
+ "npc_id": npc_id,
8
+ "location": npc_location,
9
+ "quest_stage": "",
10
+ "relationship": "",
11
+ "trust": "",
12
+ "npc_mood": "",
13
+ "player_reputation": "",
14
+ "style": ""
15
+ },
16
+ "player_state": {},
17
+ "rag_main_docs": [], # 웹 테스트에서는 RAG 문서 없음
18
+ "context": [], # 대화 히스토리 없음
19
+ "player_utterance": player_utt
20
+ }
21
+ # session_id는 웹 테스트에서는 의미 없으니 빈 값
22
+ return build_main_prompt(pre, session_id="", npc_id=npc_id)
23
+
24
+
25
+ def build_main_prompt(pre: Dict[str, Any], session_id: str, npc_id: str) -> str:
26
+ tags = pre.get("tags", {})
27
+ ps = pre.get("player_state", {})
28
+ rag_docs = pre.get("rag_main_docs", [])
29
+
30
+ # RAG 문서 분리
31
+ lore_text = ""
32
+ desc_text = ""
33
+ for doc in rag_docs:
34
+ if "LORE:" in doc:
35
+ lore_text += doc + "\n"
36
+ elif "DESCRIPTION:" in doc:
37
+ desc_text += doc + "\n"
38
+ else:
39
+ # fallback: type 기반 분리 가능
40
+ if "lore" in doc.lower():
41
+ lore_text += doc + "\n"
42
+ elif "description" in doc.lower():
43
+ desc_text += doc + "\n"
44
+
45
+ prompt = [
46
+ "<SYS>",
47
+ f"NPC_ID={tags.get('npc_id','')}",
48
+ f"NPC_LOCATION={tags.get('location','')}",
49
+ "TAGS:",
50
+ f" quest_stage={tags.get('quest_stage','')}",
51
+ f" relationship={tags.get('relationship','')}",
52
+ f" trust={tags.get('trust','')}",
53
+ f" npc_mood={tags.get('npc_mood','')}",
54
+ f" player_reputation={tags.get('player_reputation','')}",
55
+ f" style={tags.get('style','')}",
56
+ "</SYS>",
57
+ "<RAG>",
58
+ f"LORE: {lore_text.strip() or '(없음)'}",
59
+ f"DESCRIPTION: {desc_text.strip() or '(없음)'}",
60
+ "</RAG>",
61
+ "<PLAYER_STATE>"
62
+ ]
63
+
64
+ if ps.get("items"):
65
+ prompt.append(f"items={','.join(ps['items'])}")
66
+ if ps.get("actions"):
67
+ prompt.append(f"actions={','.join(ps['actions'])}")
68
+ if ps.get("position"):
69
+ prompt.append(f"position={ps['position']}")
70
+ prompt.append("</PLAYER_STATE>")
71
+
72
+ prompt.append("<CTX>")
73
+ for h in pre.get("context", []):
74
+ prompt.append(f"{h['role']}: {h['text']}")
75
+ prompt.append("</CTX>")
76
+
77
+ prompt.append(f"<PLAYER>{pre.get('player_utterance','').rstrip()}")
78
+ prompt.append("<STATE>")
79
+ prompt.append("<NPC>")
80
+
81
+ return "\n".join(prompt)