ZivK commited on
Commit
1381033
·
1 Parent(s): 66df394

added model file

Browse files
Files changed (1) hide show
  1. model.py +90 -0
model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from peft import LoraConfig, get_peft_model
4
+ from torch import nn as nn
5
+ from torchmetrics import Accuracy
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+
9
+ base_checkpoint = "HuggingFaceTB/SmolLM2-360M"
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ criterion = nn.BCEWithLogitsLoss()
12
+
13
+
14
+ class SmolLM(pl.LightningModule):
15
+ def __init__(self, learning_rate=3e-4):
16
+ super().__init__()
17
+ self.learning_rate = learning_rate
18
+ self.criterion = criterion
19
+ self.tokenizer = AutoTokenizer.from_pretrained(base_checkpoint)
20
+ self.tokenizer.pad_token = self.tokenizer.eos_token
21
+ self.base_model = AutoModelForCausalLM.from_pretrained(base_checkpoint).to(device)
22
+ self.base_model.lm_head = nn.Identity()
23
+ self.classifier = nn.Sequential(
24
+ nn.Linear(960, 128),
25
+ nn.ReLU(),
26
+ nn.Linear(128, 1),
27
+ )
28
+ # Freeze smollm2 parameters
29
+ for param in self.base_model.parameters():
30
+ param.requires_grad = False
31
+ # LoRA fine-tuning
32
+ lora_config = LoraConfig(
33
+ r=8,
34
+ lora_alpha=32,
35
+ target_modules=["q_proj", "v_proj", 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
36
+ lora_dropout=0.0,
37
+ bias="none",
38
+ use_dora=True
39
+ )
40
+ self.base_model = get_peft_model(self.base_model, lora_config)
41
+ self.save_hyperparameters()
42
+ self.val_accuracy = Accuracy(task="binary")
43
+
44
+ def forward(self, x):
45
+ input_ids = x["input_ids"]
46
+ attention_mask = x["attention_mask"]
47
+
48
+ out = self.base_model(input_ids, attention_mask=attention_mask)
49
+ logits = out.logits # shape: (batch_size, seq_len, hidden_dim)
50
+
51
+ # Calculate the index of the last non-padding token for each sequence
52
+ last_token_indices = attention_mask.sum(dim=1) - 1
53
+ real_batch_size = logits.size(0)
54
+ batch_indices = torch.arange(real_batch_size, device=device)
55
+
56
+ # Select logits corresponding to the last non-padding token
57
+ last_logits = logits[batch_indices, last_token_indices, :]
58
+
59
+ output_logits = self.classifier(last_logits)
60
+ return output_logits.squeeze(-1)
61
+
62
+ def training_step(self, batch, batch_idx):
63
+ sentences = batch["sentence"]
64
+ labels = batch["eos_label"].to(device)
65
+ inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
66
+ logits = self(inputs)
67
+ loss = self.criterion(logits, labels)
68
+ self.log('Train Step Loss', loss, prog_bar=True)
69
+ return loss
70
+
71
+ def validation_step(self, batch, batch_idx):
72
+ sentences = batch["sentence"]
73
+ labels = batch["eos_label"].to(device)
74
+ inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
75
+ logits = self(inputs)
76
+ loss = self.criterion(logits, labels)
77
+ preds = (torch.sigmoid(logits) > 0.5).long()
78
+ self.val_accuracy.update(preds, labels.long())
79
+ self.log('Validation Step Loss', loss, prog_bar=True)
80
+ return loss
81
+
82
+ def on_validation_epoch_end(self):
83
+ acc = self.val_accuracy.compute()
84
+ self.log('Validation Accuracy', acc, prog_bar=True)
85
+ self.val_accuracy.reset()
86
+
87
+ def configure_optimizers(self):
88
+ optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
89
+ return optimizer
90
+