ariG23498 HF Staff commited on
Commit
f2c2a4e
·
1 Parent(s): 1d47577
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ __pycache__
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from PIL import Image
5
+
6
+ # Set random seeds for reproducibility
7
+ torch.manual_seed(0)
8
+ if torch.cuda.is_available():
9
+ torch.cuda.manual_seed_all(0)
10
+
11
+ from models.vision_language_model import VisionLanguageModel
12
+ from data.processors import get_tokenizer, get_image_processor
13
+
14
+
15
+ @spaces.GPU
16
+ def generate_outputs(image, query):
17
+ # Determine device
18
+ if torch.cuda.is_available():
19
+ device = torch.device("cuda")
20
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
21
+ device = torch.device("mps")
22
+ else:
23
+ device = torch.device("cpu")
24
+
25
+ # Load model
26
+ hf_model = "lusxvr/nanoVLM-222M"
27
+ try:
28
+ model = VisionLanguageModel.from_pretrained(hf_model).to(device)
29
+ model.eval()
30
+ except Exception as e:
31
+ return f"Error loading model: {str(e)}", None, None, None, None
32
+
33
+ # Load tokenizer and image processor
34
+ try:
35
+ tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
36
+ image_processor = get_image_processor(model.cfg.vit_img_size)
37
+ except Exception as e:
38
+ return f"Error loading tokenizer or image processor: {str(e)}", None, None, None, None
39
+
40
+ # Prepare text input
41
+ template = f"Question: {query} Answer:"
42
+ encoded = tokenizer.batch_encode_plus([template], return_tensors="pt")
43
+ tokens = encoded["input_ids"].to(device)
44
+
45
+ # Process image
46
+ try:
47
+ img = image.convert("RGB")
48
+ img_t = image_processor(img).unsqueeze(0).to(device)
49
+ except Exception as e:
50
+ return f"Error processing image: {str(e)}", None, None, None, None
51
+
52
+ # Generate four outputs
53
+ outputs = []
54
+ max_new_tokens = 50 # Fixed value from provided script
55
+ try:
56
+ for _ in range(4):
57
+ gen = model.generate(tokens, img_t, max_new_tokens=max_new_tokens)
58
+ out = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
59
+ outputs.append(out)
60
+ except Exception as e:
61
+ return f"Error during generation: {str(e)}", None, None, None, None
62
+
63
+ return None, outputs[0], outputs[1], outputs[2], outputs[3]
64
+
65
+
66
+ def main():
67
+ # Define minimal CSS for subtle aesthetic enhancements
68
+ css = """
69
+ .gradio-container {
70
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
71
+ padding: 20px;
72
+ }
73
+ h1 {
74
+ color: #333;
75
+ text-align: center;
76
+ margin-bottom: 20px;
77
+ }
78
+ .description {
79
+ margin-bottom: 20px;
80
+ line-height: 1.6;
81
+ }
82
+ .gr-button {
83
+ padding: 10px 20px;
84
+ }
85
+ """
86
+
87
+ # Define Gradio interface
88
+ with gr.Blocks(css=css, title="nanoVLM Image-to-Text Generator") as app:
89
+ gr.Markdown(
90
+ "# nanoVLM Image-to-Text Generator"
91
+ )
92
+ gr.Markdown(
93
+ """
94
+ <div class="description">
95
+ This demo showcases <b>nanoVLM</b>, a lightweight vision-language model by HuggingFace.
96
+ Upload an image and provide a query to generate four text descriptions.
97
+ The model is based on the <a href="https://github.com/huggingface/nanoVLM/" target="_blank">nanoVLM repository</a>
98
+ and uses the pretrained model <a href="https://huggingface.co/lusxvr/nanoVLM-222M" target="_blank">lusxvr/nanoVLM-222M</a>.
99
+ nanoVLM is designed for efficient image-to-text generation, ideal for resource-constrained environments.
100
+ </div>
101
+ """
102
+ )
103
+
104
+ with gr.Row():
105
+ with gr.Column():
106
+ image_input = gr.Image(
107
+ type="pil",
108
+ label="Upload Image",
109
+ value="cat.jpg" # Set example image
110
+ )
111
+ query_input = gr.Textbox(
112
+ label="Query",
113
+ value="What is this?",
114
+ placeholder="Enter your query here",
115
+ lines=2
116
+ )
117
+ submit_button = gr.Button("Generate")
118
+
119
+ with gr.Column():
120
+ error_output = gr.Textbox(
121
+ label="Errors (if any)",
122
+ placeholder="No errors",
123
+ visible=True,
124
+ interactive=False
125
+ )
126
+ output1 = gr.Textbox(
127
+ label="Generation 1",
128
+ placeholder="Output 1 will appear here...",
129
+ lines=3
130
+ )
131
+ output2 = gr.Textbox(
132
+ label="Generation 2",
133
+ placeholder="Output 2 will appear here...",
134
+ lines=3
135
+ )
136
+ output3 = gr.Textbox(
137
+ label="Generation 3",
138
+ placeholder="Output 3 will appear here...",
139
+ lines=3
140
+ )
141
+ output4 = gr.Textbox(
142
+ label="Generation 4",
143
+ placeholder="Output 4 will appear here...",
144
+ lines=3
145
+ )
146
+
147
+ # Define action on submit
148
+ submit_button.click(
149
+ fn=generate_outputs,
150
+ inputs=[image_input, query_input],
151
+ outputs=[error_output, output1, output2, output3, output4]
152
+ )
153
+
154
+ # Launch the app
155
+ app.launch()
156
+
157
+
158
+ if __name__ == "__main__":
159
+ main()
cat.jpg ADDED

Git LFS Details

  • SHA256: dea9e7ef97386345f7cff32f9055da4982da5471c48d575146c796ab4563b04e
  • Pointer size: 131 Bytes
  • Size of remote file: 173 kB
data/__init__.py ADDED
File without changes
data/collators.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class VQACollator(object): # Visual Question Answering Collator
4
+ def __init__(self, tokenizer, max_length):
5
+ self.tokenizer = tokenizer
6
+ self.max_length = max_length
7
+
8
+ def __call__(self, batch):
9
+ images = [item["image"] for item in batch]
10
+ texts = [item["text_data"] for item in batch]
11
+ answers = [item["answer"] for item in batch]
12
+
13
+ # Stack images
14
+ images = torch.stack(images)
15
+
16
+ # Create inputs by concatenating the question and answer
17
+ input_sequences = []
18
+ for i in range(len(texts)):
19
+ input_sequences.append(f"{texts[i]}{answers[i]}")
20
+
21
+ encoded_full_sequences = self.tokenizer.batch_encode_plus(
22
+ input_sequences,
23
+ padding="max_length",
24
+ padding_side="left",
25
+ return_tensors="pt",
26
+ truncation=True,
27
+ max_length=self.max_length,
28
+ )
29
+
30
+ # Create labels where only answer tokens are predicted
31
+ input_ids = encoded_full_sequences["input_ids"]
32
+ attention_mask = encoded_full_sequences["attention_mask"]
33
+ labels = input_ids.clone()
34
+ labels[:, :-1] = input_ids[:, 1:].clone()
35
+ labels[:, -1] = -100 #self.tokenizer.pad_token_id
36
+
37
+ # The tokenizer has different behavior for padding and truncation:
38
+ # 1. If the full text (answer + question) is shorter than the max length, it gets padded on the left
39
+ # 2. If the full text is longer than the max length, it gets truncated on the right
40
+ # Therefore, I need to handle multiple cases, this is the different scenarios:
41
+ # If the full text is longer than the max length, we need to set the labels to -100 for the whole sample (we want to ignore the whole sample)
42
+ # If the full text is shorter than the max length, we need to set the labels to -100 only for the question part, and create causal language modeling labels for the answer part, taking into account the padding
43
+
44
+ # Determine if sequences were truncated
45
+ original_lengths = [len(self.tokenizer.encode(seq)) for seq in input_sequences]
46
+
47
+ for i in range(len(batch)):
48
+ # Get the length of the question for this sample
49
+ question_length = len(self.tokenizer.encode(texts[i], add_special_tokens=False))
50
+
51
+ # Case 1: If sequence was truncated (original is longer than max_length)
52
+ if original_lengths[i] > self.max_length:
53
+ # Set all labels to -100 to ignore this sample entirely
54
+ labels[i, :] = -100
55
+ #print(f"Sample {i} was truncated. Setting all labels to -100.")
56
+ continue
57
+
58
+ # Case 2: Sequence fits within max_length
59
+ # Use attention mask to find first non-padding token
60
+ # The first 1 in the attention mask marks the first non-padding token
61
+ first_token_pos = attention_mask[i].nonzero(as_tuple=True)[0][0].item()
62
+
63
+ # Set labels for padding and question part to -100 (don't predict these), substracting 1 to account for the left shift
64
+ question_end = first_token_pos + question_length - 1
65
+ labels[i, :question_end] = -100
66
+ # labels[i, original_lengths[i]-1:] = -100 # If you are using right padding
67
+
68
+ return {
69
+ "image": images,
70
+ "input_ids": input_ids,
71
+ "attention_mask": attention_mask,
72
+ "labels": labels
73
+ }
74
+
75
+ class MMStarCollator(object): # https://huggingface.co/datasets/Lin-Chen/MMStar
76
+ def __init__(self, tokenizer):
77
+ self.tokenizer = tokenizer
78
+
79
+ def __call__(self, batch):
80
+ images = [item["image"] for item in batch]
81
+ questions = [item["text_data"] for item in batch]
82
+ answers = [item["answer"] for item in batch]
83
+
84
+ # Stack images
85
+ images = torch.stack(images)
86
+
87
+ encoded_question_sequences = self.tokenizer.batch_encode_plus(
88
+ questions,
89
+ padding=True,
90
+ padding_side="left",
91
+ return_tensors="pt"
92
+ )
93
+
94
+ encoded_answer_sequences = self.tokenizer.batch_encode_plus(
95
+ answers,
96
+ padding=True,
97
+ padding_side="left",
98
+ return_tensors="pt"
99
+ )
100
+
101
+ return {
102
+ "images": images,
103
+ "input_ids": encoded_question_sequences['input_ids'],
104
+ "attention_mask": encoded_question_sequences['attention_mask'],
105
+ "labels": encoded_answer_sequences['input_ids'],
106
+ }
data/datasets.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torch.utils.data import Dataset
4
+
5
+ import models.config as cfg
6
+
7
+
8
+ class VQADataset(Dataset): # Visual Question Answering Dataset
9
+ def __init__(self, dataset, tokenizer, image_processor):
10
+ self.dataset = dataset
11
+ self.tokenizer = tokenizer
12
+ self.image_processor = image_processor
13
+
14
+ def __len__(self):
15
+ return len(self.dataset)
16
+
17
+ def __getitem__(self, idx):
18
+ item = self.dataset[idx]
19
+
20
+ # Handle image (it's a list)
21
+ image_data = item['images']
22
+ if isinstance(image_data, list) and len(image_data) > 0:
23
+ image = image_data[0]
24
+ else:
25
+ image = image_data
26
+
27
+ # Now process the image
28
+ if isinstance(image, Image.Image):
29
+ if image.mode != 'RGB':
30
+ image = image.convert('RGB')
31
+ processed_image = self.image_processor(image)
32
+ else:
33
+ print(f"Error processing image at index {idx}")
34
+ # Create empty tensor with right dimensions as fallback
35
+ processed_image = torch.zeros(
36
+ 3, cfg.VLMConfig.vit_img_size, cfg.VLMConfig.vit_img_size)
37
+
38
+ # Process text (also a list)
39
+ text_data = item['texts']
40
+ if isinstance(text_data, list) and len(text_data) > 0:
41
+ text = text_data[0]
42
+ else:
43
+ text = text_data
44
+
45
+ question = text['user']
46
+ # Add EOS token to the answer to train model to predict it, enabling correct stopping during generation
47
+ answer = text['assistant'] + self.tokenizer.eos_token
48
+
49
+ formatted_text = f"Question: {question} Answer:"
50
+
51
+ return {
52
+ "image": processed_image,
53
+ "text_data": formatted_text,
54
+ "answer": answer
55
+ }
56
+
57
+
58
+ class MMStarDataset(Dataset): # https://huggingface.co/datasets/Lin-Chen/MMStar
59
+ def __init__(self, dataset, tokenizer, image_processor):
60
+ self.dataset = dataset
61
+ self.tokenizer = tokenizer
62
+ self.image_processor = image_processor
63
+
64
+ def __len__(self):
65
+ return len(self.dataset)
66
+
67
+ def __getitem__(self, idx):
68
+ item = self.dataset[idx]
69
+
70
+ image = item['image']
71
+
72
+ # Now process the image
73
+ if isinstance(image, Image.Image):
74
+ if image.mode != 'RGB':
75
+ image = image.convert('RGB')
76
+ processed_image = self.image_processor(image)
77
+ else:
78
+ print(f"Error processing image at index {idx}")
79
+ # Create empty tensor with right dimensions as fallback
80
+ processed_image = torch.zeros(3, cfg.VLMConfig.vit_img_size, cfg.VLMConfig.vit_img_size)
81
+
82
+ question = item['question']
83
+ answer = item['answer'] + self.tokenizer.eos_token # Add EOS token to the answer to train model to predict it, enabling correct stopping during generation
84
+
85
+ formatted_text = f"Question: {question} \nAnswer only with the letter! \nAnswer:"
86
+
87
+ return {
88
+ "image": processed_image,
89
+ "text_data": formatted_text,
90
+ "answer": answer
91
+ }
92
+
data/processors.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ import torchvision.transforms as transforms
3
+
4
+ TOKENIZERS_CACHE = {}
5
+
6
+ def get_tokenizer(name):
7
+ if name not in TOKENIZERS_CACHE:
8
+ tokenizer = AutoTokenizer.from_pretrained(name, use_fast=True)
9
+ tokenizer.pad_token = tokenizer.eos_token
10
+ TOKENIZERS_CACHE[name] = tokenizer
11
+ return TOKENIZERS_CACHE[name]
12
+
13
+ def get_image_processor(img_size):
14
+ return transforms.Compose([
15
+ transforms.Resize((img_size, img_size)),
16
+ transforms.ToTensor()
17
+ ])
models/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models
2
+
3
+ ## Vision Backbone (ViT)
4
+
5
+ This is a very lightweight Vision Transformer in native pytorch. I took inspiration from the following sources:
6
+ - https://github.com/karpathy/nanoGPT (General Transformer Decoder)
7
+ - https://arxiv.org/abs/2010.11929 (ViT Paper)
8
+ - https://arxiv.org/abs/2303.15343 (SigLiP Paper)
9
+ - https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/modeling_siglip.py (HF SigLiP Implementation)
10
+
11
+ ## Language Model (Llama / SmolLM)
12
+
13
+ This is a decoder only LM, following the Llama 2/3 architecture. Inspiration from the following sources:
14
+ - https://arxiv.org/pdf/2307.09288 (Original Llama Paper)
15
+ - https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py (HF Llama Implementation)
16
+
17
+ ## Modality Projection
18
+
19
+ This is a simple MLP (Linear Layer) for the Modality Projection between the Image Patch Encodings and the Language Embedding Space with a simple Pixel Shuffle (https://arxiv.org/pdf/2504.05299)
20
+
21
+ ## Vision-Language-Model
22
+
23
+ This brings all the individual parts together and handles the concatenation of images and text. Built as a simple version of SmolVLM (https://arxiv.org/pdf/2504.05299)
models/__init__.py ADDED
File without changes
models/config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class VLMConfig:
6
+ vit_hidden_dim: int = 768
7
+ vit_inter_dim: int = 4 * vit_hidden_dim
8
+ vit_patch_size: int = 16
9
+ vit_img_size: int = 224
10
+ vit_n_heads: int = 12
11
+ vit_dropout: float = 0.0
12
+ vit_n_blocks: int = 12
13
+ vit_ln_eps: float = 1e-6
14
+ vit_cls_flag: bool = False
15
+ vit_model_type: str = 'google/siglip-base-patch16-224'
16
+
17
+ lm_hidden_dim: int = 576
18
+ lm_inter_dim: int = 1536
19
+ lm_rms_eps: float = 1e-5
20
+ lm_re_base: int = 100000
21
+ lm_max_position_embeddings: int = 8192
22
+ lm_vocab_size: int = 49152
23
+ lm_n_heads: int = 9
24
+ lm_n_kv_heads: int = 3
25
+ lm_dropout: float = 0.0
26
+ lm_n_blocks: int = 30
27
+ lm_attn_scaling: float = 1.0
28
+ lm_max_length: int = 128 - 49 # Deduct the image token length to achieve a 'nice number'
29
+ lm_use_tokens: bool = False # Decide if the LM expects tokens or embeddings as input (if using as a backbone for the VLM, set to False)
30
+ lm_tie_weights: bool = True # Decide if you want to tie the LM Head weight to the token embedding weights
31
+ lm_model_type: str = 'HuggingFaceTB/SmolLM2-135M'
32
+ lm_tokenizer: str = 'HuggingFaceTB/cosmo2-tokenizer'
33
+ lm_eos_token_id: int = 0
34
+
35
+ mp_pixel_shuffle_factor: int = 2
36
+
37
+ vlm_load_backbone_weights: bool = True
38
+ vlm_checkpoint_path: str = 'checkpoints/nanoVLM-222M'
39
+
40
+
41
+ @dataclass
42
+ class TrainConfig:
43
+ lr_mp: float = 2e-3
44
+ lr_backbones: float = 1e-4
45
+ data_cutoff_idx: int = None
46
+ val_ratio: float = 0.01
47
+ batch_size: int = 256
48
+ mmstar_batch_size: int = 32
49
+ eval_in_epochs: bool = True
50
+ epochs: int = 5
51
+ compile: bool = True
52
+ resume_from_vlm_checkpoint: bool = False # Indicate if the training should be resumed from a checkpoint of the whole VLM or you want to start from scratch
53
+ train_dataset_path: str = 'HuggingFaceM4/the_cauldron'
54
+ train_dataset_name: tuple[str, ...] = ("ai2d", "aokvqa", "chart2text", "chartqa", "clevr", "cocoqa", "datikz", "diagram_image_to_text", "docvqa", "dvqa", "figureqa", "finqa", "geomverse", "hateful_memes", "hitab", "iam", "iconqa", "infographic_vqa", "intergps", "localized_narratives", "mapqa", "multihiertt", "ocrvqa", "plotqa", "raven", "rendered_text", "robut_sqa", "robut_wikisql", "robut_wtq", "scienceqa", "screen2words", "st_vqa", "tabmwp", "tallyqa", "tat_qa", "textcaps", "textvqa", "tqa", "vistext", "visual7w", "visualmrc", "vqarad", "vqav2", "vsr", "websight") # "clevr_math", "okvqa", "spot_the_diff", "nlvr2", "mimic_cgd",
55
+ test_dataset_path: str = "Lin-Chen/MMStar"
56
+ wandb_entity: str = "HuggingFace" # Indicate the entity to log to in wandb
57
+ log_wandb: bool = True
models/language_model.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L69
7
+ class RMSNorm(nn.Module):
8
+ def __init__(self, cfg):
9
+ super().__init__()
10
+ self.weight = nn.Parameter(torch.ones(cfg.lm_hidden_dim))
11
+ self.eps = cfg.lm_rms_eps
12
+
13
+ def forward(self, x):
14
+ irms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) # inverse of RMS
15
+ x = x * irms * self.weight
16
+
17
+ return x
18
+
19
+ # Multiple derivates of Rotary Embeddings by now, this is a basic one with linear scaling to context length
20
+ # e.g. https://github.com/huggingface/smollm/blob/main/vision/m4/models/vllama3/modeling_vllama3.py#L190
21
+ class RotaryEmbedding(nn.Module):
22
+ def __init__(self, cfg):
23
+ super().__init__()
24
+ assert cfg.lm_hidden_dim % cfg.lm_n_heads == 0, "Hidden dimension must be divisible by number of heads"
25
+
26
+ self.dim = cfg.lm_hidden_dim // cfg.lm_n_heads # dim of each head
27
+ self.base = cfg.lm_re_base
28
+ self.max_seq_len = cfg.lm_max_position_embeddings
29
+ # Standard RoPE implementation - create frequencies for each dimension
30
+ # freq_i = 1 / (base^(2i/dim)) where i is the dimension index
31
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
32
+ self.register_buffer("inv_freq", inv_freq)
33
+ self.original_max_seq_len = cfg.lm_max_position_embeddings
34
+ self.attention_scaling = cfg.lm_attn_scaling
35
+
36
+ @torch.no_grad()
37
+ def forward(self, position_ids):
38
+ batch_size, seq_len = position_ids.shape
39
+ # Dynamic scaling for longer sequences
40
+ max_seq = position_ids.max() + 1
41
+ if max_seq > self.original_max_seq_len:
42
+ scale = max_seq / self.original_max_seq_len
43
+ inv_freq = self.inv_freq / scale
44
+ else:
45
+ inv_freq = self.inv_freq
46
+
47
+ # Compute theta = position * frequency
48
+ # Flatten position_ids for batch processing
49
+ flat_position_ids = position_ids.reshape(-1).float()
50
+
51
+ # Element-wise outer product: [seq_len] x [dim/2] => [seq_len, dim/2]
52
+ freqs = flat_position_ids.unsqueeze(-1) * inv_freq.unsqueeze(0)
53
+
54
+ # Reshape to include batch dimension
55
+ freqs = freqs.reshape(batch_size, seq_len, -1)
56
+
57
+ # Now create interleaved pattern
58
+ emb = torch.cat([freqs, freqs], dim=-1)
59
+
60
+ # Compute cos and sin
61
+ cos = torch.cos(emb) * self.attention_scaling
62
+ sin = torch.sin(emb) * self.attention_scaling
63
+
64
+ return cos, sin
65
+
66
+ # Rotates half the hidden dims of the input by swapping and negating dimensions.
67
+ def rotate_half(x):
68
+ x1, x2 = x.chunk(2, dim=-1)
69
+ return torch.cat((-x2, x1), dim=-1)
70
+
71
+ # Apply rotary position embeddings to queries and keys.
72
+ def apply_rotary_pos_embd(q, k, cos, sin, unsqueeze_dim=1):
73
+ # We need to make sure cos and sin can be properly broadcast
74
+ # to the shape of q and k by adding the heads dimension
75
+ cos = cos.unsqueeze(unsqueeze_dim) # [batch_size, 1, seq_len, head_dim]
76
+ sin = sin.unsqueeze(unsqueeze_dim) # [batch_size, 1, seq_len, head_dim]
77
+
78
+ # Apply complex multiplication:
79
+ # (q * cos) + (rotate_half(q) * sin)
80
+ q_embed = (q * cos) + (rotate_half(q) * sin)
81
+ k_embed = (k * cos) + (rotate_half(k) * sin)
82
+
83
+ return q_embed, k_embed
84
+
85
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L214
86
+ # https://github.com/huggingface/smollm/blob/main/vision/m4/models/vllama3/modeling_vllama3.py#L382
87
+ class LanguageModelGroupedQueryAttention(nn.Module):
88
+ def __init__(self, cfg):
89
+ super().__init__()
90
+
91
+ self.n_heads = cfg.lm_n_heads
92
+ self.n_kv_heads = cfg.lm_n_kv_heads
93
+ self.embd_dim = cfg.lm_hidden_dim
94
+ self.dropout = cfg.lm_dropout
95
+
96
+ assert self.n_heads % self.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
97
+ assert self.embd_dim % self.n_heads == 0, "embd_dim must be divisible by num_heads"
98
+
99
+ self.n_kv_groups = self.n_heads // self.n_kv_heads
100
+ self.head_dim = self.embd_dim // self.n_heads
101
+
102
+ self.q_proj = nn.Linear(self.embd_dim, self.embd_dim, bias=False)
103
+ self.k_proj = nn.Linear(self.embd_dim, self.head_dim * self.n_kv_heads, bias=False)
104
+ self.v_proj = nn.Linear(self.embd_dim, self.head_dim * self.n_kv_heads, bias=False)
105
+ self.out_proj = nn.Linear(self.embd_dim, self.embd_dim, bias=False)
106
+
107
+ self.attn_dropout = nn.Dropout(self.dropout)
108
+ self.resid_dropout = nn.Dropout(self.dropout)
109
+
110
+ # Use scaled dot product attention if available
111
+ self.sdpa = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
112
+ if not self.sdpa:
113
+ print("Warning: scaled dot product attention not available, using standard attention in LM.")
114
+
115
+ def forward(self, x, cos, sin, attention_mask=None):
116
+ B, T, C = x.size()
117
+
118
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
119
+ k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T, head_dim)
120
+ v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T, head_dim)
121
+
122
+ # Use precomputed positional embeddings
123
+ q, k = apply_rotary_pos_embd(q, k, cos, sin)
124
+
125
+ k = k.repeat_interleave(self.n_kv_groups, dim=1)
126
+ v = v.repeat_interleave(self.n_kv_groups, dim=1)
127
+
128
+ # Process attention mask if provided
129
+ if attention_mask is not None:
130
+ # Create a 4D attention mask [batch_size, 1, 1, seq_length], In this format, 1 = attend, 0 = mask
131
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, T]
132
+ padding_mask = (attention_mask == 0).transpose(-1, -2) # Use this for the manual path
133
+ # Convert to attention mask where 0 keeps values and -inf masks
134
+ attention_mask = (1.0 - attention_mask) * torch.finfo(q.dtype).min
135
+
136
+ if self.sdpa:
137
+ y = torch.nn.functional.scaled_dot_product_attention(
138
+ q, k, v,
139
+ attn_mask=attention_mask,
140
+ dropout_p=self.dropout if self.training else 0.0,
141
+ is_causal=True # LM attention is causal (masked)
142
+ )
143
+ else:
144
+ attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
145
+ causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
146
+ attn = attn.masked_fill(causal_mask == 0, float('-inf'))
147
+ if attention_mask is not None:
148
+ attn = attn + attention_mask
149
+
150
+ attn = F.softmax(attn, dim=-1)
151
+ attn = self.attn_dropout(attn)
152
+ y = attn @ v
153
+
154
+ if attention_mask is not None:
155
+ y = y.masked_fill(padding_mask, 0.0) # Zero out the padded positions in the output
156
+
157
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
158
+ y = self.out_proj(y)
159
+ y = self.resid_dropout(y)
160
+
161
+ return y
162
+
163
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L160
164
+ class LanguageModelMLP(nn.Module):
165
+ def __init__(self, cfg):
166
+ super().__init__()
167
+ self.embd_dim = cfg.lm_hidden_dim
168
+ self.inter_dim = cfg.lm_inter_dim
169
+
170
+ self.activation_fn = F.silu
171
+ self.gate_proj = nn.Linear(self.embd_dim, self.inter_dim, bias=False)
172
+ self.up_proj = nn.Linear(self.embd_dim, self.inter_dim, bias=False)
173
+ self.down_proj = nn.Linear(self.inter_dim, self.embd_dim, bias=False)
174
+
175
+ def forward(self, x):
176
+ gate = self.activation_fn(self.gate_proj(x))
177
+ x = self.up_proj(x)
178
+ x = self.down_proj(gate * x)
179
+
180
+ return x
181
+
182
+ # https://github.com/meta-llama/llama3/blob/main/llama/model.py#L222
183
+ class LanguageModelBlock(nn.Module):
184
+ def __init__(self, cfg):
185
+ super().__init__()
186
+ self.mlp = LanguageModelMLP(cfg)
187
+ self.attn = LanguageModelGroupedQueryAttention(cfg)
188
+ self.norm1 = RMSNorm(cfg) # Input Norm
189
+ self.norm2 = RMSNorm(cfg) # Post Attention Norm
190
+
191
+ def forward(self, x, cos, sin, attention_mask=None):
192
+ res = x
193
+ x = self.norm1(x)
194
+ x = self.attn(x, cos, sin, attention_mask)
195
+ x = res + x
196
+
197
+ res = x
198
+ x = self.norm2(x)
199
+ x = self.mlp(x)
200
+ x = res + x
201
+
202
+ return x
203
+
204
+ # https://github.com/meta-llama/llama3/blob/main/llama/model.py#L251
205
+ class LanguageModel(nn.Module):
206
+ def __init__(self, cfg):
207
+ super().__init__()
208
+ self.cfg = cfg
209
+ self.lm_use_tokens = cfg.lm_use_tokens
210
+ self.lm_tie_weights = cfg.lm_tie_weights
211
+
212
+ self.token_embedding = nn.Embedding(cfg.lm_vocab_size, cfg.lm_hidden_dim)
213
+ self.rotary_embd = RotaryEmbedding(cfg)
214
+ self.blocks = nn.ModuleList([
215
+ LanguageModelBlock(cfg) for _ in range(cfg.lm_n_blocks)
216
+ ])
217
+ self.norm = RMSNorm(cfg) # Final Norm
218
+ self.head = nn.Linear(cfg.lm_hidden_dim, cfg.lm_vocab_size, bias=False)
219
+ if self.lm_tie_weights:
220
+ self.head.weight = self.token_embedding.weight
221
+
222
+ self.apply(self._init_weights)
223
+
224
+ def _init_weights(self, module):
225
+ if isinstance(module, nn.Linear):
226
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
227
+ if module.bias is not None:
228
+ torch.nn.init.zeros_(module.bias)
229
+ elif isinstance(module, nn.Embedding):
230
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
231
+ elif isinstance(module, RMSNorm):
232
+ module.weight.data.fill_(1.0)
233
+
234
+ def forward(self, x, attention_mask=None):
235
+ if self.lm_use_tokens:
236
+ x = self.token_embedding(x) # Only embed the inputs when using tokens
237
+
238
+ B , T, _ = x.size()
239
+
240
+ # Note: You could also cache these input embeddings if you want to avoid recomputing them
241
+ position_ids = torch.arange(T, device=x.device).unsqueeze(0).expand(B, -1) # Create position ids [0, 1, 2, ..., seq_len-1]
242
+ cos, sin = self.rotary_embd(position_ids) # Get rotary position embeddings
243
+
244
+ for block in self.blocks:
245
+ x = block(x, cos, sin, attention_mask)
246
+ x = self.norm(x)
247
+
248
+ if self.lm_use_tokens:
249
+ x = self.head(x) # Compute logits if we are using tokens, otherwise stay in the embedding space
250
+
251
+ return x
252
+
253
+ @torch.no_grad()
254
+ def generate(self, inputs, max_new_tokens=20):
255
+ # Add batch dimension if needed
256
+ if inputs.dim() == 1:
257
+ inputs = inputs.unsqueeze(0)
258
+
259
+ generated = inputs.clone()
260
+
261
+ for _ in range(max_new_tokens):
262
+ # Forward pass through the model
263
+ outputs = self.forward(generated)
264
+ last_output = outputs[:, -1, :]
265
+
266
+ if self.lm_use_tokens:
267
+ # Now the model outputs logits
268
+ next_token = torch.argmax(last_output, dim=-1, keepdim=True)
269
+ generated = torch.cat((generated, next_token), dim=-1)
270
+ else:
271
+ # Now the model outputs embeddings
272
+ next_token_embedding = last_output.unsqueeze(1) # Shape: [batch_size, 1, hidden_dim]
273
+ generated = torch.cat((generated, next_token_embedding), dim=1)
274
+
275
+ #Note: You could enable the generation to break earlier than max_new_tokens when it detects a eos token, but this does not work in batched generation (output tensors need to have the same size)
276
+
277
+ return generated
278
+
279
+ # Load the model from a pretrained HuggingFace model (we don't want to have to train the Language Backbone from scratch)
280
+ @classmethod
281
+ def from_pretrained(cls, cfg):
282
+ from transformers import AutoConfig
283
+ from huggingface_hub import hf_hub_download
284
+ import safetensors
285
+ import torch.nn.init as init
286
+
287
+ # Load the HuggingFace config
288
+ hf_config = AutoConfig.from_pretrained(cfg.lm_model_type)
289
+
290
+ # Store original HF vocab size before we modify it
291
+ original_vocab_size = hf_config.vocab_size
292
+ # print(f"Original vocabulary size from pretrained model: {original_vocab_size}")
293
+
294
+ # Configure model parameters from HF config
295
+ cfg.lm_hidden_dim = hf_config.hidden_size
296
+ cfg.lm_inter_dim = hf_config.intermediate_size
297
+ cfg.lm_rms_eps = hf_config.rms_norm_eps
298
+ cfg.lm_re_base = hf_config.rope_theta
299
+ cfg.lm_max_position_embeddings = hf_config.max_position_embeddings
300
+ # We're keeping our own vocab size in cfg, but checking it's larger than original
301
+ if hasattr(cfg, 'lm_vocab_size'):
302
+ if cfg.lm_vocab_size < original_vocab_size:
303
+ raise ValueError(f"Config vocab size ({cfg.lm_vocab_size}) is smaller than pretrained model vocab size ({original_vocab_size})")
304
+ # print(f"Using vocabulary size: {cfg.lm_vocab_size}")
305
+ else:
306
+ # If not specified, use the original
307
+ cfg.lm_vocab_size = original_vocab_size
308
+ # print(f"Using original vocabulary size: {cfg.lm_vocab_size}")
309
+
310
+ cfg.lm_n_heads = hf_config.num_attention_heads
311
+ cfg.lm_n_kv_heads = hf_config.num_key_value_heads
312
+ cfg.lm_dropout = hf_config.attention_dropout
313
+ cfg.lm_n_blocks = hf_config.num_hidden_layers
314
+
315
+ # Create our model with potentially larger vocabulary
316
+ model = cls(cfg)
317
+ safetensors_file = hf_hub_download(repo_id=cfg.lm_model_type, filename="model.safetensors")
318
+
319
+ sd = model.state_dict()
320
+
321
+ mapping = {
322
+ 'model.embed_tokens.weight': 'token_embedding.weight',
323
+ 'model.norm.weight': 'norm.weight'
324
+ }
325
+
326
+ for i in range(cfg.lm_n_blocks):
327
+ layer_prefix = f'model.layers.{i}.'
328
+ block_prefix = f'blocks.{i}.'
329
+
330
+ mapping.update({
331
+ f"{layer_prefix}self_attn.q_proj.weight": f"{block_prefix}attn.q_proj.weight",
332
+ f"{layer_prefix}self_attn.k_proj.weight": f"{block_prefix}attn.k_proj.weight",
333
+ f"{layer_prefix}self_attn.v_proj.weight": f"{block_prefix}attn.v_proj.weight",
334
+ f"{layer_prefix}self_attn.o_proj.weight": f"{block_prefix}attn.out_proj.weight",
335
+ f"{layer_prefix}mlp.gate_proj.weight": f"{block_prefix}mlp.gate_proj.weight",
336
+ f"{layer_prefix}mlp.up_proj.weight": f"{block_prefix}mlp.up_proj.weight",
337
+ f"{layer_prefix}mlp.down_proj.weight": f"{block_prefix}mlp.down_proj.weight",
338
+ f"{layer_prefix}input_layernorm.weight": f"{block_prefix}norm1.weight",
339
+ f"{layer_prefix}post_attention_layernorm.weight": f"{block_prefix}norm2.weight"
340
+ })
341
+
342
+ # Special handling for token embeddings with extended vocabulary
343
+ has_extended_embeddings = False
344
+ with safetensors.safe_open(filename=safetensors_file, framework="pt", device="cpu") as f:
345
+ for hf_key, our_key in mapping.items():
346
+ if hf_key in f.keys() and our_key in sd:
347
+ tensor = f.get_tensor(hf_key)
348
+
349
+ # Special handling for token embeddings if vocab sizes differ
350
+ if hf_key == 'model.embed_tokens.weight' and tensor.shape[0] != sd[our_key].shape[0]:
351
+ has_extended_embeddings = True
352
+ print(f"Extending token embeddings from {tensor.shape} to {sd[our_key].shape}")
353
+
354
+ # Copy existing embeddings to the beginning of our larger embedding matrix
355
+ sd[our_key][:tensor.shape[0]].copy_(tensor)
356
+
357
+ # Initialize the new embeddings using the same approach as the original model
358
+ std = 0.02 # Common value, but you might want to adjust based on model
359
+ init.normal_(sd[our_key][tensor.shape[0]:], mean=0.0, std=std)
360
+
361
+ print(f"Initialized {sd[our_key].shape[0] - tensor.shape[0]} new token embeddings")
362
+ sd['head.weight'].copy_(sd[our_key]) # Update the head weights as well
363
+ elif tensor.shape == sd[our_key].shape:
364
+ sd[our_key].copy_(tensor)
365
+ else:
366
+ print(f"Shape mismatch for {hf_key} -> {our_key}: {tensor.shape} vs {sd[our_key].shape}")
367
+ else:
368
+ if hf_key not in f.keys():
369
+ print(f"Warning: Key {hf_key} not found in safetensors file")
370
+ if our_key not in sd:
371
+ print(f"Warning: Key {our_key} not found in model state dict")
372
+
373
+ # Load the state dict
374
+ model.load_state_dict(sd)
375
+
376
+ # Handle output projection / language modeling head
377
+ if has_extended_embeddings and hasattr(model, 'head') and 'head.weight' in sd:
378
+ # If we have a separate output projection layer and extended the vocab
379
+ # we should handle it similarly to the input embeddings
380
+ with safetensors.safe_open(filename=safetensors_file, framework="pt", device="cpu") as f:
381
+ if 'lm_head.weight' in f.keys():
382
+ lm_head = f.get_tensor('lm_head.weight')
383
+ if lm_head.shape[0] != sd['head.weight'].shape[0]:
384
+ print(f"Extending LM head from {lm_head.shape} to {sd['head.weight'].shape}")
385
+ # Copy existing weights
386
+ sd['head.weight'][:lm_head.shape[0]].copy_(lm_head)
387
+ # Initialize new weights
388
+ std = 0.02
389
+ init.normal_(sd['head.weight'][lm_head.shape[0]:], mean=0.0, std=std)
390
+ # Load updated weights
391
+ model.load_state_dict(sd)
392
+
393
+ # Handle weight tying (if needed)
394
+ if cfg.lm_tie_weights and hasattr(model, 'head') and hasattr(model, 'token_embedding'):
395
+ model.head.weight = model.token_embedding.weight
396
+ # print("Tied token embedding and LM head weights")
397
+
398
+ print(f"Successfully loaded {cfg.lm_model_type} weights from safetensors. Model has {sum(p.numel() for p in model.parameters()):,} parameters.")
399
+ return model
models/modality_projector.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modality Projection from Vision to Language
2
+ import torch.nn as nn
3
+
4
+ class ModalityProjector(nn.Module):
5
+ def __init__(self, cfg):
6
+ super().__init__()
7
+ self.cfg = cfg
8
+ self.input_dim = cfg.vit_hidden_dim * (cfg.mp_pixel_shuffle_factor**2)
9
+ self.output_dim = cfg.lm_hidden_dim
10
+ self.scale_factor = cfg.mp_pixel_shuffle_factor
11
+
12
+ self.proj = nn.Linear(self.input_dim, self.output_dim, bias=False)
13
+
14
+ self.apply(self._init_weights)
15
+
16
+ def _init_weights(self, module):
17
+ if isinstance(module, nn.Linear):
18
+ nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
19
+ if module.bias is not None:
20
+ nn.init.zeros_(module.bias)
21
+
22
+ # https://github.com/huggingface/smollm/blob/main/vision/m4/models/vllama3/modeling_vllama3.py#L1281
23
+ def pixel_shuffle(self, x):
24
+ bsz, seq, embed_dim = x.size()
25
+ seq_root = int(seq**0.5)
26
+ assert seq_root**2 == seq # Sequence length must be a perfect square for pixel shuffle
27
+ assert seq_root % self.scale_factor == 0 # Sequence root must be divisible by scale factor
28
+
29
+ height = width = seq_root
30
+ x = x.view(bsz, height, width, embed_dim)
31
+ h_out = height // self.scale_factor
32
+ w_out = width // self.scale_factor
33
+
34
+ x = x.reshape(bsz, h_out, self.scale_factor, w_out, self.scale_factor, embed_dim)
35
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
36
+ x = x.reshape(bsz, h_out * w_out, embed_dim * self.scale_factor**2)
37
+
38
+ return x
39
+
40
+ def forward(self, x):
41
+ x = self.pixel_shuffle(x)
42
+ x = self.proj(x)
43
+
44
+ return x
45
+
46
+
models/utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ # Used to check our models performance on multiple choice tasks. This can also be done in a more involved way with e.g. LLM-as-a-judge
4
+ def check_multiple_choice_with_regex(model_outputs, correct_answers):
5
+ results = []
6
+ for model_output, correct_answer in zip(model_outputs, correct_answers):
7
+ correct_answer = correct_answer.upper()
8
+
9
+ # Look for the answer letter at the beginning of a line or as the last word
10
+ patterns = [
11
+ rf"\b{correct_answer}\b", # Word boundary around the answer letter
12
+ rf"\b{correct_answer}[.,)]", # Answer followed by punctuation
13
+ rf"\(.*{correct_answer}.*\)", # Answer within parentheses
14
+ ]
15
+
16
+ match_found = False
17
+ for pattern in patterns:
18
+ if re.search(pattern, model_output):
19
+ match_found = True
20
+ break # Exit inner loop once a match is found
21
+ results.append(match_found)
22
+ return results
models/vision_language_model.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import tempfile
4
+ from dataclasses import asdict
5
+ from typing import Optional
6
+
7
+
8
+ from models.vision_transformer import ViT
9
+ from models.language_model import LanguageModel
10
+ from models.modality_projector import ModalityProjector
11
+ from models.config import VLMConfig
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from safetensors.torch import load_model, save_model
17
+
18
+ class VisionLanguageModel(nn.Module):
19
+ def __init__(self, cfg: VLMConfig, load_backbone=True):
20
+ super().__init__()
21
+ self.cfg = cfg
22
+ if load_backbone:
23
+ print("Loading from backbone weights")
24
+ self.vision_encoder = ViT.from_pretrained(cfg)
25
+ self.decoder = LanguageModel.from_pretrained(cfg)
26
+ else:
27
+ self.vision_encoder = ViT(cfg)
28
+ self.decoder = LanguageModel(cfg)
29
+ self.MP = ModalityProjector(cfg)
30
+ self.load_backbone = load_backbone
31
+
32
+ def forward(self, input_ids, image, attention_mask=None, targets=None):
33
+ image_embd = self.vision_encoder(image)
34
+ image_embd = self.MP(image_embd)
35
+
36
+ token_embd = self.decoder.token_embedding(input_ids)
37
+
38
+ combined_embd = torch.cat((image_embd, token_embd), dim=1) # Concatenate image embeddings to token embeddings
39
+
40
+ # Adjust attention mask to account for image tokens
41
+ if attention_mask is not None:
42
+ # Create mask of 1s for image tokens (all image tokens should be attended to)
43
+ batch_size = image_embd.size(0)
44
+ img_seq_len = image_embd.size(1)
45
+ image_attention_mask = torch.ones((batch_size, img_seq_len), device=attention_mask.device, dtype=attention_mask.dtype)
46
+
47
+ # Combine image and token attention masks
48
+ attention_mask = torch.cat((image_attention_mask, attention_mask), dim=1)
49
+
50
+ logits = self.decoder(combined_embd, attention_mask) # Not logits yet, but easier to return like this
51
+
52
+ loss = None
53
+ if targets is not None:
54
+ # Only use the token part of the logits for loss computation
55
+ logits = self.decoder.head(logits)
56
+ logits = logits[:, image_embd.size(1):, :]
57
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100)
58
+
59
+ return logits, loss
60
+
61
+ @torch.no_grad()
62
+ def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5):
63
+ # Process image through vision encoder and projection
64
+ image_embd = self.vision_encoder(image)
65
+ image_embd = self.MP(image_embd)
66
+
67
+ # Embed initial tokens
68
+ token_embd = self.decoder.token_embedding(input_ids)
69
+
70
+ # Concatenate image embeddings with token embeddings
71
+ combined_embd = torch.cat((image_embd, token_embd), dim=1)
72
+
73
+ batch_size = image_embd.size(0)
74
+ img_seq_len = image_embd.size(1)
75
+ # Adjust attention mask to account for image tokens
76
+ if attention_mask is not None:
77
+ # Create mask of 1s for image tokens (all image tokens should be attended to)
78
+ image_attention_mask = torch.ones((batch_size, img_seq_len), device=attention_mask.device, dtype=attention_mask.dtype)
79
+ attention_mask = torch.cat((image_attention_mask, attention_mask), dim=1)
80
+
81
+ # Generate from combined embeddings using the decoder
82
+ # We need to use the decoder's forward function and not its generate method
83
+ # because we want to keep track of the image prefix
84
+ outputs = combined_embd
85
+ generated_tokens = torch.zeros((batch_size, max_new_tokens), device=input_ids.device, dtype=input_ids.dtype)
86
+
87
+ #Note: Here you could implement improvements like e.g. KV caching
88
+ for i in range(max_new_tokens):
89
+ model_out = self.decoder(outputs, attention_mask)
90
+
91
+ # Get predictions for the last token only (normally this is the embedding, not the logits)
92
+ last_token_logits = model_out[:, -1, :]
93
+
94
+ # Apply head to get logits (if model is in embedding mode)
95
+ if not self.decoder.lm_use_tokens:
96
+ last_token_logits = self.decoder.head(last_token_logits)
97
+
98
+ probs = torch.softmax(last_token_logits, dim=-1)
99
+ next_token = torch.multinomial(probs, num_samples=1)
100
+
101
+ generated_tokens[:, i] = next_token.squeeze(-1)
102
+
103
+ # Convert to embedding and append
104
+ next_embd = self.decoder.token_embedding(next_token)
105
+ outputs = torch.cat((outputs, next_embd), dim=1)
106
+
107
+ if attention_mask is not None:
108
+ attention_mask = torch.cat((attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)), dim=1)
109
+
110
+ return generated_tokens
111
+
112
+ @classmethod
113
+ def from_pretrained(
114
+ cls, repo_id_or_path: str, *, revision: Optional[str] = None
115
+ ) -> "VisionLanguageModel":
116
+ """
117
+ Load a VisionLanguageModel from a local directory or a repo on the Hugging Face Hub.
118
+
119
+ Args:
120
+ repo_id_or_path (str): The path to the local directory or the Hugging Face Hub repo ID.
121
+
122
+ Returns:
123
+ VisionLanguageModel: The loaded model.
124
+ """
125
+ # If local folder exists => load from there
126
+ if os.path.exists(repo_id_or_path):
127
+ config_path = os.path.join(repo_id_or_path, "config.json")
128
+ weights_path = os.path.join(repo_id_or_path, "model.safetensors")
129
+
130
+ if not os.path.exists(config_path):
131
+ raise ValueError(
132
+ f"Config file not found at {config_path}. Please provide a valid path."
133
+ )
134
+ if not os.path.exists(weights_path):
135
+ raise ValueError(
136
+ f"Weights file not found at {weights_path}. Please provide a valid path."
137
+ )
138
+ # Otherwise, assume it's a Hugging Face Hub repo
139
+ else:
140
+ from huggingface_hub import hf_hub_download
141
+
142
+ config_path = hf_hub_download(
143
+ repo_id=repo_id_or_path, filename="config.json", revision=revision
144
+ )
145
+ weights_path = hf_hub_download(
146
+ repo_id=repo_id_or_path, filename="model.safetensors", revision=revision
147
+ )
148
+
149
+ # Load config
150
+ with open(config_path, "r") as f:
151
+ cfg = VLMConfig(**json.load(f))
152
+
153
+ # Initialize model without loading the backbone
154
+ model = cls(cfg, load_backbone=False)
155
+
156
+ # Load safetensors weights
157
+ load_model(model, weights_path)
158
+
159
+ # Done!
160
+ return model
161
+
162
+ def save_pretrained(self, save_directory: str) -> None:
163
+ """
164
+ Save the model and configuration to a directory.
165
+
166
+ Args:
167
+ save_directory (str): The directory to save the model and config.
168
+ """
169
+ # Create directory if it doesn't exist
170
+ os.makedirs(save_directory, exist_ok=True)
171
+
172
+ # Save config
173
+ with open(os.path.join(save_directory, "config.json"), "w") as f:
174
+ f.write(json.dumps(asdict(self.cfg), indent=4))
175
+
176
+ # Save weights as safetensors
177
+ save_model(self, os.path.join(save_directory, "model.safetensors"))
178
+
179
+ def push_to_hub(self, repo_id: str, private: bool = False) -> None:
180
+ """
181
+ Push the model and configuration to the Hugging Face Hub.
182
+
183
+ Args:
184
+ repo_id (str): The repo ID on the Hugging Face Hub.
185
+ """
186
+ from huggingface_hub import create_repo, upload_folder
187
+
188
+ # Create repo
189
+ repo_url = create_repo(repo_id=repo_id, private=private, exist_ok=True)
190
+ repo_id = repo_url.repo_id
191
+ print("Created repo: ", repo_url)
192
+
193
+ with tempfile.TemporaryDirectory() as save_path:
194
+ # Save to tmp directory
195
+ self.save_pretrained(save_path)
196
+
197
+ # Save model card
198
+ with open(os.path.join(save_path, "README.md"), "w") as f:
199
+ f.write(MODEL_CARD_TEMPLATE.format(repo_id=repo_id))
200
+
201
+ # Upload
202
+ return upload_folder(
203
+ repo_id=repo_id,
204
+ repo_type="model",
205
+ folder_path=save_path,
206
+ commit_message="Upload nanoVLM using push_to_hub",
207
+ )
208
+
209
+
210
+ MODEL_CARD_TEMPLATE = """
211
+ ---
212
+ # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
213
+ # Doc / guide: https://huggingface.co/docs/hub/model-cards
214
+ library_name: nanovlm
215
+ license: mit
216
+ pipeline_tag: image-text-to-text
217
+ tags:
218
+ - vision-language
219
+ - multimodal
220
+ - research
221
+ ---
222
+
223
+ **nanoVLM** is a minimal and lightweight Vision-Language Model (VLM) designed for efficient training and experimentation. Built using pure PyTorch, the entire model architecture and training logic fits within ~750 lines of code. It combines a ViT-based image encoder (SigLIP-B/16-224-85M) with a lightweight causal language model (SmolLM2-135M), resulting in a compact 222M parameter model.
224
+
225
+ For more information, check out the base model on https://huggingface.co/lusxvr/nanoVLM-222M.
226
+
227
+ **Usage:**
228
+
229
+ Clone the nanoVLM repository: https://github.com/huggingface/nanoVLM.
230
+ Follow the install instructions and run the following code:
231
+
232
+ ```python
233
+ from models.vision_language_model import VisionLanguageModel
234
+
235
+ model = VisionLanguageModel.from_pretrained("{repo_id}")
236
+ ```
237
+ """
models/vision_transformer.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/modeling_siglip.py#L245
7
+ class ViTPatchEmbeddings(nn.Module):
8
+ def __init__(self, cfg):
9
+ super().__init__()
10
+
11
+ self.img_size = cfg.vit_img_size
12
+ self.patch_size = cfg.vit_patch_size
13
+ self.num_patches = (self.img_size // self.patch_size) ** 2
14
+ self.cls_flag = cfg.vit_cls_flag
15
+ self.embd_dim = cfg.vit_hidden_dim
16
+
17
+ # Conv layer to extract the patches
18
+ self.conv = nn.Conv2d(
19
+ in_channels=3,
20
+ out_channels=self.embd_dim,
21
+ kernel_size=self.patch_size,
22
+ stride=self.patch_size,
23
+ padding="valid",
24
+ )
25
+
26
+ if self.cls_flag:
27
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embd_dim))
28
+ self.position_embedding = nn.Parameter(torch.rand(1, self.num_patches + 1, self.embd_dim))
29
+ else:
30
+ self.position_embedding = nn.Parameter(torch.rand(1, self.num_patches, self.embd_dim))
31
+
32
+
33
+ def forward(self, x):
34
+ x = self.conv(x) # extract patches
35
+ x = x.flatten(2) # flatten the patches into a single dimension
36
+ x = x.transpose(1, 2) # transpose to (batch_size, num_patches, hidden_dim)
37
+
38
+ # Add CLS token (according to original ViT Paper) and position embeddings
39
+ if self.cls_flag:
40
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
41
+ x = torch.cat((cls_token, x), dim=1)
42
+ x = x + self.position_embedding
43
+ return x
44
+
45
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/modeling_siglip.py#L381
46
+ # https://github.com/karpathy/nanoGPT/blob/master/model.py#L29
47
+ class ViTMultiHeadAttention(nn.Module):
48
+ def __init__(self, cfg):
49
+ super().__init__()
50
+
51
+ self.n_heads = cfg.vit_n_heads
52
+ self.embd_dim = cfg.vit_hidden_dim
53
+ assert self.embd_dim % self.n_heads == 0, "embd_dim must be divisible by num_heads"
54
+ self.head_dim = self.embd_dim // self.n_heads
55
+ self.dropout = cfg.vit_dropout
56
+
57
+ # Combined projections for all heads
58
+ self.qkv_proj = nn.Linear(self.embd_dim, 3 * self.embd_dim, bias=True)
59
+ self.out_proj = nn.Linear(self.embd_dim, self.embd_dim, bias=True)
60
+
61
+ # Dropout layers
62
+ self.attn_dropout = nn.Dropout(self.dropout)
63
+ self.resid_dropout = nn.Dropout(self.dropout)
64
+
65
+ # Use scaled dot product attention if available
66
+ self.sdpa = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
67
+ if not self.sdpa:
68
+ print("Warning: scaled dot product attention not available. Using standard attention in ViT.")
69
+
70
+ def forward(self, x):
71
+ B, T, C = x.size()
72
+
73
+ qkv = self.qkv_proj(x)
74
+ q, k, v = qkv.split(C, dim=2)
75
+ # Reshape [B, T, C] -> [B, T, n_heads, head_dim] and transpose -> [B, n_heads, T, head_dim]
76
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
77
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
78
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
79
+
80
+ if self.sdpa:
81
+ y = torch.nn.functional.scaled_dot_product_attention(
82
+ q, k, v,
83
+ attn_mask=None,
84
+ dropout_p=self.dropout if self.training else 0.0,
85
+ is_causal=False # ViT attention is bidirectional
86
+ )
87
+ else:
88
+ attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
89
+ attn = F.softmax(attn, dim=-1)
90
+ attn = self.attn_dropout(attn)
91
+ y = attn @ v # (B, n_heads, T, T) x (B, n_heads, T, head_dim) -> (B, n_heads, T, head_dim)
92
+
93
+ # Transpose back from [B, n_heads, T, head_dim] to [B, T, n_heads * head_dim] and combine all heads to [B, T, C]
94
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
95
+ y = self.out_proj(y)
96
+ y = self.resid_dropout(y)
97
+
98
+ return y
99
+
100
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/modeling_siglip.py#L453
101
+ class ViTMLP(nn.Module):
102
+ def __init__(self, cfg):
103
+ super().__init__()
104
+ self.activation_fn = nn.GELU(approximate='tanh')
105
+ self.fc1 = nn.Linear(cfg.vit_hidden_dim, cfg.vit_inter_dim)
106
+ self.fc2 = nn.Linear(cfg.vit_inter_dim, cfg.vit_hidden_dim)
107
+ self.dropout = nn.Dropout(cfg.vit_dropout)
108
+
109
+ def forward(self, x):
110
+ x = self.fc1(x)
111
+ x = self.activation_fn(x)
112
+ x = self.fc2(x)
113
+ x = self.dropout(x)
114
+ return x
115
+
116
+ # https://github.com/karpathy/nanoGPT/blob/master/model.py#L94
117
+ class ViTBlock(nn.Module):
118
+ def __init__(self, cfg):
119
+ super().__init__()
120
+ self.ln1 = nn.LayerNorm(cfg.vit_hidden_dim, eps=cfg.vit_ln_eps)
121
+ self.attn = ViTMultiHeadAttention(cfg)
122
+ self.ln2 = nn.LayerNorm(cfg.vit_hidden_dim, eps=cfg.vit_ln_eps)
123
+ self.mlp = ViTMLP(cfg)
124
+
125
+ def forward(self, x):
126
+ x = x + self.attn(self.ln1(x))
127
+ x = x + self.mlp(self.ln2(x))
128
+ return x
129
+
130
+
131
+ class ViT(nn.Module):
132
+ def __init__(self, cfg):
133
+ super().__init__()
134
+ self.cfg = cfg
135
+ self.patch_embedding = ViTPatchEmbeddings(cfg)
136
+ self.cls_flag = cfg.vit_cls_flag
137
+ self.dropout = nn.Dropout(cfg.vit_dropout)
138
+ self.blocks = nn.ModuleList([ViTBlock(cfg) for _ in range(cfg.vit_n_blocks)])
139
+ self.layer_norm = nn.LayerNorm(cfg.vit_hidden_dim, eps=cfg.vit_ln_eps)
140
+
141
+ self.apply(self._init_weights)
142
+
143
+ def _init_weights(self, module):
144
+ if isinstance(module, nn.Linear):
145
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
146
+ if module.bias is not None:
147
+ torch.nn.init.zeros_(module.bias)
148
+ elif isinstance(module, nn.LayerNorm):
149
+ module.bias.data.zero_()
150
+ module.weight.data.fill_(1.0)
151
+ elif isinstance(module, nn.Conv2d):
152
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
153
+ if module.bias is not None:
154
+ torch.nn.init.zeros_(module.bias)
155
+
156
+ def forward(self, x):
157
+ x = self.patch_embedding(x)
158
+ x = self.dropout(x)
159
+ for block in self.blocks:
160
+ x = block(x)
161
+
162
+ if self.cls_flag:
163
+ x = self.layer_norm(x[:, 0])
164
+ else:
165
+ x = self.layer_norm(x)
166
+ #x = x.mean(dim=1)
167
+
168
+ return x
169
+
170
+ # Load the model from a pretrained HuggingFace model (we don't want to have to train the Vision Backbone from scratch)
171
+ @classmethod
172
+ def from_pretrained(cls, cfg):
173
+ from transformers import SiglipVisionConfig
174
+ from huggingface_hub import hf_hub_download
175
+ import safetensors
176
+
177
+ hf_config = SiglipVisionConfig.from_pretrained(cfg.vit_model_type)
178
+ cfg.vit_dropout=hf_config.attention_dropout
179
+ cfg.vit_hidden_dim=hf_config.hidden_size
180
+ cfg.vit_img_size=hf_config.image_size
181
+ cfg.vit_inter_dim=hf_config.intermediate_size
182
+ cfg.vit_ln_eps=hf_config.layer_norm_eps
183
+ cfg.vit_n_heads=hf_config.num_attention_heads
184
+ cfg.vit_n_blocks=hf_config.num_hidden_layers
185
+ cfg.vit_patch_size=hf_config.patch_size
186
+ model = cls(cfg)
187
+ safetensors_file = hf_hub_download(repo_id=cfg.vit_model_type, filename="model.safetensors")
188
+
189
+ sd = model.state_dict()
190
+
191
+ mapping = {
192
+ 'vision_model.embeddings.patch_embedding.weight': 'patch_embedding.conv.weight',
193
+ 'vision_model.embeddings.patch_embedding.bias': 'patch_embedding.conv.bias',
194
+ 'vision_model.embeddings.position_embedding.weight': 'patch_embedding.position_embedding',
195
+ 'vision_model.post_layernorm.weight': 'layer_norm.weight',
196
+ 'vision_model.post_layernorm.bias': 'layer_norm.bias',
197
+ }
198
+
199
+ for i in range(cfg.vit_n_blocks):
200
+ # Layer norms
201
+ mapping[f'vision_model.encoder.layers.{i}.layer_norm1.weight'] = f'blocks.{i}.ln1.weight'
202
+ mapping[f'vision_model.encoder.layers.{i}.layer_norm1.bias'] = f'blocks.{i}.ln1.bias'
203
+ mapping[f'vision_model.encoder.layers.{i}.layer_norm2.weight'] = f'blocks.{i}.ln2.weight'
204
+ mapping[f'vision_model.encoder.layers.{i}.layer_norm2.bias'] = f'blocks.{i}.ln2.bias'
205
+
206
+ # MLP
207
+ mapping[f'vision_model.encoder.layers.{i}.mlp.fc1.weight'] = f'blocks.{i}.mlp.fc1.weight'
208
+ mapping[f'vision_model.encoder.layers.{i}.mlp.fc1.bias'] = f'blocks.{i}.mlp.fc1.bias'
209
+ mapping[f'vision_model.encoder.layers.{i}.mlp.fc2.weight'] = f'blocks.{i}.mlp.fc2.weight'
210
+ mapping[f'vision_model.encoder.layers.{i}.mlp.fc2.bias'] = f'blocks.{i}.mlp.fc2.bias'
211
+
212
+ # Output projection
213
+ mapping[f'vision_model.encoder.layers.{i}.self_attn.out_proj.weight'] = f'blocks.{i}.attn.out_proj.weight'
214
+ mapping[f'vision_model.encoder.layers.{i}.self_attn.out_proj.bias'] = f'blocks.{i}.attn.out_proj.bias'
215
+
216
+ with safetensors.safe_open(filename=safetensors_file, framework="pt", device="cpu") as f:
217
+ for hf_key, our_key in mapping.items():
218
+ if hf_key in f.keys() and our_key in sd:
219
+ tensor = f.get_tensor(hf_key)
220
+ if tensor.shape == sd[our_key].shape:
221
+ sd[our_key].copy_(tensor)
222
+ else:
223
+ if 'position_embedding' in hf_key:
224
+ sd[our_key].copy_(tensor.unsqueeze(0))
225
+ else:
226
+ print(f"Shape mismatch for {hf_key} -> {our_key}: {tensor.shape} vs {sd[our_key].shape}")
227
+ else:
228
+ if hf_key not in f.keys():
229
+ print(f"Warning: Key {hf_key} not found in safetensors file")
230
+ if our_key not in sd:
231
+ print(f"Warning: Key {our_key} not found in model state dict")
232
+
233
+ # Manually handle QKV concatenation since our implementation combines Q, K, V into one
234
+ for i in range(model.cfg.vit_n_blocks):
235
+ q_weight = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.q_proj.weight')
236
+ k_weight = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.k_proj.weight')
237
+ v_weight = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.v_proj.weight')
238
+
239
+ qkv_weight = torch.cat((q_weight, k_weight, v_weight), dim=0)
240
+ sd[f'blocks.{i}.attn.qkv_proj.weight'].copy_(qkv_weight)
241
+
242
+ q_bias = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.q_proj.bias')
243
+ k_bias = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.k_proj.bias')
244
+ v_bias = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.v_proj.bias')
245
+
246
+ qkv_bias = torch.cat((q_bias, k_bias, v_bias), dim=0)
247
+ sd[f'blocks.{i}.attn.qkv_proj.bias'].copy_(qkv_bias)
248
+
249
+ model.load_state_dict(sd)
250
+ print(f"Successfully loaded {cfg.vit_model_type} weights from safetensors. Model has {sum(p.numel() for p in model.parameters()):,} parameters.")
251
+ return model
requirements.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ certifi==2025.4.26
5
+ charset-normalizer==3.4.2
6
+ click==8.1.8
7
+ fastapi==0.115.12
8
+ ffmpy==0.5.0
9
+ filelock==3.18.0
10
+ fsspec==2025.3.2
11
+ gradio==5.29.1
12
+ gradio-client==1.10.1
13
+ groovy==0.1.2
14
+ h11==0.16.0
15
+ httpcore==1.0.9
16
+ httpx==0.28.1
17
+ huggingface-hub==0.31.2
18
+ idna==3.10
19
+ jinja2==3.1.6
20
+ markdown-it-py==3.0.0
21
+ markupsafe==3.0.2
22
+ mdurl==0.1.2
23
+ mpmath==1.3.0
24
+ networkx==3.4.2
25
+ numpy==2.2.5
26
+ nvidia-cublas-cu12==12.6.4.1
27
+ nvidia-cuda-cupti-cu12==12.6.80
28
+ nvidia-cuda-nvrtc-cu12==12.6.77
29
+ nvidia-cuda-runtime-cu12==12.6.77
30
+ nvidia-cudnn-cu12==9.5.1.17
31
+ nvidia-cufft-cu12==11.3.0.4
32
+ nvidia-cufile-cu12==1.11.1.6
33
+ nvidia-curand-cu12==10.3.7.77
34
+ nvidia-cusolver-cu12==11.7.1.2
35
+ nvidia-cusparse-cu12==12.5.4.2
36
+ nvidia-cusparselt-cu12==0.6.3
37
+ nvidia-nccl-cu12==2.26.2
38
+ nvidia-nvjitlink-cu12==12.6.85
39
+ nvidia-nvtx-cu12==12.6.77
40
+ orjson==3.10.18
41
+ packaging==25.0
42
+ pandas==2.2.3
43
+ pillow==11.2.1
44
+ psutil==5.9.8
45
+ pydantic==2.11.4
46
+ pydantic-core==2.33.2
47
+ pydub==0.25.1
48
+ pygments==2.19.1
49
+ python-dateutil==2.9.0.post0
50
+ python-multipart==0.0.20
51
+ pytz==2025.2
52
+ pyyaml==6.0.2
53
+ regex==2024.11.6
54
+ requests==2.32.3
55
+ rich==14.0.0
56
+ ruff==0.11.10
57
+ safehttpx==0.1.6
58
+ safetensors==0.5.3
59
+ semantic-version==2.10.0
60
+ setuptools==80.7.1
61
+ shellingham==1.5.4
62
+ six==1.17.0
63
+ sniffio==1.3.1
64
+ spaces==0.36.0
65
+ starlette==0.46.2
66
+ sympy==1.14.0
67
+ tokenizers==0.21.1
68
+ tomlkit==0.13.2
69
+ torch==2.7.0
70
+ torchvision==0.22.0
71
+ tqdm==4.67.1
72
+ transformers==4.51.3
73
+ triton==3.3.0
74
+ typer==0.15.4
75
+ typing-extensions==4.13.2
76
+ typing-inspection==0.4.0
77
+ tzdata==2025.2
78
+ urllib3==2.4.0
79
+ uvicorn==0.34.2
80
+ websockets==15.0.1