|
|
|
|
|
|
|
import os
|
|
import torch
|
|
from transformers import AutoTokenizer, TrainingArguments, Trainer
|
|
from peft import get_peft_model, LoraConfig, TaskType
|
|
from torch.utils.data import DataLoader
|
|
from dataclasses import dataclass
|
|
from typing import Dict
|
|
import json
|
|
|
|
|
|
from emotional_gemma import EmotionalLlamaModel, EMOTION_DIMENSIONS, EMOTION_DIMENSIONS_REFERENCE, MODEL_NAME
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForEmotionalLlama:
|
|
tokenizer: AutoTokenizer
|
|
max_length: int
|
|
emotion_dim: int = EMOTION_DIMENSIONS
|
|
|
|
def __call__(self, examples: list) -> Dict[str, torch.Tensor]:
|
|
|
|
input_ids_list = [example.get("input_ids", []) for example in examples]
|
|
attention_mask_list = [example.get("attention_mask", []) for example in examples]
|
|
emotion_vectors_list = [example.get("emotion_vectors", []) for example in examples]
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
model_prompt_tokens = self.tokenizer(
|
|
"<start_of_turn>model\n",
|
|
add_special_tokens=False
|
|
).input_ids
|
|
if not model_prompt_tokens:
|
|
raise ValueError("Tokenizer produced empty list for model prompt sequence.")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Could not tokenize model prompt '<start_of_turn>model\\n'. Check tokenizer and template format. Error: {e}")
|
|
raise ValueError("Cannot proceed without identifying model start tokens for label masking.") from e
|
|
|
|
batch_input_ids = []
|
|
batch_attention_mask = []
|
|
batch_labels = []
|
|
batch_emotion_vectors = []
|
|
|
|
|
|
for i in range(len(input_ids_list)):
|
|
input_ids = input_ids_list[i]
|
|
attention_mask = attention_mask_list[i]
|
|
emotion_vectors = emotion_vectors_list[i]
|
|
|
|
|
|
seq_len = len(input_ids)
|
|
pad_len = self.max_length - seq_len
|
|
|
|
|
|
if pad_len < 0:
|
|
input_ids = input_ids[:self.max_length]
|
|
attention_mask = attention_mask[:self.max_length]
|
|
emotion_vectors = emotion_vectors[:self.max_length]
|
|
seq_len = self.max_length
|
|
pad_len = 0
|
|
|
|
|
|
padded_input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
|
|
padded_attention_mask = attention_mask + [0] * pad_len
|
|
|
|
padded_emotion_vectors = emotion_vectors + [[0.0] * self.emotion_dim] * pad_len
|
|
|
|
|
|
labels = list(padded_input_ids)
|
|
|
|
|
|
model_start_idx = -1
|
|
|
|
for k in range(seq_len - len(model_prompt_tokens) + 1):
|
|
if input_ids[k : k + len(model_prompt_tokens)] == model_prompt_tokens:
|
|
model_start_idx = k
|
|
break
|
|
|
|
if model_start_idx != -1:
|
|
|
|
for j in range(model_start_idx + len(model_prompt_tokens)):
|
|
labels[j] = -100
|
|
else:
|
|
print(f"Warning: Model prompt sequence not found in sample {i}. Masking all labels.")
|
|
labels = [-100] * self.max_length
|
|
|
|
|
|
for j in range(seq_len, self.max_length):
|
|
labels[j] = -100
|
|
|
|
|
|
if len(padded_input_ids) != self.max_length or \
|
|
len(padded_attention_mask) != self.max_length or \
|
|
len(labels) != self.max_length or \
|
|
len(padded_emotion_vectors) != self.max_length:
|
|
raise ValueError(f"Length mismatch in collator for sample {i} after padding/truncation!")
|
|
|
|
batch_input_ids.append(padded_input_ids)
|
|
batch_attention_mask.append(padded_attention_mask)
|
|
batch_labels.append(labels)
|
|
batch_emotion_vectors.append(padded_emotion_vectors)
|
|
|
|
|
|
batch = {
|
|
"input_ids": torch.tensor(batch_input_ids, dtype=torch.long),
|
|
"attention_mask": torch.tensor(batch_attention_mask, dtype=torch.long),
|
|
"labels": torch.tensor(batch_labels, dtype=torch.long),
|
|
"emotion_vector": torch.tensor(batch_emotion_vectors, dtype=torch.float),
|
|
}
|
|
|
|
return batch
|
|
|
|
|
|
|
|
class CustomTrainer(Trainer):
|
|
def get_train_dataloader(self) -> DataLoader:
|
|
"""
|
|
Overrides the method to explicitly use the provided data collator.
|
|
This is mostly for clarity or if the default Trainer behavior needs bypassing.
|
|
"""
|
|
if self.train_dataset is None:
|
|
raise ValueError("Trainer: training requires a train_dataset.")
|
|
|
|
|
|
data_collator = self.data_collator
|
|
|
|
return DataLoader(
|
|
self.train_dataset,
|
|
batch_size=self.args.train_batch_size,
|
|
shuffle=True,
|
|
collate_fn=data_collator,
|
|
drop_last=self.args.dataloader_drop_last,
|
|
num_workers=self.args.dataloader_num_workers,
|
|
pin_memory=self.args.dataloader_pin_memory,
|
|
)
|
|
|
|
def train_emotional_llama(
|
|
model_name=MODEL_NAME,
|
|
dataset_path="./dataset.json",
|
|
output_dir="./emotional-gemma-output",
|
|
max_length=128,
|
|
learning_rate=1e-4,
|
|
emotion_proj_lr=2e-3,
|
|
num_train_epochs=2,
|
|
per_device_batch_size=12,
|
|
gradient_accumulation_steps=1,
|
|
use_lora=True
|
|
):
|
|
"""
|
|
Sets up and runs the training for the EmotionalLlamaModel.
|
|
"""
|
|
print(f"Loading tokenizer: {model_name}")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
tokenizer.padding_side = "right"
|
|
|
|
print(f"Loading base model: {model_name}")
|
|
|
|
model = EmotionalLlamaModel.from_pretrained(model_name)
|
|
|
|
if use_lora:
|
|
print("Applying LoRA configuration")
|
|
|
|
peft_config = LoraConfig(
|
|
task_type=TaskType.CAUSAL_LM,
|
|
inference_mode=False,
|
|
r=32,
|
|
lora_alpha=32,
|
|
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
|
)
|
|
|
|
model = get_peft_model(model, peft_config)
|
|
|
|
model.print_trainable_parameters()
|
|
|
|
|
|
|
|
print("Setting emotion_proj_embed requires_grad=True")
|
|
for param in model.emotion_proj_embed.parameters():
|
|
param.requires_grad = True
|
|
|
|
|
|
print(f"Loading dataset from: {dataset_path}")
|
|
|
|
try:
|
|
from dataset import create_huggingface_dataset
|
|
dataset = create_huggingface_dataset(dataset_path, tokenizer, max_length)
|
|
print(f"Dataset loaded with {len(dataset)} examples.")
|
|
except ImportError:
|
|
print("Error: Could not import 'create_huggingface_dataset' from dataset.py")
|
|
print("Please ensure dataset.py exists and contains the necessary function.")
|
|
print("Example dummy dataset creation:")
|
|
|
|
|
|
|
|
dummy_data = [
|
|
{"text": "<start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\nHi there!", "emotion_vectors": [[0.1]*EMOTION_DIMENSIONS] * 20},
|
|
{"text": "<start_of_turn>user\nHow are you?<end_of_turn>\n<start_of_turn>model\nI'm feeling good today.", "emotion_vectors": [[0.8]*EMOTION_DIMENSIONS] * 25},
|
|
]
|
|
def dummy_process(example):
|
|
|
|
tokenized = tokenizer(example["text"], truncation=True, max_length=max_length, padding="max_length")
|
|
tokenized["emotion_vectors"] = example["emotion_vectors"][:max_length]
|
|
if len(tokenized["emotion_vectors"]) < max_length:
|
|
tokenized["emotion_vectors"] += [[0.0] * EMOTION_DIMENSIONS] * (max_length - len(tokenized["emotion_vectors"]))
|
|
return tokenized
|
|
|
|
from datasets import Dataset
|
|
dataset = Dataset.from_list(dummy_data).map(dummy_process)
|
|
print("Created a dummy dataset. REPLACE THIS with your actual dataset loading!")
|
|
|
|
|
|
|
|
data_collator = DataCollatorForEmotionalLlama(tokenizer=tokenizer, max_length=max_length)
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=output_dir,
|
|
learning_rate=learning_rate,
|
|
num_train_epochs=num_train_epochs,
|
|
per_device_train_batch_size=per_device_batch_size,
|
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
warmup_ratio=0.1,
|
|
weight_decay=0.01,
|
|
logging_steps=10,
|
|
save_steps=200,
|
|
save_total_limit=2,
|
|
report_to="none",
|
|
push_to_hub=False,
|
|
bf16=torch.cuda.is_bf16_supported(),
|
|
fp16=not torch.cuda.is_bf16_supported(),
|
|
lr_scheduler_type="cosine",
|
|
optim="adamw_torch"
|
|
)
|
|
|
|
|
|
|
|
|
|
main_params = [p for n, p in model.named_parameters() if p.requires_grad and "emotion_proj" not in n]
|
|
|
|
emotion_params = [p for n, p in model.named_parameters() if "emotion_proj" in n and p.requires_grad]
|
|
|
|
|
|
optimizer_grouped_parameters = [
|
|
|
|
{"params": main_params, "lr": training_args.learning_rate, "weight_decay": training_args.weight_decay},
|
|
|
|
{"params": emotion_params, "lr": emotion_proj_lr, "weight_decay": 0.0}
|
|
]
|
|
|
|
|
|
optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
|
|
|
|
|
|
trainer = CustomTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
data_collator=data_collator,
|
|
optimizers=(optimizer, None),
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Starting training...")
|
|
trainer.train()
|
|
print("Training finished.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_lora:
|
|
print(f"Saving PEFT adapter model to {output_dir}")
|
|
|
|
model.save_pretrained(output_dir)
|
|
else:
|
|
|
|
print(f"Saving full model checkpoint to {output_dir}")
|
|
trainer.save_model(output_dir)
|
|
|
|
|
|
print(f"Saving custom emotion_proj_embed weights...")
|
|
|
|
if hasattr(model, "base_model"):
|
|
emotion_layer = model.base_model.emotion_proj_embed
|
|
else:
|
|
emotion_layer = model.emotion_proj_embed
|
|
|
|
|
|
emotion_state_dict = emotion_layer.state_dict()
|
|
|
|
save_path_emotion = os.path.join(output_dir, "emotion_proj_weights.pth")
|
|
|
|
torch.save(emotion_state_dict, save_path_emotion)
|
|
print(f"Custom emotion_proj_embed weights saved to: {save_path_emotion}")
|
|
|
|
|
|
return model, tokenizer
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
train_emotional_llama(
|
|
dataset_path="./dataset.json",
|
|
output_dir="./emotional-gemma-output",
|
|
max_length=128,
|
|
num_train_epochs=3,
|
|
per_device_batch_size=4,
|
|
gradient_accumulation_steps=8,
|
|
learning_rate=2e-4,
|
|
emotion_proj_lr=5e-3,
|
|
use_lora=True
|
|
) |