FelixTheWhale commited on
Commit
364f6c2
·
verified ·
1 Parent(s): dbef3ef

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +348 -0
train.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ # YOUR dataset.py should have create_huggingface_dataset() function like this:
3
+ # from dataset import create_huggingface_dataset
4
+ import os
5
+ import torch
6
+ from transformers import AutoTokenizer, TrainingArguments, Trainer
7
+ from peft import get_peft_model, LoraConfig, TaskType
8
+ from torch.utils.data import DataLoader
9
+ from dataclasses import dataclass
10
+ from typing import Dict
11
+ import json # Import json for dataset loading example
12
+
13
+ # Import the EmotionalLlamaModel and constants from the emotional_gemma.py file
14
+ from emotional_gemma import EmotionalLlamaModel, EMOTION_DIMENSIONS, EMOTION_DIMENSIONS_REFERENCE, MODEL_NAME
15
+
16
+
17
+
18
+ # Define the DataCollator for handling padding and adding emotion vectors
19
+ @dataclass
20
+ class DataCollatorForEmotionalLlama:
21
+ tokenizer: AutoTokenizer
22
+ max_length: int
23
+ emotion_dim: int = EMOTION_DIMENSIONS # Use the constant from emotional_gemma
24
+
25
+ def __call__(self, examples: list) -> Dict[str, torch.Tensor]:
26
+ # Separate the components from the examples
27
+ input_ids_list = [example.get("input_ids", []) for example in examples]
28
+ attention_mask_list = [example.get("attention_mask", []) for example in examples]
29
+ emotion_vectors_list = [example.get("emotion_vectors", []) for example in examples]
30
+
31
+ # --- Find the token ID for the start of the model's turn ---
32
+ # This is used to mask out user input and padding from the labels
33
+ # Ensure your tokenizer and dataset preparation consistently include this sequence.
34
+ try:
35
+ # Tokenize the specific sequence marking the model's turn start.
36
+ # add_special_tokens=False is crucial here to get just the tokens for the string.
37
+ model_prompt_tokens = self.tokenizer(
38
+ "<start_of_turn>model\n",
39
+ add_special_tokens=False
40
+ ).input_ids
41
+ if not model_prompt_tokens:
42
+ raise ValueError("Tokenizer produced empty list for model prompt sequence.")
43
+ # print(f"DEBUG: Detected model prompt start tokens: {model_prompt_tokens} (decoded: '{self.tokenizer.decode(model_prompt_tokens)}')")
44
+ except Exception as e:
45
+ print(f"ERROR: Could not tokenize model prompt '<start_of_turn>model\\n'. Check tokenizer and template format. Error: {e}")
46
+ raise ValueError("Cannot proceed without identifying model start tokens for label masking.") from e
47
+
48
+ batch_input_ids = []
49
+ batch_attention_mask = []
50
+ batch_labels = []
51
+ batch_emotion_vectors = []
52
+
53
+ # Process each example in the batch
54
+ for i in range(len(input_ids_list)):
55
+ input_ids = input_ids_list[i]
56
+ attention_mask = attention_mask_list[i]
57
+ emotion_vectors = emotion_vectors_list[i]
58
+
59
+ # --- Padding ---
60
+ seq_len = len(input_ids)
61
+ pad_len = self.max_length - seq_len
62
+
63
+ # Truncate if sequence is longer than max_length (should ideally be handled in dataset)
64
+ if pad_len < 0:
65
+ input_ids = input_ids[:self.max_length]
66
+ attention_mask = attention_mask[:self.max_length]
67
+ emotion_vectors = emotion_vectors[:self.max_length]
68
+ seq_len = self.max_length
69
+ pad_len = 0 # Recalculate pad_len after truncation
70
+
71
+ # Pad input IDs, attention mask, and emotion vectors
72
+ padded_input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
73
+ padded_attention_mask = attention_mask + [0] * pad_len
74
+ # Pad emotion vectors with zero vectors
75
+ padded_emotion_vectors = emotion_vectors + [[0.0] * self.emotion_dim] * pad_len
76
+
77
+ # --- Create Labels and Mask User/Padding Tokens ---
78
+ labels = list(padded_input_ids) # Start with a copy of input_ids for labels
79
+
80
+ # Find the start index of the model's response to mask previous tokens
81
+ model_start_idx = -1
82
+ # Search for the model prompt token sequence within the original input_ids
83
+ for k in range(seq_len - len(model_prompt_tokens) + 1):
84
+ if input_ids[k : k + len(model_prompt_tokens)] == model_prompt_tokens:
85
+ model_start_idx = k
86
+ break
87
+
88
+ if model_start_idx != -1:
89
+ # Mask everything before and including the model's prompt sequence
90
+ for j in range(model_start_idx + len(model_prompt_tokens)):
91
+ labels[j] = -100
92
+ else:
93
+ print(f"Warning: Model prompt sequence not found in sample {i}. Masking all labels.")
94
+ labels = [-100] * self.max_length # Mask everything
95
+
96
+ # Mask padding tokens regardless of model prompt finding
97
+ for j in range(seq_len, self.max_length): # Only mask the padded part
98
+ labels[j] = -100
99
+
100
+ # Sanity check: ensure all lists have the correct length
101
+ if len(padded_input_ids) != self.max_length or \
102
+ len(padded_attention_mask) != self.max_length or \
103
+ len(labels) != self.max_length or \
104
+ len(padded_emotion_vectors) != self.max_length:
105
+ raise ValueError(f"Length mismatch in collator for sample {i} after padding/truncation!")
106
+
107
+ batch_input_ids.append(padded_input_ids)
108
+ batch_attention_mask.append(padded_attention_mask)
109
+ batch_labels.append(labels)
110
+ batch_emotion_vectors.append(padded_emotion_vectors)
111
+
112
+ # Convert lists to tensors
113
+ batch = {
114
+ "input_ids": torch.tensor(batch_input_ids, dtype=torch.long),
115
+ "attention_mask": torch.tensor(batch_attention_mask, dtype=torch.long),
116
+ "labels": torch.tensor(batch_labels, dtype=torch.long),
117
+ "emotion_vector": torch.tensor(batch_emotion_vectors, dtype=torch.float),
118
+ }
119
+
120
+ return batch
121
+
122
+
123
+ # Subclass Trainer to potentially customize dataloader behavior
124
+ class CustomTrainer(Trainer):
125
+ def get_train_dataloader(self) -> DataLoader:
126
+ """
127
+ Overrides the method to explicitly use the provided data collator.
128
+ This is mostly for clarity or if the default Trainer behavior needs bypassing.
129
+ """
130
+ if self.train_dataset is None:
131
+ raise ValueError("Trainer: training requires a train_dataset.")
132
+
133
+ # Use the data_collator provided during Trainer initialization
134
+ data_collator = self.data_collator
135
+
136
+ return DataLoader(
137
+ self.train_dataset,
138
+ batch_size=self.args.train_batch_size,
139
+ shuffle=True, # Important for training
140
+ collate_fn=data_collator,
141
+ drop_last=self.args.dataloader_drop_last,
142
+ num_workers=self.args.dataloader_num_workers,
143
+ pin_memory=self.args.dataloader_pin_memory,
144
+ )
145
+
146
+ def train_emotional_llama(
147
+ model_name=MODEL_NAME, # Use the default model name from emotional_gemma.py
148
+ dataset_path="./dataset.json", # Path to your dataset file
149
+ output_dir="./emotional-gemma-output", # Directory to save results
150
+ max_length=128, # Max sequence length for training
151
+ learning_rate=1e-4, # Base learning rate for LoRA
152
+ emotion_proj_lr=2e-3, # Higher learning rate for emotion projection layer
153
+ num_train_epochs=2,
154
+ per_device_batch_size=12,
155
+ gradient_accumulation_steps=1,
156
+ use_lora=True # Whether to use LoRA
157
+ ):
158
+ """
159
+ Sets up and runs the training for the EmotionalLlamaModel.
160
+ """
161
+ print(f"Loading tokenizer: {model_name}")
162
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
163
+ # Set pad_token to eos_token for Gemma if not already set
164
+ if tokenizer.pad_token is None:
165
+ tokenizer.pad_token = tokenizer.eos_token
166
+ # Set padding side to right for causal models
167
+ tokenizer.padding_side = "right"
168
+
169
+ print(f"Loading base model: {model_name}")
170
+ # Load the custom EmotionalLlamaModel
171
+ model = EmotionalLlamaModel.from_pretrained(model_name)
172
+
173
+ if use_lora:
174
+ print("Applying LoRA configuration")
175
+ # Define LoRA configuration
176
+ peft_config = LoraConfig(
177
+ task_type=TaskType.CAUSAL_LM,
178
+ inference_mode=False,
179
+ r=32, # LoRA rank
180
+ lora_alpha=32, # LoRA scaling factor
181
+ # lora_dropout=0.05, # Dropout for LoRA layers (optional)
182
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] # Modules to apply LoRA to
183
+ )
184
+ # Get the PEFT model by wrapping the base model
185
+ model = get_peft_model(model, peft_config)
186
+ # Print trainable parameters summary
187
+ model.print_trainable_parameters()
188
+
189
+ # Ensure the emotion projection layer is trainable
190
+ # This is necessary if LoRA was applied, as LoRA defaults other layers to not trainable.
191
+ print("Setting emotion_proj_embed requires_grad=True")
192
+ for param in model.emotion_proj_embed.parameters():
193
+ param.requires_grad = True
194
+
195
+ # --- Load and Prepare Dataset ---
196
+ print(f"Loading dataset from: {dataset_path}")
197
+ # Import and use your dataset creation function
198
+ try:
199
+ from dataset import create_huggingface_dataset
200
+ dataset = create_huggingface_dataset(dataset_path, tokenizer, max_length)
201
+ print(f"Dataset loaded with {len(dataset)} examples.")
202
+ except ImportError:
203
+ print("Error: Could not import 'create_huggingface_dataset' from dataset.py")
204
+ print("Please ensure dataset.py exists and contains the necessary function.")
205
+ print("Example dummy dataset creation:")
206
+ # --- PLACEHOLDER! Dummy Dataset Creation Example ---
207
+ # PLACEHOLDER! if dataset.py is not available.
208
+ # Replace this section with your actual dataset loading and processing logic.
209
+ dummy_data = [
210
+ {"text": "<start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\nHi there!", "emotion_vectors": [[0.1]*EMOTION_DIMENSIONS] * 20},
211
+ {"text": "<start_of_turn>user\nHow are you?<end_of_turn>\n<start_of_turn>model\nI'm feeling good today.", "emotion_vectors": [[0.8]*EMOTION_DIMENSIONS] * 25},
212
+ ]
213
+ def dummy_process(example):
214
+ # Simple tokenization for dummy data
215
+ tokenized = tokenizer(example["text"], truncation=True, max_length=max_length, padding="max_length")
216
+ tokenized["emotion_vectors"] = example["emotion_vectors"][:max_length] # Truncate/pad emotion vectors too
217
+ if len(tokenized["emotion_vectors"]) < max_length:
218
+ tokenized["emotion_vectors"] += [[0.0] * EMOTION_DIMENSIONS] * (max_length - len(tokenized["emotion_vectors"]))
219
+ return tokenized
220
+
221
+ from datasets import Dataset
222
+ dataset = Dataset.from_list(dummy_data).map(dummy_process)
223
+ print("Created a dummy dataset. REPLACE THIS with your actual dataset loading!")
224
+ # --- End Dummy Dataset Example ---
225
+
226
+ # Initialize the data collator
227
+ data_collator = DataCollatorForEmotionalLlama(tokenizer=tokenizer, max_length=max_length)
228
+
229
+ # --- Training Arguments ---
230
+ training_args = TrainingArguments(
231
+ output_dir=output_dir,
232
+ learning_rate=learning_rate,
233
+ num_train_epochs=num_train_epochs,
234
+ per_device_train_batch_size=per_device_batch_size,
235
+ gradient_accumulation_steps=gradient_accumulation_steps, # Accumulate gradients over steps
236
+ warmup_ratio=0.1, # Linear warmup over the first 10% of steps
237
+ weight_decay=0.01, # L2 regularization for most parameters
238
+ logging_steps=10, # Log training progress every N steps
239
+ save_steps=200, # Save checkpoint every N steps
240
+ save_total_limit=2, # Keep only the last N checkpoints
241
+ report_to="none", # Disable reporting to external platforms like W&B
242
+ push_to_hub=False, # Do not push to Hugging Face Hub
243
+ bf16=torch.cuda.is_bf16_supported(), # Use bf16 if supported
244
+ fp16=not torch.cuda.is_bf16_supported(), # Otherwise use fp16
245
+ lr_scheduler_type="cosine", # Cosine annealing learning rate scheduler
246
+ optim="adamw_torch" # PyTorch AdamW optimizer
247
+ )
248
+
249
+ # --- Optimizer Setup ---
250
+ # Split parameters for different learning rates and weight decay
251
+ # LoRA parameters and other model parameters (if any are trainable beyond LoRA)
252
+ main_params = [p for n, p in model.named_parameters() if p.requires_grad and "emotion_proj" not in n]
253
+ # Emotion projection layer parameters
254
+ emotion_params = [p for n, p in model.named_parameters() if "emotion_proj" in n and p.requires_grad]
255
+
256
+ # Define parameter groups for the optimizer
257
+ optimizer_grouped_parameters = [
258
+ # Group for main parameters (LoRA, etc.) with weight decay
259
+ {"params": main_params, "lr": training_args.learning_rate, "weight_decay": training_args.weight_decay},
260
+ # Group for emotion projection layer parameters with a higher LR and NO weight decay
261
+ {"params": emotion_params, "lr": emotion_proj_lr, "weight_decay": 0.0}
262
+ ]
263
+
264
+ # Initialize the optimizer
265
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
266
+
267
+ # --- Initialize Trainer ---
268
+ trainer = CustomTrainer(
269
+ model=model,
270
+ args=training_args,
271
+ train_dataset=dataset,
272
+ data_collator=data_collator,
273
+ optimizers=(optimizer, None), # Pass the custom optimizer
274
+ )
275
+
276
+ # --- Optional: Debugging Prints for Dataloader ---
277
+ # print("\n--- Debugging Data Collator Output (First Batch) ---")
278
+ # for step, batch in enumerate(trainer.get_train_dataloader()):
279
+ # print(f" Step {step + 1}:")
280
+ # print(f" input_ids shape: {batch['input_ids'].shape}")
281
+ # print(f" attention_mask shape: {batch['attention_mask'].shape}")
282
+ # print(f" emotion_vector shape: {batch['emotion_vector'].shape}")
283
+ # print(f" labels shape: {batch['labels'].shape}")
284
+ # # Print slices or stats for verification
285
+ # # print(f" input_ids (first row): {batch['input_ids'][0]}")
286
+ # # print(f" labels (first row): {batch['labels'][0]}")
287
+ # # print(f" emotion_vector (first row, few elements): {batch['emotion_vector'][0, :10, :2]}")
288
+ # print(f" emotion_vector batch MIN: {batch['emotion_vector'].min()}")
289
+ # print(f" emotion_vector batch MAX: {batch['emotion_vector'].max()}")
290
+ # print(f" emotion_vector batch MEAN: {batch['emotion_vector'].mean()}")
291
+ # break # Only print the first batch for debug
292
+ # print("--- End Debugging Data Collator Output ---\n")
293
+ # --- End Debugging Prints ---
294
+
295
+
296
+ # --- Start Training ---
297
+ print("Starting training...")
298
+ trainer.train()
299
+ print("Training finished.")
300
+
301
+ # --- Save the Model ---
302
+ # Trainer.save_model saves the full model checkpoint by default.
303
+ # If using PEFT, model.save_pretrained() saves only the adapter weights.
304
+ # We want to save BOTH the PEFT adapter and the custom layer weights.
305
+
306
+ # Save the PEFT adapter weights if using LoRA
307
+ if use_lora:
308
+ print(f"Saving PEFT adapter model to {output_dir}")
309
+ # This saves adapter_model.safetensors and adapter_config.json
310
+ model.save_pretrained(output_dir)
311
+ else:
312
+ # If not using LoRA, save the full model checkpoint
313
+ print(f"Saving full model checkpoint to {output_dir}")
314
+ trainer.save_model(output_dir)
315
+
316
+ # Manually Save Custom Layer Weights (the emotion_proj_embed layer)
317
+ print(f"Saving custom emotion_proj_embed weights...")
318
+ # Access the custom layer, handling the case if the model is wrapped by PEFT
319
+ if hasattr(model, "base_model"): # Check if it's a PeftModel
320
+ emotion_layer = model.base_model.emotion_proj_embed
321
+ else: # If not using PEFT, the layer is directly on the model
322
+ emotion_layer = model.emotion_proj_embed
323
+
324
+ # Get the state dictionary of the custom layer
325
+ emotion_state_dict = emotion_layer.state_dict()
326
+ # Define the save path within the output directory
327
+ save_path_emotion = os.path.join(output_dir, "emotion_proj_weights.pth")
328
+ # Save the state dictionary
329
+ torch.save(emotion_state_dict, save_path_emotion)
330
+ print(f"Custom emotion_proj_embed weights saved to: {save_path_emotion}")
331
+
332
+ # Return the trained model and tokenizer
333
+ return model, tokenizer
334
+
335
+ if __name__ == "__main__":
336
+ # Make sure you have a dataset.py and dataset.json file or implement the dummy dataset creation above.
337
+ # Replace the dataset_path with the actual path to your dataset.
338
+ train_emotional_llama(
339
+ dataset_path="./dataset.json", # Replace with your dataset path
340
+ output_dir="./emotional-gemma-output", # Output directory
341
+ max_length=128,
342
+ num_train_epochs=3,
343
+ per_device_batch_size=4, # Adjust based on your GPU memory
344
+ gradient_accumulation_steps=8, # Adjust based on desired effective batch size
345
+ learning_rate=2e-4, # Base LR for LoRA
346
+ emotion_proj_lr=5e-3, # Higher LR for emotion layer
347
+ use_lora=True
348
+ )