Upload 2 files
Browse filesclass definition and inference scripts
- emotional_gemma.py +143 -0
- inference.py +179 -0
emotional_gemma.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# emotional_gemma.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from transformers import Gemma3ForCausalLM
|
5 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
6 |
+
from typing import Optional, Union
|
7 |
+
|
8 |
+
# Constants
|
9 |
+
MODEL_NAME = "google/gemma-3-1b-it"
|
10 |
+
EMOTION_DIMENSIONS = 8
|
11 |
+
EMOTION_DIMENSIONS_REFERENCE = [
|
12 |
+
"SADNESS_JOY", "FEAR_COURAGE", "DISGUST_ACCEPTANCE", "ANGER_CALMNESS",
|
13 |
+
"SURPRISE_EXPECTATION", "DISTRUST_TRUST", "BOREDOM_INTEREST", "INDIFFERENCE_EMPATHY"
|
14 |
+
]
|
15 |
+
|
16 |
+
class EmotionalLlamaModel(Gemma3ForCausalLM):
|
17 |
+
"""Gemma3 Causal Language Model with emotion modulation."""
|
18 |
+
def __init__(self, config):
|
19 |
+
super().__init__(config)
|
20 |
+
self.emotion_dim = EMOTION_DIMENSIONS
|
21 |
+
|
22 |
+
# Emotion projection layer: MLP
|
23 |
+
# This layer projects the emotion vector to the hidden size of the model.
|
24 |
+
intermediate_size = config.hidden_size // 2
|
25 |
+
self.emotion_proj_embed = nn.Sequential(
|
26 |
+
nn.Linear(self.emotion_dim, intermediate_size),
|
27 |
+
nn.LayerNorm(intermediate_size),
|
28 |
+
nn.GELU(),
|
29 |
+
nn.Linear(intermediate_size, config.hidden_size),
|
30 |
+
)
|
31 |
+
|
32 |
+
# Initialization for the MLP weights
|
33 |
+
def init_weights(m):
|
34 |
+
if isinstance(m, nn.Linear):
|
35 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
36 |
+
if m.bias is not None:
|
37 |
+
torch.nn.init.zeros_(m.bias)
|
38 |
+
self.emotion_proj_embed.apply(init_weights)
|
39 |
+
|
40 |
+
# Post-initialization steps from the base class
|
41 |
+
self.post_init()
|
42 |
+
|
43 |
+
def forward(
|
44 |
+
self,
|
45 |
+
input_ids: Optional[torch.LongTensor] = None,
|
46 |
+
attention_mask: Optional[torch.Tensor] = None,
|
47 |
+
position_ids: Optional[torch.LongTensor] = None,
|
48 |
+
past_key_values: Optional[list] = None,
|
49 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
50 |
+
labels: Optional[torch.LongTensor] = None,
|
51 |
+
use_cache: Optional[bool] = None,
|
52 |
+
output_attentions: Optional[bool] = None,
|
53 |
+
output_hidden_states: Optional[bool] = None,
|
54 |
+
return_dict: Optional[bool] = None,
|
55 |
+
emotion_vector: Optional[torch.FloatTensor] = None,
|
56 |
+
**kwargs,
|
57 |
+
) -> Union[tuple, CausalLMOutputWithPast]:
|
58 |
+
|
59 |
+
# 1. Prepare Input Embeddings
|
60 |
+
# Get input embeddings from input_ids or use provided inputs_embeds
|
61 |
+
if input_ids is not None and inputs_embeds is not None:
|
62 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds")
|
63 |
+
elif input_ids is not None:
|
64 |
+
batch_size, seq_len = input_ids.shape
|
65 |
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
66 |
+
elif inputs_embeds is not None:
|
67 |
+
batch_size, seq_len = inputs_embeds.shape[:2]
|
68 |
+
else:
|
69 |
+
# If neither is provided, it's likely a generation step using only cache.
|
70 |
+
# The base model's forward handles this by looking up the single new token ID.
|
71 |
+
# We will rely on the base model forward to handle this case and potentially
|
72 |
+
# receive `inputs_embeds` as `kwargs`.
|
73 |
+
pass # Standard generate handle embedding lookup for subsequent tokens
|
74 |
+
|
75 |
+
|
76 |
+
# 2. Apply Emotion Modulation to Embeddings
|
77 |
+
# If emotion_vector is provided and we have inputs_embeds, modulate the embeddings
|
78 |
+
if emotion_vector is not None and inputs_embeds is not None:
|
79 |
+
if emotion_vector.shape[0] != batch_size:
|
80 |
+
raise ValueError("Batch size mismatch between emotion_vector and input.")
|
81 |
+
|
82 |
+
# Ensure emotion_vector shape is [batch, seq_len, emotion_dim]
|
83 |
+
# This handles the case where a single emotion vector [batch, emotion_dim]
|
84 |
+
# is provided for the entire sequence during inference.
|
85 |
+
current_seq_len = inputs_embeds.shape[1]
|
86 |
+
if emotion_vector.dim() == 2:
|
87 |
+
emotion_vector = emotion_vector.unsqueeze(1).expand(-1, current_seq_len, -1)
|
88 |
+
elif emotion_vector.shape[1] != current_seq_len:
|
89 |
+
# This case might occur if the emotion vector is longer than the current
|
90 |
+
# input chunk (e.g., during token-by-token generation after prompt).
|
91 |
+
# We take the slice corresponding to the current input.
|
92 |
+
emotion_vector = emotion_vector[:, :current_seq_len, :]
|
93 |
+
|
94 |
+
|
95 |
+
# Project emotion vector to hidden size using the emotion projection layer
|
96 |
+
emotion_offset = self.emotion_proj_embed(emotion_vector) # -> [batch, current_seq_len, hidden_size]
|
97 |
+
|
98 |
+
# Add the projected emotion vector as an offset to the input embeddings
|
99 |
+
# Scaling factor (e.g., 3) can be adjusted during training
|
100 |
+
inputs_embeds = inputs_embeds + emotion_offset * 3
|
101 |
+
|
102 |
+
# 3. Pass embeddings (potentially modified) to the base model's core layers
|
103 |
+
# Crucially, pass inputs_embeds if they were modified, otherwise input_ids
|
104 |
+
# (though the base forward handles input_ids -> inputs_embeds)
|
105 |
+
outputs = self.model(
|
106 |
+
input_ids=input_ids if inputs_embeds is None else None, # Pass input_ids ONLY if inputs_embeds wasn't created/modified
|
107 |
+
attention_mask=attention_mask,
|
108 |
+
position_ids=position_ids,
|
109 |
+
past_key_values=past_key_values,
|
110 |
+
inputs_embeds=inputs_embeds, # Always pass the potentially modified inputs_embeds
|
111 |
+
use_cache=use_cache,
|
112 |
+
output_attentions=output_attentions,
|
113 |
+
output_hidden_states=True, # Need last hidden state for lm_head
|
114 |
+
return_dict=True,
|
115 |
+
**kwargs
|
116 |
+
)
|
117 |
+
|
118 |
+
# 4. Compute logits from the final hidden state
|
119 |
+
hidden_states = outputs.hidden_states[-1]
|
120 |
+
|
121 |
+
# Apply the language model head to get logits
|
122 |
+
logits = self.lm_head(hidden_states)
|
123 |
+
|
124 |
+
# 5. Compute loss if labels are provided
|
125 |
+
loss = None
|
126 |
+
if labels is not None:
|
127 |
+
# Shift tokens for autoregressive training
|
128 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
129 |
+
shift_labels = labels[..., 1:].contiguous()
|
130 |
+
loss_fct = nn.CrossEntropyLoss()
|
131 |
+
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
132 |
+
|
133 |
+
# Return the CausalLMOutputWithPast object
|
134 |
+
return CausalLMOutputWithPast(
|
135 |
+
loss=loss,
|
136 |
+
logits=logits,
|
137 |
+
past_key_values=outputs.past_key_values,
|
138 |
+
hidden_states=outputs.hidden_states, # Optionally keep all hidden states
|
139 |
+
attentions=outputs.attentions, # Optionally keep attentions
|
140 |
+
)
|
141 |
+
|
142 |
+
# This file only contains the model definition and constants.
|
143 |
+
# Training and inference logic are handled in separate files.
|
inference.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# inference.py
|
2 |
+
import os
|
3 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from emotional_gemma_clean import EmotionalLlamaModel, EMOTION_DIMENSIONS, EMOTION_DIMENSIONS_REFERENCE
|
7 |
+
from peft import PeftModel, PeftConfig
|
8 |
+
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def generate_with_emotion(
|
13 |
+
model,
|
14 |
+
tokenizer,
|
15 |
+
prompt: str,
|
16 |
+
emotion_vector: list,
|
17 |
+
max_new_tokens: int = 128,
|
18 |
+
temperature: float = 0.7,
|
19 |
+
top_k: int = 128,
|
20 |
+
top_p: float = 0.95,
|
21 |
+
do_sample: bool = True,
|
22 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
23 |
+
seed: int = None,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
Generates text using the standard model.generate() method with an emotion vector.
|
27 |
+
"""
|
28 |
+
print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}, top_k={top_k}, top_p={top_p}, do_sample={do_sample}")
|
29 |
+
if len(emotion_vector) != EMOTION_DIMENSIONS:
|
30 |
+
raise ValueError(f"Emotion vector must have {EMOTION_DIMENSIONS} dimensions.")
|
31 |
+
|
32 |
+
if seed is not None:
|
33 |
+
torch.manual_seed(seed)
|
34 |
+
if device == "cuda":
|
35 |
+
torch.cuda.manual_seed_all(seed)
|
36 |
+
|
37 |
+
current_model = model
|
38 |
+
current_model.eval()
|
39 |
+
current_model.to(device)
|
40 |
+
|
41 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
42 |
+
input_ids = inputs["input_ids"]
|
43 |
+
# Emotion vector needs to be a tensor and moved to the correct device
|
44 |
+
emotion_tensor = torch.tensor([emotion_vector], dtype=torch.float).to(device) # Shape [1, EMOTION_DIMENSIONS]
|
45 |
+
|
46 |
+
with torch.no_grad():
|
47 |
+
# Pass the emotion vector to the generate method
|
48 |
+
generated_outputs = current_model.generate(
|
49 |
+
input_ids=input_ids,
|
50 |
+
attention_mask=inputs["attention_mask"],
|
51 |
+
max_new_tokens=max_new_tokens,
|
52 |
+
temperature=temperature,
|
53 |
+
top_k=top_k,
|
54 |
+
top_p=top_p,
|
55 |
+
do_sample=do_sample,
|
56 |
+
pad_token_id=tokenizer.eos_token_id,
|
57 |
+
emotion_vector=emotion_tensor, # Pass the [1, EMOTION_DIMENSIONS] tensor
|
58 |
+
)
|
59 |
+
|
60 |
+
generated_text = tokenizer.decode(generated_outputs[0], skip_special_tokens=True)
|
61 |
+
return generated_text
|
62 |
+
|
63 |
+
# --- Main block ---
|
64 |
+
if __name__ == "__main__":
|
65 |
+
# Directory where the adapter weights and custom layer weights were saved
|
66 |
+
model_path = "./emotional-gemma-output-4"
|
67 |
+
|
68 |
+
# --- Load configuration ---
|
69 |
+
# PEFT config should tell us the base model name
|
70 |
+
try:
|
71 |
+
config = PeftConfig.from_pretrained(model_path)
|
72 |
+
model_name = config.base_model_name_or_path
|
73 |
+
print(f"Inferred base model name from PEFT config: {model_name}")
|
74 |
+
except Exception as e:
|
75 |
+
print(f"Warning: Could not infer base model name from PeftConfig in {model_path}. Using default. Error: {e}")
|
76 |
+
# Fallback if config loading fails
|
77 |
+
model_name = "google/gemma-3-1b-it"
|
78 |
+
|
79 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
|
80 |
+
if tokenizer.pad_token is None:
|
81 |
+
tokenizer.pad_token = tokenizer.eos_token
|
82 |
+
tokenizer.padding_side = "right"
|
83 |
+
|
84 |
+
# --- Load the base model ---
|
85 |
+
# The base model needs to be the custom EmotionalLlamaModel
|
86 |
+
print(f"Loading base model: {model_name}")
|
87 |
+
base_model = EmotionalLlamaModel.from_pretrained(
|
88 |
+
model_name,
|
89 |
+
trust_remote_code=True,
|
90 |
+
)
|
91 |
+
print("Base model loaded.")
|
92 |
+
|
93 |
+
# --- Load the PEFT model (adapter weights only) ---
|
94 |
+
print(f"Loading PEFT adapter from: {model_path}")
|
95 |
+
# This wraps the base_model with PEFT adapters
|
96 |
+
model = PeftModel.from_pretrained(base_model, model_path)
|
97 |
+
print(f"PEFT adapter loaded. Model type: {type(model)}")
|
98 |
+
|
99 |
+
# --- Explicitly Load Custom Layer Weights ---
|
100 |
+
# Load the state_dict for the custom layer from the saved file
|
101 |
+
custom_weights_path = os.path.join(model_path, "emotion_proj_weights.pth")
|
102 |
+
try:
|
103 |
+
if os.path.exists(custom_weights_path):
|
104 |
+
print(f"Loading custom emotion_proj_embed weights from: {custom_weights_path}")
|
105 |
+
# Load the state dict, mapping to CPU first is safer before loading into model
|
106 |
+
emotion_state_dict = torch.load(custom_weights_path, map_location="cpu")
|
107 |
+
|
108 |
+
# Access the layer within the PeftModel's base_model
|
109 |
+
# The custom layer is directly on the base model instance
|
110 |
+
emotion_layer = model.base_model.emotion_proj_embed
|
111 |
+
load_result = emotion_layer.load_state_dict(emotion_state_dict)
|
112 |
+
print(f"Custom weights loaded successfully: {load_result}")
|
113 |
+
else:
|
114 |
+
print(f"WARNING: Custom weights file not found at {custom_weights_path}. Layer 'emotion_proj_embed' will have base model's initial weights.")
|
115 |
+
|
116 |
+
except Exception as e:
|
117 |
+
print(f"ERROR loading custom emotion_proj_embed weights from {custom_weights_path}: {e}")
|
118 |
+
|
119 |
+
# Determine and move the model to the appropriate device
|
120 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
121 |
+
print(f"Moving model to device: {device}")
|
122 |
+
model.to(device)
|
123 |
+
|
124 |
+
|
125 |
+
# --- Model Weight Checks (After Loading) ---
|
126 |
+
print("\n--- Model Weight Checks (After Loading) ---")
|
127 |
+
is_peft_model = isinstance(model, PeftModel)
|
128 |
+
print(f"Is PeftModel: {is_peft_model}")
|
129 |
+
|
130 |
+
print(" emotion_proj Layer Check:")
|
131 |
+
try:
|
132 |
+
# Access the custom layer via the base_model attribute of the PeftModel
|
133 |
+
emotion_proj_layer = model.base_model.emotion_proj_embed
|
134 |
+
print(f" - emotion_proj_embed Sequential found: {emotion_proj_layer}")
|
135 |
+
# Assuming the Sequential contains a Linear layer at index 0
|
136 |
+
linear_layer = emotion_proj_layer[0]
|
137 |
+
print(f" - Linear layer inside Sequential: {linear_layer}")
|
138 |
+
if hasattr(linear_layer, 'weight'):
|
139 |
+
print(f" Weights exist, device: {linear_layer.weight.device}, dtype: {linear_layer.weight.dtype}")
|
140 |
+
print(f" Weights mean abs value: {linear_layer.weight.data.abs().mean().item()}")
|
141 |
+
else: print(" Weights attribute not found.")
|
142 |
+
if hasattr(linear_layer, 'bias') and linear_layer.bias is not None:
|
143 |
+
print(f" Bias exist, device: {linear_layer.bias.device}, dtype: {linear_layer.bias.dtype}")
|
144 |
+
print(f" Bias mean abs value: {linear_layer.bias.data.abs().mean().item()}")
|
145 |
+
else: print(" Bias attribute not found or is None.")
|
146 |
+
except Exception as e: print(f" - Error checking layer: {e}")
|
147 |
+
|
148 |
+
# Check the device of one of the model parameters
|
149 |
+
print(f"Model overall device: {next(model.parameters()).device}")
|
150 |
+
|
151 |
+
# --- Generation ---
|
152 |
+
# Prepare the prompt using the chat template
|
153 |
+
prompt = tokenizer.apply_chat_template([
|
154 |
+
{"role": "user", "content": "How are you feeling today?"},
|
155 |
+
], tokenize=False, add_generation_prompt=True)
|
156 |
+
|
157 |
+
print(f"\nPrompt:\n{prompt}")
|
158 |
+
|
159 |
+
# Define emotion vectors based on the reference dimensions
|
160 |
+
# EMOTION_DIMENSIONS_REFERENCE is defined in emotional_gemma.py
|
161 |
+
# Index mapping: 0=SADNESS_JOY, 1=FEAR_COURAGE, 2=DISGUST_ACCEPTANCE, 3=ANGER_CALMNESS,
|
162 |
+
# 4=SURPRISE_EXPECTATION, 5=DISTRUST_TRUST, 6=BOREDOM_INTEREST, 7=INDIFFERENCE_EMPATHY
|
163 |
+
joyful_emotion = [0.8, 0, 0, -0.5, 0, 0, 0, 0] # High Joy, some Calmness
|
164 |
+
sad_emotion = [-0.8, 0, 0, -0.5, 0, 0, 0, 0] # High Sadness, some Calmness
|
165 |
+
neutral_emotion = [0] * EMOTION_DIMENSIONS # All dimensions at zero
|
166 |
+
my_seed = 42 # Seed for reproducibility
|
167 |
+
|
168 |
+
# Generate text with different emotions using the recommended method
|
169 |
+
print("Generating with joyful emotion:")
|
170 |
+
joyful_text = generate_with_emotion(model, tokenizer, prompt, joyful_emotion, seed=my_seed)
|
171 |
+
print(joyful_text)
|
172 |
+
|
173 |
+
print("\nGenerating with sad emotion:")
|
174 |
+
sad_text = generate_with_emotion(model, tokenizer, prompt, sad_emotion, seed=my_seed)
|
175 |
+
print(sad_text)
|
176 |
+
|
177 |
+
print("\nGenerating with neutral emotion:")
|
178 |
+
neutral_text = generate_with_emotion(model, tokenizer, prompt, neutral_emotion, seed=my_seed)
|
179 |
+
print(neutral_text)
|