Upload train.py
Browse files
train.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train.py
|
2 |
+
# YOUR dataset.py should have create_huggingface_dataset() function like this:
|
3 |
+
# from dataset import create_huggingface_dataset
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, TrainingArguments, Trainer
|
7 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Dict
|
11 |
+
import json # Import json for dataset loading example
|
12 |
+
|
13 |
+
# Import the EmotionalLlamaModel and constants from the emotional_gemma.py file
|
14 |
+
from emotional_gemma import EmotionalLlamaModel, EMOTION_DIMENSIONS, EMOTION_DIMENSIONS_REFERENCE, MODEL_NAME
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
# Define the DataCollator for handling padding and adding emotion vectors
|
19 |
+
@dataclass
|
20 |
+
class DataCollatorForEmotionalLlama:
|
21 |
+
tokenizer: AutoTokenizer
|
22 |
+
max_length: int
|
23 |
+
emotion_dim: int = EMOTION_DIMENSIONS # Use the constant from emotional_gemma
|
24 |
+
|
25 |
+
def __call__(self, examples: list) -> Dict[str, torch.Tensor]:
|
26 |
+
# Separate the components from the examples
|
27 |
+
input_ids_list = [example.get("input_ids", []) for example in examples]
|
28 |
+
attention_mask_list = [example.get("attention_mask", []) for example in examples]
|
29 |
+
emotion_vectors_list = [example.get("emotion_vectors", []) for example in examples]
|
30 |
+
|
31 |
+
# --- Find the token ID for the start of the model's turn ---
|
32 |
+
# This is used to mask out user input and padding from the labels
|
33 |
+
# Ensure your tokenizer and dataset preparation consistently include this sequence.
|
34 |
+
try:
|
35 |
+
# Tokenize the specific sequence marking the model's turn start.
|
36 |
+
# add_special_tokens=False is crucial here to get just the tokens for the string.
|
37 |
+
model_prompt_tokens = self.tokenizer(
|
38 |
+
"<start_of_turn>model\n",
|
39 |
+
add_special_tokens=False
|
40 |
+
).input_ids
|
41 |
+
if not model_prompt_tokens:
|
42 |
+
raise ValueError("Tokenizer produced empty list for model prompt sequence.")
|
43 |
+
# print(f"DEBUG: Detected model prompt start tokens: {model_prompt_tokens} (decoded: '{self.tokenizer.decode(model_prompt_tokens)}')")
|
44 |
+
except Exception as e:
|
45 |
+
print(f"ERROR: Could not tokenize model prompt '<start_of_turn>model\\n'. Check tokenizer and template format. Error: {e}")
|
46 |
+
raise ValueError("Cannot proceed without identifying model start tokens for label masking.") from e
|
47 |
+
|
48 |
+
batch_input_ids = []
|
49 |
+
batch_attention_mask = []
|
50 |
+
batch_labels = []
|
51 |
+
batch_emotion_vectors = []
|
52 |
+
|
53 |
+
# Process each example in the batch
|
54 |
+
for i in range(len(input_ids_list)):
|
55 |
+
input_ids = input_ids_list[i]
|
56 |
+
attention_mask = attention_mask_list[i]
|
57 |
+
emotion_vectors = emotion_vectors_list[i]
|
58 |
+
|
59 |
+
# --- Padding ---
|
60 |
+
seq_len = len(input_ids)
|
61 |
+
pad_len = self.max_length - seq_len
|
62 |
+
|
63 |
+
# Truncate if sequence is longer than max_length (should ideally be handled in dataset)
|
64 |
+
if pad_len < 0:
|
65 |
+
input_ids = input_ids[:self.max_length]
|
66 |
+
attention_mask = attention_mask[:self.max_length]
|
67 |
+
emotion_vectors = emotion_vectors[:self.max_length]
|
68 |
+
seq_len = self.max_length
|
69 |
+
pad_len = 0 # Recalculate pad_len after truncation
|
70 |
+
|
71 |
+
# Pad input IDs, attention mask, and emotion vectors
|
72 |
+
padded_input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
|
73 |
+
padded_attention_mask = attention_mask + [0] * pad_len
|
74 |
+
# Pad emotion vectors with zero vectors
|
75 |
+
padded_emotion_vectors = emotion_vectors + [[0.0] * self.emotion_dim] * pad_len
|
76 |
+
|
77 |
+
# --- Create Labels and Mask User/Padding Tokens ---
|
78 |
+
labels = list(padded_input_ids) # Start with a copy of input_ids for labels
|
79 |
+
|
80 |
+
# Find the start index of the model's response to mask previous tokens
|
81 |
+
model_start_idx = -1
|
82 |
+
# Search for the model prompt token sequence within the original input_ids
|
83 |
+
for k in range(seq_len - len(model_prompt_tokens) + 1):
|
84 |
+
if input_ids[k : k + len(model_prompt_tokens)] == model_prompt_tokens:
|
85 |
+
model_start_idx = k
|
86 |
+
break
|
87 |
+
|
88 |
+
if model_start_idx != -1:
|
89 |
+
# Mask everything before and including the model's prompt sequence
|
90 |
+
for j in range(model_start_idx + len(model_prompt_tokens)):
|
91 |
+
labels[j] = -100
|
92 |
+
else:
|
93 |
+
print(f"Warning: Model prompt sequence not found in sample {i}. Masking all labels.")
|
94 |
+
labels = [-100] * self.max_length # Mask everything
|
95 |
+
|
96 |
+
# Mask padding tokens regardless of model prompt finding
|
97 |
+
for j in range(seq_len, self.max_length): # Only mask the padded part
|
98 |
+
labels[j] = -100
|
99 |
+
|
100 |
+
# Sanity check: ensure all lists have the correct length
|
101 |
+
if len(padded_input_ids) != self.max_length or \
|
102 |
+
len(padded_attention_mask) != self.max_length or \
|
103 |
+
len(labels) != self.max_length or \
|
104 |
+
len(padded_emotion_vectors) != self.max_length:
|
105 |
+
raise ValueError(f"Length mismatch in collator for sample {i} after padding/truncation!")
|
106 |
+
|
107 |
+
batch_input_ids.append(padded_input_ids)
|
108 |
+
batch_attention_mask.append(padded_attention_mask)
|
109 |
+
batch_labels.append(labels)
|
110 |
+
batch_emotion_vectors.append(padded_emotion_vectors)
|
111 |
+
|
112 |
+
# Convert lists to tensors
|
113 |
+
batch = {
|
114 |
+
"input_ids": torch.tensor(batch_input_ids, dtype=torch.long),
|
115 |
+
"attention_mask": torch.tensor(batch_attention_mask, dtype=torch.long),
|
116 |
+
"labels": torch.tensor(batch_labels, dtype=torch.long),
|
117 |
+
"emotion_vector": torch.tensor(batch_emotion_vectors, dtype=torch.float),
|
118 |
+
}
|
119 |
+
|
120 |
+
return batch
|
121 |
+
|
122 |
+
|
123 |
+
# Subclass Trainer to potentially customize dataloader behavior
|
124 |
+
class CustomTrainer(Trainer):
|
125 |
+
def get_train_dataloader(self) -> DataLoader:
|
126 |
+
"""
|
127 |
+
Overrides the method to explicitly use the provided data collator.
|
128 |
+
This is mostly for clarity or if the default Trainer behavior needs bypassing.
|
129 |
+
"""
|
130 |
+
if self.train_dataset is None:
|
131 |
+
raise ValueError("Trainer: training requires a train_dataset.")
|
132 |
+
|
133 |
+
# Use the data_collator provided during Trainer initialization
|
134 |
+
data_collator = self.data_collator
|
135 |
+
|
136 |
+
return DataLoader(
|
137 |
+
self.train_dataset,
|
138 |
+
batch_size=self.args.train_batch_size,
|
139 |
+
shuffle=True, # Important for training
|
140 |
+
collate_fn=data_collator,
|
141 |
+
drop_last=self.args.dataloader_drop_last,
|
142 |
+
num_workers=self.args.dataloader_num_workers,
|
143 |
+
pin_memory=self.args.dataloader_pin_memory,
|
144 |
+
)
|
145 |
+
|
146 |
+
def train_emotional_llama(
|
147 |
+
model_name=MODEL_NAME, # Use the default model name from emotional_gemma.py
|
148 |
+
dataset_path="./dataset.json", # Path to your dataset file
|
149 |
+
output_dir="./emotional-gemma-output", # Directory to save results
|
150 |
+
max_length=128, # Max sequence length for training
|
151 |
+
learning_rate=1e-4, # Base learning rate for LoRA
|
152 |
+
emotion_proj_lr=2e-3, # Higher learning rate for emotion projection layer
|
153 |
+
num_train_epochs=2,
|
154 |
+
per_device_batch_size=12,
|
155 |
+
gradient_accumulation_steps=1,
|
156 |
+
use_lora=True # Whether to use LoRA
|
157 |
+
):
|
158 |
+
"""
|
159 |
+
Sets up and runs the training for the EmotionalLlamaModel.
|
160 |
+
"""
|
161 |
+
print(f"Loading tokenizer: {model_name}")
|
162 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
163 |
+
# Set pad_token to eos_token for Gemma if not already set
|
164 |
+
if tokenizer.pad_token is None:
|
165 |
+
tokenizer.pad_token = tokenizer.eos_token
|
166 |
+
# Set padding side to right for causal models
|
167 |
+
tokenizer.padding_side = "right"
|
168 |
+
|
169 |
+
print(f"Loading base model: {model_name}")
|
170 |
+
# Load the custom EmotionalLlamaModel
|
171 |
+
model = EmotionalLlamaModel.from_pretrained(model_name)
|
172 |
+
|
173 |
+
if use_lora:
|
174 |
+
print("Applying LoRA configuration")
|
175 |
+
# Define LoRA configuration
|
176 |
+
peft_config = LoraConfig(
|
177 |
+
task_type=TaskType.CAUSAL_LM,
|
178 |
+
inference_mode=False,
|
179 |
+
r=32, # LoRA rank
|
180 |
+
lora_alpha=32, # LoRA scaling factor
|
181 |
+
# lora_dropout=0.05, # Dropout for LoRA layers (optional)
|
182 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] # Modules to apply LoRA to
|
183 |
+
)
|
184 |
+
# Get the PEFT model by wrapping the base model
|
185 |
+
model = get_peft_model(model, peft_config)
|
186 |
+
# Print trainable parameters summary
|
187 |
+
model.print_trainable_parameters()
|
188 |
+
|
189 |
+
# Ensure the emotion projection layer is trainable
|
190 |
+
# This is necessary if LoRA was applied, as LoRA defaults other layers to not trainable.
|
191 |
+
print("Setting emotion_proj_embed requires_grad=True")
|
192 |
+
for param in model.emotion_proj_embed.parameters():
|
193 |
+
param.requires_grad = True
|
194 |
+
|
195 |
+
# --- Load and Prepare Dataset ---
|
196 |
+
print(f"Loading dataset from: {dataset_path}")
|
197 |
+
# Import and use your dataset creation function
|
198 |
+
try:
|
199 |
+
from dataset import create_huggingface_dataset
|
200 |
+
dataset = create_huggingface_dataset(dataset_path, tokenizer, max_length)
|
201 |
+
print(f"Dataset loaded with {len(dataset)} examples.")
|
202 |
+
except ImportError:
|
203 |
+
print("Error: Could not import 'create_huggingface_dataset' from dataset.py")
|
204 |
+
print("Please ensure dataset.py exists and contains the necessary function.")
|
205 |
+
print("Example dummy dataset creation:")
|
206 |
+
# --- PLACEHOLDER! Dummy Dataset Creation Example ---
|
207 |
+
# PLACEHOLDER! if dataset.py is not available.
|
208 |
+
# Replace this section with your actual dataset loading and processing logic.
|
209 |
+
dummy_data = [
|
210 |
+
{"text": "<start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\nHi there!", "emotion_vectors": [[0.1]*EMOTION_DIMENSIONS] * 20},
|
211 |
+
{"text": "<start_of_turn>user\nHow are you?<end_of_turn>\n<start_of_turn>model\nI'm feeling good today.", "emotion_vectors": [[0.8]*EMOTION_DIMENSIONS] * 25},
|
212 |
+
]
|
213 |
+
def dummy_process(example):
|
214 |
+
# Simple tokenization for dummy data
|
215 |
+
tokenized = tokenizer(example["text"], truncation=True, max_length=max_length, padding="max_length")
|
216 |
+
tokenized["emotion_vectors"] = example["emotion_vectors"][:max_length] # Truncate/pad emotion vectors too
|
217 |
+
if len(tokenized["emotion_vectors"]) < max_length:
|
218 |
+
tokenized["emotion_vectors"] += [[0.0] * EMOTION_DIMENSIONS] * (max_length - len(tokenized["emotion_vectors"]))
|
219 |
+
return tokenized
|
220 |
+
|
221 |
+
from datasets import Dataset
|
222 |
+
dataset = Dataset.from_list(dummy_data).map(dummy_process)
|
223 |
+
print("Created a dummy dataset. REPLACE THIS with your actual dataset loading!")
|
224 |
+
# --- End Dummy Dataset Example ---
|
225 |
+
|
226 |
+
# Initialize the data collator
|
227 |
+
data_collator = DataCollatorForEmotionalLlama(tokenizer=tokenizer, max_length=max_length)
|
228 |
+
|
229 |
+
# --- Training Arguments ---
|
230 |
+
training_args = TrainingArguments(
|
231 |
+
output_dir=output_dir,
|
232 |
+
learning_rate=learning_rate,
|
233 |
+
num_train_epochs=num_train_epochs,
|
234 |
+
per_device_train_batch_size=per_device_batch_size,
|
235 |
+
gradient_accumulation_steps=gradient_accumulation_steps, # Accumulate gradients over steps
|
236 |
+
warmup_ratio=0.1, # Linear warmup over the first 10% of steps
|
237 |
+
weight_decay=0.01, # L2 regularization for most parameters
|
238 |
+
logging_steps=10, # Log training progress every N steps
|
239 |
+
save_steps=200, # Save checkpoint every N steps
|
240 |
+
save_total_limit=2, # Keep only the last N checkpoints
|
241 |
+
report_to="none", # Disable reporting to external platforms like W&B
|
242 |
+
push_to_hub=False, # Do not push to Hugging Face Hub
|
243 |
+
bf16=torch.cuda.is_bf16_supported(), # Use bf16 if supported
|
244 |
+
fp16=not torch.cuda.is_bf16_supported(), # Otherwise use fp16
|
245 |
+
lr_scheduler_type="cosine", # Cosine annealing learning rate scheduler
|
246 |
+
optim="adamw_torch" # PyTorch AdamW optimizer
|
247 |
+
)
|
248 |
+
|
249 |
+
# --- Optimizer Setup ---
|
250 |
+
# Split parameters for different learning rates and weight decay
|
251 |
+
# LoRA parameters and other model parameters (if any are trainable beyond LoRA)
|
252 |
+
main_params = [p for n, p in model.named_parameters() if p.requires_grad and "emotion_proj" not in n]
|
253 |
+
# Emotion projection layer parameters
|
254 |
+
emotion_params = [p for n, p in model.named_parameters() if "emotion_proj" in n and p.requires_grad]
|
255 |
+
|
256 |
+
# Define parameter groups for the optimizer
|
257 |
+
optimizer_grouped_parameters = [
|
258 |
+
# Group for main parameters (LoRA, etc.) with weight decay
|
259 |
+
{"params": main_params, "lr": training_args.learning_rate, "weight_decay": training_args.weight_decay},
|
260 |
+
# Group for emotion projection layer parameters with a higher LR and NO weight decay
|
261 |
+
{"params": emotion_params, "lr": emotion_proj_lr, "weight_decay": 0.0}
|
262 |
+
]
|
263 |
+
|
264 |
+
# Initialize the optimizer
|
265 |
+
optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
|
266 |
+
|
267 |
+
# --- Initialize Trainer ---
|
268 |
+
trainer = CustomTrainer(
|
269 |
+
model=model,
|
270 |
+
args=training_args,
|
271 |
+
train_dataset=dataset,
|
272 |
+
data_collator=data_collator,
|
273 |
+
optimizers=(optimizer, None), # Pass the custom optimizer
|
274 |
+
)
|
275 |
+
|
276 |
+
# --- Optional: Debugging Prints for Dataloader ---
|
277 |
+
# print("\n--- Debugging Data Collator Output (First Batch) ---")
|
278 |
+
# for step, batch in enumerate(trainer.get_train_dataloader()):
|
279 |
+
# print(f" Step {step + 1}:")
|
280 |
+
# print(f" input_ids shape: {batch['input_ids'].shape}")
|
281 |
+
# print(f" attention_mask shape: {batch['attention_mask'].shape}")
|
282 |
+
# print(f" emotion_vector shape: {batch['emotion_vector'].shape}")
|
283 |
+
# print(f" labels shape: {batch['labels'].shape}")
|
284 |
+
# # Print slices or stats for verification
|
285 |
+
# # print(f" input_ids (first row): {batch['input_ids'][0]}")
|
286 |
+
# # print(f" labels (first row): {batch['labels'][0]}")
|
287 |
+
# # print(f" emotion_vector (first row, few elements): {batch['emotion_vector'][0, :10, :2]}")
|
288 |
+
# print(f" emotion_vector batch MIN: {batch['emotion_vector'].min()}")
|
289 |
+
# print(f" emotion_vector batch MAX: {batch['emotion_vector'].max()}")
|
290 |
+
# print(f" emotion_vector batch MEAN: {batch['emotion_vector'].mean()}")
|
291 |
+
# break # Only print the first batch for debug
|
292 |
+
# print("--- End Debugging Data Collator Output ---\n")
|
293 |
+
# --- End Debugging Prints ---
|
294 |
+
|
295 |
+
|
296 |
+
# --- Start Training ---
|
297 |
+
print("Starting training...")
|
298 |
+
trainer.train()
|
299 |
+
print("Training finished.")
|
300 |
+
|
301 |
+
# --- Save the Model ---
|
302 |
+
# Trainer.save_model saves the full model checkpoint by default.
|
303 |
+
# If using PEFT, model.save_pretrained() saves only the adapter weights.
|
304 |
+
# We want to save BOTH the PEFT adapter and the custom layer weights.
|
305 |
+
|
306 |
+
# Save the PEFT adapter weights if using LoRA
|
307 |
+
if use_lora:
|
308 |
+
print(f"Saving PEFT adapter model to {output_dir}")
|
309 |
+
# This saves adapter_model.safetensors and adapter_config.json
|
310 |
+
model.save_pretrained(output_dir)
|
311 |
+
else:
|
312 |
+
# If not using LoRA, save the full model checkpoint
|
313 |
+
print(f"Saving full model checkpoint to {output_dir}")
|
314 |
+
trainer.save_model(output_dir)
|
315 |
+
|
316 |
+
# Manually Save Custom Layer Weights (the emotion_proj_embed layer)
|
317 |
+
print(f"Saving custom emotion_proj_embed weights...")
|
318 |
+
# Access the custom layer, handling the case if the model is wrapped by PEFT
|
319 |
+
if hasattr(model, "base_model"): # Check if it's a PeftModel
|
320 |
+
emotion_layer = model.base_model.emotion_proj_embed
|
321 |
+
else: # If not using PEFT, the layer is directly on the model
|
322 |
+
emotion_layer = model.emotion_proj_embed
|
323 |
+
|
324 |
+
# Get the state dictionary of the custom layer
|
325 |
+
emotion_state_dict = emotion_layer.state_dict()
|
326 |
+
# Define the save path within the output directory
|
327 |
+
save_path_emotion = os.path.join(output_dir, "emotion_proj_weights.pth")
|
328 |
+
# Save the state dictionary
|
329 |
+
torch.save(emotion_state_dict, save_path_emotion)
|
330 |
+
print(f"Custom emotion_proj_embed weights saved to: {save_path_emotion}")
|
331 |
+
|
332 |
+
# Return the trained model and tokenizer
|
333 |
+
return model, tokenizer
|
334 |
+
|
335 |
+
if __name__ == "__main__":
|
336 |
+
# Make sure you have a dataset.py and dataset.json file or implement the dummy dataset creation above.
|
337 |
+
# Replace the dataset_path with the actual path to your dataset.
|
338 |
+
train_emotional_llama(
|
339 |
+
dataset_path="./dataset.json", # Replace with your dataset path
|
340 |
+
output_dir="./emotional-gemma-output", # Output directory
|
341 |
+
max_length=128,
|
342 |
+
num_train_epochs=3,
|
343 |
+
per_device_batch_size=4, # Adjust based on your GPU memory
|
344 |
+
gradient_accumulation_steps=8, # Adjust based on desired effective batch size
|
345 |
+
learning_rate=2e-4, # Base LR for LoRA
|
346 |
+
emotion_proj_lr=5e-3, # Higher LR for emotion layer
|
347 |
+
use_lora=True
|
348 |
+
)
|