seemggoel commited on
Commit
74f9db3
·
verified ·
1 Parent(s): 41a91e8

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.py +65 -0
  2. features.py +163 -0
  3. modal.py +300 -0
config.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoModelForCausalLM,
3
+ BitsAndBytesConfig,
4
+ TrainingArguments,
5
+ Trainer,
6
+ TrainerCallback,
7
+ EarlyStoppingCallback
8
+ )
9
+ # def get_training_args(output_dir):
10
+ # return TrainingArguments(
11
+ # output_dir=output_dir,
12
+ # num_train_epochs=5, # Increased from 3
13
+ # per_device_train_batch_size=4,
14
+ # per_device_eval_batch_size=4,
15
+ # gradient_accumulation_steps=8, # Increased from 4
16
+ # evaluation_strategy="steps",
17
+ # eval_steps=50, # More frequent evaluation
18
+ # save_strategy="steps",
19
+ # save_steps=50,
20
+ # logging_dir=f"{output_dir}/logs",
21
+ # logging_strategy="steps",
22
+ # logging_steps=10,
23
+ # learning_rate=5e-5, # Lower learning rate for continued training
24
+ # weight_decay=0.02, # Increased from 0.01
25
+ # warmup_ratio=0.1, # Increased from previous value
26
+ # lr_scheduler_type="cosine_with_restarts", # Changed from cosine
27
+ # load_best_model_at_end=True,
28
+ # metric_for_best_model="eval_loss",
29
+ # greater_is_better=False,
30
+ # fp16=True,
31
+ # gradient_checkpointing=True,
32
+ # gradient_checkpointing_kwargs={"use_reentrant": False},
33
+ # report_to="tensorboard",
34
+ # remove_unused_columns=False,
35
+ # optim="adamw_torch_fused", # Using fused optimizer
36
+ # max_grad_norm=0.5, # Added gradient clipping
37
+ # )
38
+
39
+
40
+ def get_training_args(output_dir):
41
+ return TrainingArguments(
42
+ output_dir=output_dir,
43
+ num_train_epochs=3, # Reduced epochs for continued training
44
+ per_device_train_batch_size=2, # Reduced batch size
45
+ per_device_eval_batch_size=2,
46
+ gradient_accumulation_steps=16, # Increased for stability
47
+ evaluation_strategy="steps",
48
+ eval_steps=25, # More frequent evaluation
49
+ save_strategy="steps",
50
+ save_steps=25,
51
+ learning_rate=1e-5, # Lower learning rate for fine-tuning
52
+ weight_decay=0.03, # Increased for better regularization
53
+ warmup_ratio=0.15, # Increased warmup
54
+ lr_scheduler_type="cosine_with_restarts",
55
+ load_best_model_at_end=True,
56
+ metric_for_best_model="eval_loss",
57
+ greater_is_better=False,
58
+ fp16=True,
59
+ gradient_checkpointing=True,
60
+ gradient_checkpointing_kwargs={"use_reentrant": False},
61
+ report_to="tensorboard",
62
+ remove_unused_columns=False,
63
+ optim="adamw_torch_fused",
64
+ max_grad_norm=0.3, # Reduced for stability
65
+ )
features.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """prepare_dataset_tokenise.py - Optimized for Multimodal Fine-tuning"""
3
+
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, WhisperProcessor, WhisperForConditionalGeneration, PreTrainedModel,BitsAndBytesConfig
9
+ from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, TaskType
10
+ from datasets import Dataset, DatasetDict
11
+ from tqdm import tqdm
12
+ import json
13
+ import librosa
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Union
16
+ import gc
17
+ from transformers import EarlyStoppingCallback
18
+ from torch.utils.checkpoint import checkpoint_sequential
19
+
20
+ # Initialize Whisper components for audio transcription
21
+ whisper_model_name = "openai/whisper-small"
22
+ whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
23
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
24
+
25
+ # Load embeddings with error handling
26
+ def load_embeddings(file_path):
27
+ try:
28
+ data = np.load(file_path)
29
+ if 'image_ids' in data and 'embeddings' in data:
30
+ return {'ids': data['image_ids'], 'embeddings': data['embeddings']}
31
+ else:
32
+ raise ValueError(f"Unexpected structure in {file_path}.")
33
+ except Exception as e:
34
+ print(f"Error loading embeddings: {e}")
35
+ return None
36
+
37
+ # Process audio files
38
+ def transcribe_speech(audiopath):
39
+ try:
40
+ speech, rate = librosa.load(audiopath, sr=16000)
41
+ audio_input = whisper_processor(speech, return_tensors="pt", sampling_rate=16000)
42
+ with torch.no_grad():
43
+ generated_ids = whisper_model.generate(audio_input["input_features"])
44
+ return whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
45
+ except Exception as e:
46
+ print(f"Error transcribing audio: {e}")
47
+ return None
48
+
49
+
50
+
51
+ @dataclass
52
+ class MultimodalDataCollator:
53
+ tokenizer: Any
54
+
55
+ # def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
56
+ # batch = {"input_ids": self.tokenizer.pad({"input_ids": [f["input_ids"] for f in features]}, padding=True, return_tensors="pt")["input_ids"]}
57
+ # batch["attention_mask"] = torch.ones_like(batch["input_ids"])
58
+ # batch["labels"] = batch["input_ids"].clone()
59
+ # if "image_embeddings" in features[0]:
60
+ # batch["image_embeddings"] = torch.stack([f["image_embeddings"] for f in features])
61
+ # if "audio_embeddings" in features[0]:
62
+ # batch["audio_embeddings"] = torch.stack([f["audio_embeddings"] for f in features])
63
+ # return batch
64
+ #Updated on 30th November for managing the mismatching shape
65
+ #boolean index did not match indexed array along dimension 1; dimension is 591 but corresponding boolean dimension is 590
66
+ from dataclasses import dataclass
67
+ from typing import Any, Dict, List
68
+ import torch
69
+
70
+ @dataclass
71
+ class MultimodalDataCollator:
72
+ tokenizer: Any
73
+
74
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
75
+ # Extract input_ids, attention_mask, and labels
76
+ input_ids = [f["input_ids"] for f in features]
77
+ attention_mask = [f["attention_mask"] for f in features]
78
+ labels = [f["labels"] for f in features]
79
+
80
+ # Convert tensors to lists if they are tensors
81
+ input_ids = [ids.tolist() if isinstance(ids, torch.Tensor) else ids for ids in input_ids]
82
+ attention_mask = [mask.tolist() if isinstance(mask, torch.Tensor) else mask for mask in attention_mask]
83
+ labels = [lab.tolist() if isinstance(lab, torch.Tensor) else lab for lab in labels]
84
+
85
+ # Pad sequences to the maximum length in the batch
86
+ max_length = max(len(ids) for ids in input_ids)
87
+ padded_input_ids = [ids + [self.tokenizer.pad_token_id] * (max_length - len(ids)) for ids in input_ids]
88
+ padded_attention_mask = [mask + [0] * (max_length - len(mask)) for mask in attention_mask]
89
+ padded_labels = [lab + [-100] * (max_length - len(lab)) for lab in labels]
90
+
91
+ # Create a batch dictionary
92
+ batch = {
93
+ "input_ids": torch.tensor(padded_input_ids),
94
+ "attention_mask": torch.tensor(padded_attention_mask),
95
+ "labels": torch.tensor(padded_labels)
96
+ }
97
+
98
+ # Handle image and audio embeddings if present
99
+ if "image_embeddings" in features[0]:
100
+ batch["image_embeddings"] = torch.stack([f["image_embeddings"] for f in features])
101
+ if "audio_embeddings" in features[0]:
102
+ batch["audio_embeddings"] = torch.stack([f["audio_embeddings"] for f in features])
103
+
104
+ return batch
105
+
106
+ # Dataset preparation with better error handling and modularization
107
+ def prepare_dataset(image_embeddings_path, dataset_path, cache_dir=None):
108
+ image_embeddings = load_embeddings(image_embeddings_path)
109
+ with open(dataset_path, 'r') as f:
110
+ data = json.load(f)
111
+ processed_data = [{"conversation": item["conversations"], "image_embedding": image_embeddings['embeddings'][np.where(image_embeddings['ids'] == item['image'])[0][0]] if image_embeddings and "image" in item else None, "audio_path": item.get("audio")} for item in data]
112
+ dataset = Dataset.from_dict({"conversation": [item["conversation"] for item in processed_data], "image_embedding": [item.get("image_embedding") for item in processed_data], "audio_path": [item.get("audio_path") for item in processed_data]})
113
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", trust_remote_code=True)
114
+ tokenizer.pad_token = tokenizer.eos_token
115
+ tokenizer.padding_side = "right"
116
+ # tokenizer.chat_template = """
117
+ # {% for message in messages %}
118
+ # {% if message.role == 'system' %}<|system|>{{message.content}}<|endoftext|>{% elif message.role == 'user' %}<|user|>{{message.content}}<|endoftext|>{% elif message.role == 'assistant' %}<|assistant|>{{message.content}}<|endoftext|>{% endif %}{% endfor %}
119
+ # """
120
+ tokenizer.chat_template = """
121
+ {% for message in messages %}
122
+ {% if message.role == 'system' %}<|system|>{{message.content}}<|endofsystem|>{% elif message.role == 'user' %}<|user|>{{message.content}}<|endoftext|>{% elif message.role == 'assistant' %}<|assistant|>{{message.content}}<|endoftext|>{% endif %}{% endfor %}
123
+ """
124
+ prepared_dataset = dataset.map(lambda examples: prepare_example(examples, tokenizer), batched=True, remove_columns=dataset.column_names, batch_size=1).with_format("torch")
125
+ # dataset_dict = DatasetDict({"train": prepared_dataset.train_test_split(test_size=0.1)["train"], "test": prepared_dataset.train_test_split(test_size=0.1)["test"]})
126
+ dataset_dict = prepared_dataset.train_test_split(test_size=0.2) # Split into train and a combined validation/test set
127
+ dataset_dict["validation"] = dataset_dict["test"].train_test_split(test_size=0.5)["train"] # Split the combined set in half
128
+ dataset_dict["test"] = dataset_dict["test"].train_test_split(test_size=0.5)["test"] # Split the combined set in half
129
+
130
+ # Assuming you have your dataset in 'dataset_dict'
131
+ drive_path = "/content/drive/MyDrive/Cap_dataset" # Replace with your desired path in Google Drive
132
+ dataset_dict.save_to_disk(drive_path)
133
+
134
+
135
+ # if cache_dir:
136
+ # os.makedirs(cache_dir, exist_ok=True)
137
+ # dataset_dict.save_to_disk(cache_dir)
138
+ return dataset_dict, tokenizer
139
+
140
+ # Example preparation for dataset rows
141
+ def prepare_example(examples, tokenizer):
142
+ image_embeddings, audio_embeddings, tokenized_inputs = [], [], []
143
+ for idx, conv in enumerate(examples["conversation"]):
144
+ image_embedding = torch.tensor(examples["image_embedding"][idx]) if examples["image_embedding"][idx] is not None else None
145
+ transcription = transcribe_speech(examples["audio_path"][idx]) if "audio_path" in examples and examples["audio_path"][idx] else None
146
+ for i in range(0, len(conv), 2):
147
+ if i + 1 < len(conv):
148
+ human_msg = conv[i]["value"].replace("<image>", "").replace("<audio>", "").strip()
149
+ if transcription:
150
+ human_msg += f"\nAudio Transcription: {transcription}"
151
+ gpt_msg = conv[i + 1]["value"]
152
+ tokenized_input = tokenizer.apply_chat_template([{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": f"{human_msg}"}, {"role": "assistant", "content": gpt_msg}], return_tensors="pt", padding=True)
153
+ tokenized_inputs.append(tokenized_input.squeeze(0))
154
+ if image_embedding is not None:
155
+ image_embeddings.append(image_embedding)
156
+ max_length = max(input.shape[0] for input in tokenized_inputs)
157
+ padded_inputs = [torch.nn.functional.pad(input, (0, max_length - input.shape[0])) for input in tokenized_inputs]
158
+ result = {"input_ids": torch.stack(padded_inputs), "attention_mask": torch.stack(padded_inputs).ne(tokenizer.pad_token_id).long(), "labels": torch.stack(padded_inputs).clone()}
159
+ if image_embeddings:
160
+ result["image_embeddings"] = torch.stack(image_embeddings)
161
+ if audio_embeddings:
162
+ result["audio_embeddings"] = torch.stack(audio_embeddings)
163
+ return result
modal.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, WhisperProcessor, WhisperForConditionalGeneration, PreTrainedModel,BitsAndBytesConfig
6
+ from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, TaskType
7
+ from tqdm import tqdm
8
+ import json
9
+ import librosa
10
+ from dataclasses import dataclass
11
+ from typing import Any, Dict, List, Union
12
+ import gc
13
+ import torch.nn.functional as F
14
+
15
+ # # Define multimodal projector class
16
+ # class ProjectionBlock(nn.Module):
17
+ # def __init__(self, input_dim, output_dim):
18
+ # super().__init__()
19
+ # self.pre_norm = nn.LayerNorm(input_dim)
20
+ # self.proj = nn.Sequential(nn.Linear(input_dim, output_dim), nn.GELU(), nn.Linear(output_dim, output_dim))
21
+
22
+ # def forward(self, x):
23
+ # return self.proj(self.pre_norm(x))
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+ class CrossAttentionBlock(nn.Module):
29
+ def __init__(self, embed_dim, num_heads=8, dropout=0.1):
30
+ super().__init__()
31
+ self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
32
+ self.norm1 = nn.LayerNorm(embed_dim)
33
+ self.norm2 = nn.LayerNorm(embed_dim)
34
+ self.ffn = nn.Sequential(
35
+ nn.Linear(embed_dim, embed_dim * 4),
36
+ nn.GELU(),
37
+ nn.Linear(embed_dim * 4, embed_dim),
38
+ nn.Dropout(dropout)
39
+ )
40
+
41
+ def forward(self, x, context):
42
+ # Self attention
43
+ attended, _ = self.attention(
44
+ query=self.norm1(x),
45
+ key=self.norm1(context),
46
+ value=self.norm1(context)
47
+ )
48
+ x = x + attended
49
+
50
+ # FFN
51
+ x = x + self.ffn(self.norm2(x))
52
+ return x
53
+
54
+ ## Updated on 23rd November
55
+ class ProjectionBlock(nn.Module):
56
+ def __init__(self, input_dim, output_dim):
57
+ super().__init__()
58
+ self.pre_norm = nn.LayerNorm(input_dim)
59
+ self.proj = nn.Sequential(
60
+ nn.Linear(input_dim, output_dim * 2), # Increase intermediate dimension
61
+ nn.GELU(),
62
+ nn.Linear(output_dim * 2, output_dim) # Project to final dimension
63
+ )
64
+
65
+ def forward(self, x):
66
+ # Add shape validation
67
+ if len(x.shape) == 2: # If input is [batch_size, features]
68
+ return self.proj(self.pre_norm(x))
69
+ elif len(x.shape) == 3: # If input is [batch_size, seq_len, features]
70
+ return self.proj(self.pre_norm(x.mean(dim=1))) # Pool sequence dimension
71
+ else:
72
+ raise ValueError(f"Unexpected input shape: {x.shape}")
73
+
74
+ ##Updated on 23rd November
75
+ # class EnhancedMultimodalProjector(nn.Module):
76
+ # def __init__(self, image_input_dim, audio_input_dim, output_dim, num_heads=8):
77
+ # super().__init__()
78
+
79
+ # # Adjust projectors to match Phi-3's hidden size (1024)
80
+ # self.image_proj = ProjectionBlock(image_input_dim, output_dim)
81
+ # self.audio_proj = ProjectionBlock(audio_input_dim, output_dim)
82
+
83
+ # # Cross-attention blocks
84
+ # self.image_audio_cross_attn = CrossAttentionBlock(output_dim, num_heads)
85
+ # self.audio_image_cross_attn = CrossAttentionBlock(output_dim, num_heads)
86
+
87
+ # # Final fusion layer
88
+ # self.fusion_layer = nn.Sequential(
89
+ # nn.LayerNorm(output_dim * 2),
90
+ # nn.Linear(output_dim * 2, output_dim),
91
+ # nn.GELU(),
92
+ # nn.Linear(output_dim, output_dim)
93
+ # )
94
+ class EnhancedMultimodalProjector(nn.Module):
95
+ def __init__(self, image_input_dim, audio_input_dim=1024, output_dim=1024, num_heads=8):
96
+ super().__init__()
97
+ self.image_proj = ProjectionBlock(image_input_dim, output_dim)
98
+ self.audio_proj = ProjectionBlock(audio_input_dim, output_dim)
99
+ self.image_audio_cross_attn = CrossAttentionBlock(output_dim, num_heads)
100
+ self.audio_image_cross_attn = CrossAttentionBlock(output_dim, num_heads)
101
+ self.fusion_layer = nn.Sequential(
102
+ nn.LayerNorm(output_dim * 2),
103
+ nn.Linear(output_dim * 2, output_dim),
104
+ nn.GELU(),
105
+ nn.Linear(output_dim, output_dim)
106
+ )
107
+
108
+ def forward(self, image_embedding=None, audio_embedding=None):
109
+ # Add shape validation and adjustment
110
+ if image_embedding is not None and image_embedding.dim() < 2:
111
+ raise ValueError("Expected `image_embedding` to have at least 2 dimensions.")
112
+ if audio_embedding is not None and audio_embedding.dim() < 2:
113
+ raise ValueError("Expected `audio_embedding` to have at least 2 dimensions.")
114
+ if image_embedding is not None and len(image_embedding.shape) == 2:
115
+ image_embedding = image_embedding.unsqueeze(1) # Add sequence dimension
116
+ if audio_embedding is not None and len(audio_embedding.shape) == 2:
117
+ audio_embedding = audio_embedding.unsqueeze(1) # Add sequence dimension
118
+
119
+ # Initial projections
120
+ projected_image = self.image_proj(image_embedding) if image_embedding is not None else None
121
+ projected_audio = self.audio_proj(audio_embedding) if audio_embedding is not None else None
122
+
123
+ if projected_image is not None and projected_audio is not None:
124
+ # Ensure correct shapes for cross-attention
125
+ attended_image = self.image_audio_cross_attn(projected_image, projected_audio)
126
+ attended_audio = self.audio_image_cross_attn(projected_audio, projected_image)
127
+
128
+ # Combine the attended features
129
+ fused_features = torch.cat([attended_image, attended_audio], dim=-1)
130
+ final_output = self.fusion_layer(fused_features)
131
+
132
+ return final_output, final_output
133
+
134
+ elif projected_image is not None:
135
+ return projected_image, None
136
+ elif projected_audio is not None:
137
+ return None, projected_audio
138
+ else:
139
+ return None, None
140
+
141
+ # Update the Phi3WithProjector to use the enhanced projector
142
+ class Phi3WithProjector(PreTrainedModel):
143
+ def __init__(self, config, phi3_model, projector):
144
+ super().__init__(config)
145
+ self.phi3_model = phi3_model
146
+ self.projector = projector
147
+ self.supports_gradient_checkpointing = True
148
+
149
+ def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None,
150
+ image_embeddings=None, audio_embeddings=None, labels=None, **kwargs):
151
+ if inputs_embeds is None:
152
+ inputs_embeds = self.phi3_model.get_input_embeddings()(input_ids)
153
+
154
+ # Get fused embeddings from enhanced projector
155
+ projected_features, _ = self.projector(image_embeddings, audio_embeddings)
156
+
157
+ # Concatenate embeddings if we have projected features
158
+ if projected_features is not None:
159
+ combined_embeddings = torch.cat([inputs_embeds, projected_features.unsqueeze(1)], dim=1)
160
+ # Extend attention mask
161
+ extended_attention_mask = torch.cat([
162
+ attention_mask,
163
+ torch.ones((attention_mask.shape[0], 1), device=attention_mask.device)
164
+ ], dim=1)
165
+ else:
166
+ combined_embeddings = inputs_embeds
167
+ extended_attention_mask = attention_mask
168
+
169
+ # Adjust labels if needed
170
+ if labels is not None and projected_features is not None:
171
+ labels = torch.cat([
172
+ labels,
173
+ torch.full((labels.shape[0], 1), -100, dtype=labels.dtype, device=labels.device)
174
+ ], dim=1)
175
+
176
+ return self.phi3_model(
177
+ inputs_embeds=combined_embeddings,
178
+ attention_mask=extended_attention_mask,
179
+ labels=labels,
180
+ **kwargs
181
+ )
182
+
183
+
184
+ class MultimodalProjector(nn.Module):
185
+ def __init__(self, image_input_dim, audio_input_dim, output_dim):
186
+ super().__init__()
187
+ self.image_proj = ProjectionBlock(image_input_dim, output_dim)
188
+ self.audio_proj = ProjectionBlock(audio_input_dim, output_dim)
189
+
190
+ def forward(self, image_embedding=None, audio_embedding=None):
191
+ projected_image = self.image_proj(image_embedding) if image_embedding is not None else None
192
+ projected_audio = self.audio_proj(audio_embedding) if audio_embedding is not None else None
193
+ return projected_image, projected_audio
194
+
195
+
196
+
197
+ class Phi3WithProjector(PreTrainedModel):
198
+ def __init__(self, config, phi3_model, projector):
199
+ super().__init__(config)
200
+ self.phi3_model = phi3_model
201
+ self.projector = projector
202
+ self.supports_gradient_checkpointing = True
203
+
204
+
205
+ def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, image_embeddings=None, audio_embeddings=None, labels=None, **kwargs):
206
+ # Use get_input_embeddings() to retrieve the embeddings layer
207
+ if inputs_embeds is None:
208
+ inputs_embeds = self.phi3_model.get_input_embeddings()(input_ids)
209
+
210
+ # Project both image and audio embeddings to the appropriate dimension
211
+ projected_image, projected_audio = self.projector(image_embeddings, audio_embeddings)
212
+
213
+ # Concatenate the embeddings
214
+ embeddings_to_concat = [inputs_embeds]
215
+ if projected_image is not None:
216
+ embeddings_to_concat.append(projected_image.unsqueeze(1))
217
+ if projected_audio is not None:
218
+ embeddings_to_concat.append(projected_audio.unsqueeze(1))
219
+
220
+ combined_embeddings = torch.cat(embeddings_to_concat, dim=1)
221
+
222
+ # Modify how the attention mask is extended
223
+ extended_attention_mask = attention_mask.clone() # Start with a copy
224
+
225
+ # Extend for image and audio, if present
226
+ if projected_image is not None:
227
+ extended_attention_mask = torch.cat([extended_attention_mask, torch.ones_like(extended_attention_mask[:, :1])], dim=1)
228
+ if projected_audio is not None:
229
+ extended_attention_mask = torch.cat([extended_attention_mask, torch.ones_like(extended_attention_mask[:, :1])], dim=1)
230
+
231
+ # Adjust labels to match the extended input sequence length
232
+ if labels is not None:
233
+ # Pad labels with -100 to ignore the added tokens in the loss calculation
234
+ num_added_tokens = sum(1 for emb in [projected_image, projected_audio] if emb is not None)
235
+ labels = torch.cat([labels, torch.full((labels.shape[0], num_added_tokens), -100, dtype=labels.dtype, device=labels.device)], dim=1)
236
+ outputs = self.phi3_model(
237
+ inputs_embeds=combined_embeddings,
238
+ attention_mask=extended_attention_mask,
239
+ labels=labels,
240
+ **kwargs
241
+ )
242
+
243
+ # Add auxiliary losses for multimodal alignment
244
+ if image_embeddings is not None or audio_embeddings is not None:
245
+ loss = outputs.loss
246
+
247
+ # Add contrastive loss for multimodal alignment
248
+ if image_embeddings is not None and audio_embeddings is not None:
249
+ img_proj, audio_proj = self.projector(image_embeddings, audio_embeddings)
250
+ contrastive_loss = self.compute_contrastive_loss(img_proj, audio_proj)
251
+ loss = loss + 0.1 * contrastive_loss # Weight the auxiliary loss
252
+
253
+ outputs.loss = loss
254
+
255
+
256
+
257
+
258
+ return outputs
259
+
260
+ def get_input_embeddings(self):
261
+ """Returns the model's input embeddings."""
262
+ return self.phi3_model.get_input_embeddings()
263
+
264
+ def set_input_embeddings(self, value):
265
+ """Sets the model's input embeddings."""
266
+ self.phi3_model.set_input_embeddings(value)
267
+
268
+
269
+ # Instead, use the built-in gradient checkpointing
270
+ def enable_gradient_checkpointing(self):
271
+ """Enable gradient checkpointing for the model."""
272
+ if hasattr(self.phi3_model, "gradient_checkpointing_enable"):
273
+ self.phi3_model.gradient_checkpointing_enable()
274
+ else:
275
+ self.phi3_model.config.use_cache = False
276
+ self.phi3_model.train() # Ensure model is in training mode
277
+
278
+ def disable_gradient_checkpointing(self):
279
+ """Disable gradient checkpointing for the model."""
280
+ if hasattr(self.phi3_model, "gradient_checkpointing_disable"):
281
+ self.phi3_model.gradient_checkpointing_disable()
282
+ else:
283
+ self.phi3_model.config.use_cache = True
284
+
285
+ def compute_contrastive_loss(self, img_features, audio_features):
286
+ # Normalize features
287
+ img_features = F.normalize(img_features, dim=-1)
288
+ audio_features = F.normalize(audio_features, dim=-1)
289
+
290
+ # Compute similarity matrix
291
+ similarity = torch.matmul(img_features, audio_features.transpose(0, 1))
292
+
293
+ # Temperature-scaled cross entropy loss
294
+ temperature = 0.07
295
+ labels = torch.arange(similarity.size(0)).to(similarity.device)
296
+ loss = F.cross_entropy(similarity / temperature, labels)
297
+
298
+ return loss
299
+
300
+