FelixTheWhale commited on
Commit
9c873e6
·
verified ·
1 Parent(s): 725abee

Upload 2 files

Browse files

class definition and inference scripts

Files changed (2) hide show
  1. emotional_gemma.py +143 -0
  2. inference.py +179 -0
emotional_gemma.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # emotional_gemma.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import Gemma3ForCausalLM
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+ from typing import Optional, Union
7
+
8
+ # Constants
9
+ MODEL_NAME = "google/gemma-3-1b-it"
10
+ EMOTION_DIMENSIONS = 8
11
+ EMOTION_DIMENSIONS_REFERENCE = [
12
+ "SADNESS_JOY", "FEAR_COURAGE", "DISGUST_ACCEPTANCE", "ANGER_CALMNESS",
13
+ "SURPRISE_EXPECTATION", "DISTRUST_TRUST", "BOREDOM_INTEREST", "INDIFFERENCE_EMPATHY"
14
+ ]
15
+
16
+ class EmotionalLlamaModel(Gemma3ForCausalLM):
17
+ """Gemma3 Causal Language Model with emotion modulation."""
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
+ self.emotion_dim = EMOTION_DIMENSIONS
21
+
22
+ # Emotion projection layer: MLP
23
+ # This layer projects the emotion vector to the hidden size of the model.
24
+ intermediate_size = config.hidden_size // 2
25
+ self.emotion_proj_embed = nn.Sequential(
26
+ nn.Linear(self.emotion_dim, intermediate_size),
27
+ nn.LayerNorm(intermediate_size),
28
+ nn.GELU(),
29
+ nn.Linear(intermediate_size, config.hidden_size),
30
+ )
31
+
32
+ # Initialization for the MLP weights
33
+ def init_weights(m):
34
+ if isinstance(m, nn.Linear):
35
+ torch.nn.init.xavier_uniform_(m.weight)
36
+ if m.bias is not None:
37
+ torch.nn.init.zeros_(m.bias)
38
+ self.emotion_proj_embed.apply(init_weights)
39
+
40
+ # Post-initialization steps from the base class
41
+ self.post_init()
42
+
43
+ def forward(
44
+ self,
45
+ input_ids: Optional[torch.LongTensor] = None,
46
+ attention_mask: Optional[torch.Tensor] = None,
47
+ position_ids: Optional[torch.LongTensor] = None,
48
+ past_key_values: Optional[list] = None,
49
+ inputs_embeds: Optional[torch.FloatTensor] = None,
50
+ labels: Optional[torch.LongTensor] = None,
51
+ use_cache: Optional[bool] = None,
52
+ output_attentions: Optional[bool] = None,
53
+ output_hidden_states: Optional[bool] = None,
54
+ return_dict: Optional[bool] = None,
55
+ emotion_vector: Optional[torch.FloatTensor] = None,
56
+ **kwargs,
57
+ ) -> Union[tuple, CausalLMOutputWithPast]:
58
+
59
+ # 1. Prepare Input Embeddings
60
+ # Get input embeddings from input_ids or use provided inputs_embeds
61
+ if input_ids is not None and inputs_embeds is not None:
62
+ raise ValueError("You cannot specify both input_ids and inputs_embeds")
63
+ elif input_ids is not None:
64
+ batch_size, seq_len = input_ids.shape
65
+ inputs_embeds = self.model.embed_tokens(input_ids)
66
+ elif inputs_embeds is not None:
67
+ batch_size, seq_len = inputs_embeds.shape[:2]
68
+ else:
69
+ # If neither is provided, it's likely a generation step using only cache.
70
+ # The base model's forward handles this by looking up the single new token ID.
71
+ # We will rely on the base model forward to handle this case and potentially
72
+ # receive `inputs_embeds` as `kwargs`.
73
+ pass # Standard generate handle embedding lookup for subsequent tokens
74
+
75
+
76
+ # 2. Apply Emotion Modulation to Embeddings
77
+ # If emotion_vector is provided and we have inputs_embeds, modulate the embeddings
78
+ if emotion_vector is not None and inputs_embeds is not None:
79
+ if emotion_vector.shape[0] != batch_size:
80
+ raise ValueError("Batch size mismatch between emotion_vector and input.")
81
+
82
+ # Ensure emotion_vector shape is [batch, seq_len, emotion_dim]
83
+ # This handles the case where a single emotion vector [batch, emotion_dim]
84
+ # is provided for the entire sequence during inference.
85
+ current_seq_len = inputs_embeds.shape[1]
86
+ if emotion_vector.dim() == 2:
87
+ emotion_vector = emotion_vector.unsqueeze(1).expand(-1, current_seq_len, -1)
88
+ elif emotion_vector.shape[1] != current_seq_len:
89
+ # This case might occur if the emotion vector is longer than the current
90
+ # input chunk (e.g., during token-by-token generation after prompt).
91
+ # We take the slice corresponding to the current input.
92
+ emotion_vector = emotion_vector[:, :current_seq_len, :]
93
+
94
+
95
+ # Project emotion vector to hidden size using the emotion projection layer
96
+ emotion_offset = self.emotion_proj_embed(emotion_vector) # -> [batch, current_seq_len, hidden_size]
97
+
98
+ # Add the projected emotion vector as an offset to the input embeddings
99
+ # Scaling factor (e.g., 3) can be adjusted during training
100
+ inputs_embeds = inputs_embeds + emotion_offset * 3
101
+
102
+ # 3. Pass embeddings (potentially modified) to the base model's core layers
103
+ # Crucially, pass inputs_embeds if they were modified, otherwise input_ids
104
+ # (though the base forward handles input_ids -> inputs_embeds)
105
+ outputs = self.model(
106
+ input_ids=input_ids if inputs_embeds is None else None, # Pass input_ids ONLY if inputs_embeds wasn't created/modified
107
+ attention_mask=attention_mask,
108
+ position_ids=position_ids,
109
+ past_key_values=past_key_values,
110
+ inputs_embeds=inputs_embeds, # Always pass the potentially modified inputs_embeds
111
+ use_cache=use_cache,
112
+ output_attentions=output_attentions,
113
+ output_hidden_states=True, # Need last hidden state for lm_head
114
+ return_dict=True,
115
+ **kwargs
116
+ )
117
+
118
+ # 4. Compute logits from the final hidden state
119
+ hidden_states = outputs.hidden_states[-1]
120
+
121
+ # Apply the language model head to get logits
122
+ logits = self.lm_head(hidden_states)
123
+
124
+ # 5. Compute loss if labels are provided
125
+ loss = None
126
+ if labels is not None:
127
+ # Shift tokens for autoregressive training
128
+ shift_logits = logits[..., :-1, :].contiguous()
129
+ shift_labels = labels[..., 1:].contiguous()
130
+ loss_fct = nn.CrossEntropyLoss()
131
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
132
+
133
+ # Return the CausalLMOutputWithPast object
134
+ return CausalLMOutputWithPast(
135
+ loss=loss,
136
+ logits=logits,
137
+ past_key_values=outputs.past_key_values,
138
+ hidden_states=outputs.hidden_states, # Optionally keep all hidden states
139
+ attentions=outputs.attentions, # Optionally keep attentions
140
+ )
141
+
142
+ # This file only contains the model definition and constants.
143
+ # Training and inference logic are handled in separate files.
inference.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import os
3
+ os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
4
+ import torch
5
+ from transformers import AutoTokenizer
6
+ from emotional_gemma_clean import EmotionalLlamaModel, EMOTION_DIMENSIONS, EMOTION_DIMENSIONS_REFERENCE
7
+ from peft import PeftModel, PeftConfig
8
+
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def generate_with_emotion(
13
+ model,
14
+ tokenizer,
15
+ prompt: str,
16
+ emotion_vector: list,
17
+ max_new_tokens: int = 128,
18
+ temperature: float = 0.7,
19
+ top_k: int = 128,
20
+ top_p: float = 0.95,
21
+ do_sample: bool = True,
22
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
23
+ seed: int = None,
24
+ ):
25
+ """
26
+ Generates text using the standard model.generate() method with an emotion vector.
27
+ """
28
+ print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}, top_k={top_k}, top_p={top_p}, do_sample={do_sample}")
29
+ if len(emotion_vector) != EMOTION_DIMENSIONS:
30
+ raise ValueError(f"Emotion vector must have {EMOTION_DIMENSIONS} dimensions.")
31
+
32
+ if seed is not None:
33
+ torch.manual_seed(seed)
34
+ if device == "cuda":
35
+ torch.cuda.manual_seed_all(seed)
36
+
37
+ current_model = model
38
+ current_model.eval()
39
+ current_model.to(device)
40
+
41
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
42
+ input_ids = inputs["input_ids"]
43
+ # Emotion vector needs to be a tensor and moved to the correct device
44
+ emotion_tensor = torch.tensor([emotion_vector], dtype=torch.float).to(device) # Shape [1, EMOTION_DIMENSIONS]
45
+
46
+ with torch.no_grad():
47
+ # Pass the emotion vector to the generate method
48
+ generated_outputs = current_model.generate(
49
+ input_ids=input_ids,
50
+ attention_mask=inputs["attention_mask"],
51
+ max_new_tokens=max_new_tokens,
52
+ temperature=temperature,
53
+ top_k=top_k,
54
+ top_p=top_p,
55
+ do_sample=do_sample,
56
+ pad_token_id=tokenizer.eos_token_id,
57
+ emotion_vector=emotion_tensor, # Pass the [1, EMOTION_DIMENSIONS] tensor
58
+ )
59
+
60
+ generated_text = tokenizer.decode(generated_outputs[0], skip_special_tokens=True)
61
+ return generated_text
62
+
63
+ # --- Main block ---
64
+ if __name__ == "__main__":
65
+ # Directory where the adapter weights and custom layer weights were saved
66
+ model_path = "./emotional-gemma-output-4"
67
+
68
+ # --- Load configuration ---
69
+ # PEFT config should tell us the base model name
70
+ try:
71
+ config = PeftConfig.from_pretrained(model_path)
72
+ model_name = config.base_model_name_or_path
73
+ print(f"Inferred base model name from PEFT config: {model_name}")
74
+ except Exception as e:
75
+ print(f"Warning: Could not infer base model name from PeftConfig in {model_path}. Using default. Error: {e}")
76
+ # Fallback if config loading fails
77
+ model_name = "google/gemma-3-1b-it"
78
+
79
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
80
+ if tokenizer.pad_token is None:
81
+ tokenizer.pad_token = tokenizer.eos_token
82
+ tokenizer.padding_side = "right"
83
+
84
+ # --- Load the base model ---
85
+ # The base model needs to be the custom EmotionalLlamaModel
86
+ print(f"Loading base model: {model_name}")
87
+ base_model = EmotionalLlamaModel.from_pretrained(
88
+ model_name,
89
+ trust_remote_code=True,
90
+ )
91
+ print("Base model loaded.")
92
+
93
+ # --- Load the PEFT model (adapter weights only) ---
94
+ print(f"Loading PEFT adapter from: {model_path}")
95
+ # This wraps the base_model with PEFT adapters
96
+ model = PeftModel.from_pretrained(base_model, model_path)
97
+ print(f"PEFT adapter loaded. Model type: {type(model)}")
98
+
99
+ # --- Explicitly Load Custom Layer Weights ---
100
+ # Load the state_dict for the custom layer from the saved file
101
+ custom_weights_path = os.path.join(model_path, "emotion_proj_weights.pth")
102
+ try:
103
+ if os.path.exists(custom_weights_path):
104
+ print(f"Loading custom emotion_proj_embed weights from: {custom_weights_path}")
105
+ # Load the state dict, mapping to CPU first is safer before loading into model
106
+ emotion_state_dict = torch.load(custom_weights_path, map_location="cpu")
107
+
108
+ # Access the layer within the PeftModel's base_model
109
+ # The custom layer is directly on the base model instance
110
+ emotion_layer = model.base_model.emotion_proj_embed
111
+ load_result = emotion_layer.load_state_dict(emotion_state_dict)
112
+ print(f"Custom weights loaded successfully: {load_result}")
113
+ else:
114
+ print(f"WARNING: Custom weights file not found at {custom_weights_path}. Layer 'emotion_proj_embed' will have base model's initial weights.")
115
+
116
+ except Exception as e:
117
+ print(f"ERROR loading custom emotion_proj_embed weights from {custom_weights_path}: {e}")
118
+
119
+ # Determine and move the model to the appropriate device
120
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121
+ print(f"Moving model to device: {device}")
122
+ model.to(device)
123
+
124
+
125
+ # --- Model Weight Checks (After Loading) ---
126
+ print("\n--- Model Weight Checks (After Loading) ---")
127
+ is_peft_model = isinstance(model, PeftModel)
128
+ print(f"Is PeftModel: {is_peft_model}")
129
+
130
+ print(" emotion_proj Layer Check:")
131
+ try:
132
+ # Access the custom layer via the base_model attribute of the PeftModel
133
+ emotion_proj_layer = model.base_model.emotion_proj_embed
134
+ print(f" - emotion_proj_embed Sequential found: {emotion_proj_layer}")
135
+ # Assuming the Sequential contains a Linear layer at index 0
136
+ linear_layer = emotion_proj_layer[0]
137
+ print(f" - Linear layer inside Sequential: {linear_layer}")
138
+ if hasattr(linear_layer, 'weight'):
139
+ print(f" Weights exist, device: {linear_layer.weight.device}, dtype: {linear_layer.weight.dtype}")
140
+ print(f" Weights mean abs value: {linear_layer.weight.data.abs().mean().item()}")
141
+ else: print(" Weights attribute not found.")
142
+ if hasattr(linear_layer, 'bias') and linear_layer.bias is not None:
143
+ print(f" Bias exist, device: {linear_layer.bias.device}, dtype: {linear_layer.bias.dtype}")
144
+ print(f" Bias mean abs value: {linear_layer.bias.data.abs().mean().item()}")
145
+ else: print(" Bias attribute not found or is None.")
146
+ except Exception as e: print(f" - Error checking layer: {e}")
147
+
148
+ # Check the device of one of the model parameters
149
+ print(f"Model overall device: {next(model.parameters()).device}")
150
+
151
+ # --- Generation ---
152
+ # Prepare the prompt using the chat template
153
+ prompt = tokenizer.apply_chat_template([
154
+ {"role": "user", "content": "How are you feeling today?"},
155
+ ], tokenize=False, add_generation_prompt=True)
156
+
157
+ print(f"\nPrompt:\n{prompt}")
158
+
159
+ # Define emotion vectors based on the reference dimensions
160
+ # EMOTION_DIMENSIONS_REFERENCE is defined in emotional_gemma.py
161
+ # Index mapping: 0=SADNESS_JOY, 1=FEAR_COURAGE, 2=DISGUST_ACCEPTANCE, 3=ANGER_CALMNESS,
162
+ # 4=SURPRISE_EXPECTATION, 5=DISTRUST_TRUST, 6=BOREDOM_INTEREST, 7=INDIFFERENCE_EMPATHY
163
+ joyful_emotion = [0.8, 0, 0, -0.5, 0, 0, 0, 0] # High Joy, some Calmness
164
+ sad_emotion = [-0.8, 0, 0, -0.5, 0, 0, 0, 0] # High Sadness, some Calmness
165
+ neutral_emotion = [0] * EMOTION_DIMENSIONS # All dimensions at zero
166
+ my_seed = 42 # Seed for reproducibility
167
+
168
+ # Generate text with different emotions using the recommended method
169
+ print("Generating with joyful emotion:")
170
+ joyful_text = generate_with_emotion(model, tokenizer, prompt, joyful_emotion, seed=my_seed)
171
+ print(joyful_text)
172
+
173
+ print("\nGenerating with sad emotion:")
174
+ sad_text = generate_with_emotion(model, tokenizer, prompt, sad_emotion, seed=my_seed)
175
+ print(sad_text)
176
+
177
+ print("\nGenerating with neutral emotion:")
178
+ neutral_text = generate_with_emotion(model, tokenizer, prompt, neutral_emotion, seed=my_seed)
179
+ print(neutral_text)