Emotional-Gemma-3-1B / emotional_gemma.py
FelixTheWhale's picture
Upload 2 files
9c873e6 verified
# emotional_gemma.py
import torch
import torch.nn as nn
from transformers import Gemma3ForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Optional, Union
# Constants
MODEL_NAME = "google/gemma-3-1b-it"
EMOTION_DIMENSIONS = 8
EMOTION_DIMENSIONS_REFERENCE = [
"SADNESS_JOY", "FEAR_COURAGE", "DISGUST_ACCEPTANCE", "ANGER_CALMNESS",
"SURPRISE_EXPECTATION", "DISTRUST_TRUST", "BOREDOM_INTEREST", "INDIFFERENCE_EMPATHY"
]
class EmotionalLlamaModel(Gemma3ForCausalLM):
"""Gemma3 Causal Language Model with emotion modulation."""
def __init__(self, config):
super().__init__(config)
self.emotion_dim = EMOTION_DIMENSIONS
# Emotion projection layer: MLP
# This layer projects the emotion vector to the hidden size of the model.
intermediate_size = config.hidden_size // 2
self.emotion_proj_embed = nn.Sequential(
nn.Linear(self.emotion_dim, intermediate_size),
nn.LayerNorm(intermediate_size),
nn.GELU(),
nn.Linear(intermediate_size, config.hidden_size),
)
# Initialization for the MLP weights
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
self.emotion_proj_embed.apply(init_weights)
# Post-initialization steps from the base class
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
emotion_vector: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Union[tuple, CausalLMOutputWithPast]:
# 1. Prepare Input Embeddings
# Get input embeddings from input_ids or use provided inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds")
elif input_ids is not None:
batch_size, seq_len = input_ids.shape
inputs_embeds = self.model.embed_tokens(input_ids)
elif inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
# If neither is provided, it's likely a generation step using only cache.
# The base model's forward handles this by looking up the single new token ID.
# We will rely on the base model forward to handle this case and potentially
# receive `inputs_embeds` as `kwargs`.
pass # Standard generate handle embedding lookup for subsequent tokens
# 2. Apply Emotion Modulation to Embeddings
# If emotion_vector is provided and we have inputs_embeds, modulate the embeddings
if emotion_vector is not None and inputs_embeds is not None:
if emotion_vector.shape[0] != batch_size:
raise ValueError("Batch size mismatch between emotion_vector and input.")
# Ensure emotion_vector shape is [batch, seq_len, emotion_dim]
# This handles the case where a single emotion vector [batch, emotion_dim]
# is provided for the entire sequence during inference.
current_seq_len = inputs_embeds.shape[1]
if emotion_vector.dim() == 2:
emotion_vector = emotion_vector.unsqueeze(1).expand(-1, current_seq_len, -1)
elif emotion_vector.shape[1] != current_seq_len:
# This case might occur if the emotion vector is longer than the current
# input chunk (e.g., during token-by-token generation after prompt).
# We take the slice corresponding to the current input.
emotion_vector = emotion_vector[:, :current_seq_len, :]
# Project emotion vector to hidden size using the emotion projection layer
emotion_offset = self.emotion_proj_embed(emotion_vector) # -> [batch, current_seq_len, hidden_size]
# Add the projected emotion vector as an offset to the input embeddings
# Scaling factor (e.g., 3) can be adjusted during training
inputs_embeds = inputs_embeds + emotion_offset * 3
# 3. Pass embeddings (potentially modified) to the base model's core layers
# Crucially, pass inputs_embeds if they were modified, otherwise input_ids
# (though the base forward handles input_ids -> inputs_embeds)
outputs = self.model(
input_ids=input_ids if inputs_embeds is None else None, # Pass input_ids ONLY if inputs_embeds wasn't created/modified
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds, # Always pass the potentially modified inputs_embeds
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True, # Need last hidden state for lm_head
return_dict=True,
**kwargs
)
# 4. Compute logits from the final hidden state
hidden_states = outputs.hidden_states[-1]
# Apply the language model head to get logits
logits = self.lm_head(hidden_states)
# 5. Compute loss if labels are provided
loss = None
if labels is not None:
# Shift tokens for autoregressive training
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
# Return the CausalLMOutputWithPast object
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, # Optionally keep all hidden states
attentions=outputs.attentions, # Optionally keep attentions
)
# This file only contains the model definition and constants.
# Training and inference logic are handled in separate files.