|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import Gemma3ForCausalLM
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
from typing import Optional, Union
|
|
|
|
|
|
MODEL_NAME = "google/gemma-3-1b-it"
|
|
EMOTION_DIMENSIONS = 8
|
|
EMOTION_DIMENSIONS_REFERENCE = [
|
|
"SADNESS_JOY", "FEAR_COURAGE", "DISGUST_ACCEPTANCE", "ANGER_CALMNESS",
|
|
"SURPRISE_EXPECTATION", "DISTRUST_TRUST", "BOREDOM_INTEREST", "INDIFFERENCE_EMPATHY"
|
|
]
|
|
|
|
class EmotionalLlamaModel(Gemma3ForCausalLM):
|
|
"""Gemma3 Causal Language Model with emotion modulation."""
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.emotion_dim = EMOTION_DIMENSIONS
|
|
|
|
|
|
|
|
intermediate_size = config.hidden_size // 2
|
|
self.emotion_proj_embed = nn.Sequential(
|
|
nn.Linear(self.emotion_dim, intermediate_size),
|
|
nn.LayerNorm(intermediate_size),
|
|
nn.GELU(),
|
|
nn.Linear(intermediate_size, config.hidden_size),
|
|
)
|
|
|
|
|
|
def init_weights(m):
|
|
if isinstance(m, nn.Linear):
|
|
torch.nn.init.xavier_uniform_(m.weight)
|
|
if m.bias is not None:
|
|
torch.nn.init.zeros_(m.bias)
|
|
self.emotion_proj_embed.apply(init_weights)
|
|
|
|
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[list] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
emotion_vector: Optional[torch.FloatTensor] = None,
|
|
**kwargs,
|
|
) -> Union[tuple, CausalLMOutputWithPast]:
|
|
|
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds")
|
|
elif input_ids is not None:
|
|
batch_size, seq_len = input_ids.shape
|
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
elif inputs_embeds is not None:
|
|
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
else:
|
|
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
if emotion_vector is not None and inputs_embeds is not None:
|
|
if emotion_vector.shape[0] != batch_size:
|
|
raise ValueError("Batch size mismatch between emotion_vector and input.")
|
|
|
|
|
|
|
|
|
|
current_seq_len = inputs_embeds.shape[1]
|
|
if emotion_vector.dim() == 2:
|
|
emotion_vector = emotion_vector.unsqueeze(1).expand(-1, current_seq_len, -1)
|
|
elif emotion_vector.shape[1] != current_seq_len:
|
|
|
|
|
|
|
|
emotion_vector = emotion_vector[:, :current_seq_len, :]
|
|
|
|
|
|
|
|
emotion_offset = self.emotion_proj_embed(emotion_vector)
|
|
|
|
|
|
|
|
inputs_embeds = inputs_embeds + emotion_offset * 3
|
|
|
|
|
|
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids if inputs_embeds is None else None,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=True,
|
|
return_dict=True,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
hidden_states = outputs.hidden_states[-1]
|
|
|
|
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
|
|
|
|
|
return CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
|