File size: 9,989 Bytes
61ddbf5 06b1bd7 61ddbf5 06b1bd7 61ddbf5 06b1bd7 61ddbf5 06b1bd7 f5ebd10 06b1bd7 f5ebd10 06b1bd7 61ddbf5 06b1bd7 61ddbf5 06b1bd7 61ddbf5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import torch
import torch.nn as nn
from types import MethodType
from typing import List, Optional, Tuple, Union
from copy import deepcopy
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM, AutoConfig
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
class HangulGemmaDeobfuscatorConfig(PretrainedConfig):
model_type = "hangul_gemma_deobfuscator"
def __init__(
self,
base_model_name='unsloth/gemma-2-2b',
**kwargs
):
super().__init__(**kwargs)
self.base_model_name = base_model_name
class HangulGemmaDeobfuscator(PreTrainedModel):
config_class = HangulGemmaDeobfuscatorConfig
def __init__(self, config):
super().__init__(config)
self.base_model_config = AutoConfig.from_pretrained(config.base_model_name)
self.base_model_config.training = True
self.base_model_config._attn_implementation = 'eager'
self.base_model_config.sliding_window = 12
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name, config=self.base_model_config)
new_layers = []
layer_indices = [24, 25]
for i in range(len(base_model.model.layers)):
if i in layer_indices:
new_layers.append(deepcopy(base_model.model.layers[i]))
new_layers.append(base_model.model.layers[i])
new_layers = nn.ModuleList(new_layers)
base_model.model.layers = new_layers
base_model.config.num_hidden_layers = len(base_model.model.layers)
for layer_idx, layer in enumerate(base_model.model.layers):
layer.is_sliding = not bool(layer_idx % 2)
# 모델의 모든 레이어에 대해 forward 메서드를 교체
for idx, layer in enumerate(base_model.model.layers):
layer.forward = MethodType(decoder_forward, layer)
self.model = base_model
def load_hangul_tokenizer(self, hangul_tokenizer):
self.tokenizer = hangul_tokenizer
self.cho_ids = nn.Parameter(torch.LongTensor(self.tokenizer.cho_ids), requires_grad=False)
self.joong_ids = nn.Parameter(torch.LongTensor(self.tokenizer.joong_ids), requires_grad=False)
self.jong_ids = nn.Parameter(torch.LongTensor(self.tokenizer.jong_ids), requires_grad=False)
self.char_1ids = nn.Parameter(torch.LongTensor(self.tokenizer.char_1ids), requires_grad=False)
self.char_3ids = nn.Parameter(torch.LongTensor(self.tokenizer.char_3ids), requires_grad=False)
def forward(self, input_ids, attention_mask, token_type_ids, output_ids=None):
input_ids = torch.cat([
torch.full((input_ids.size(0), 1), self.tokenizer.base_tokenizer.bos_token_id, dtype=input_ids.dtype, device=input_ids.device),
input_ids,
], dim=1)
attention_mask = torch.cat([
torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device),
attention_mask,
], dim=1)
attention_mask = _prepare_4d_attention_mask(attention_mask, self.model.dtype)
attention_mask = attention_mask == 0
logits = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False
).logits[:, :-1, :]
loss = None
if output_ids is not None:
loss = nn.CrossEntropyLoss(reduction='mean')(
logits.reshape(-1, self.model.config.vocab_size),
output_ids.reshape(-1),
)
return loss, logits
def pred_jamo_ids(
self,
logits,
input_ids,
token_type_ids,
):
pred_ids = input_ids.clone()
logits_cho = logits[token_type_ids==1][:, self.cho_ids]
logits_joong = logits[token_type_ids==2][:, self.joong_ids]
logits_jong = logits[token_type_ids==3][:, self.jong_ids]
pred_cho_ids = self.cho_ids[logits_cho.argmax(1)]
pred_joong_ids = self.joong_ids[logits_joong.argmax(1)]
pred_jong_ids = self.jong_ids[logits_jong.argmax(1)]
pred_ids[token_type_ids==1] = pred_cho_ids
pred_ids[token_type_ids==2] = pred_joong_ids
pred_ids[token_type_ids==3] = pred_jong_ids
return pred_ids
def pred_char_ids(
self,
logits,
input_ids,
token_type_ids
):
pred_ids = input_ids.clone()
logits_char = logits[token_type_ids==4]
if not len(logits_char):
return pred_ids
logits_char_chunks = logits_char.split(3)
pred_char_ids = []
for logits_char_chunk in logits_char_chunks:
if logits_char_chunk[0][self.char_1ids[:,0]].max() > logits_char_chunk[0][self.char_3ids[:,0]].max():
pred_char_ids.extend( self.char_1ids[ logits_char_chunk[0][self.char_1ids[:,0]].argmax() ] )
else:
logits_char_3ids = torch.stack([
logits_char_chunk[0][self.char_3ids[:,0]],
logits_char_chunk[1][self.char_3ids[:,1]],
logits_char_chunk[2][self.char_3ids[:,2]]
], 1)
pred_char_ids.extend( self.char_3ids[ logits_char_3ids.log_softmax(-1).sum(1).argmax() ] )
pred_ids[token_type_ids==4] = torch.LongTensor(pred_char_ids).type_as(pred_ids)
return pred_ids
def _deobfuscate_by_syllable(self, sentence):
sentences = [sentence]
char_input_ids, char_attention_mask, char_token_type_ids = self.tokenizer.batch_encode_char(sentences)
char_input_ids, char_attention_mask, char_token_type_ids = char_input_ids.to(self.device), char_attention_mask.to(self.device), char_token_type_ids.to(self.device)
_, logits_char = self(char_input_ids, char_attention_mask, char_token_type_ids)
pred_char_ids = self.pred_char_ids(logits_char, char_input_ids, char_token_type_ids)
pred_char_ids = pred_char_ids.detach().to('cpu').tolist()
char_token_type_ids = char_token_type_ids.detach().to('cpu').tolist()
decoded = self.tokenizer.decode_char(pred_char_ids[0],char_token_type_ids[0])
return decoded
def _deobfuscate(self, sentence):
sentences = [sentence]
char_input_ids, char_attention_mask, char_token_type_ids = self.tokenizer.batch_encode_char(sentences)
char_input_ids, char_attention_mask, char_token_type_ids = char_input_ids.to(self.device), char_attention_mask.to(self.device), char_token_type_ids.to(self.device)
_, logits_char = self(char_input_ids, char_attention_mask, char_token_type_ids)
pred_char_ids = self.pred_char_ids(logits_char, char_input_ids, char_token_type_ids)
pred_char_ids = pred_char_ids.detach().to('cpu').tolist()
char_token_type_ids = char_token_type_ids.detach().to('cpu').tolist()
jamo_input_ids, jamo_attention_mask, jamo_token_type_ids = self.tokenizer.batch_encode_jamo_from_char_encoded(pred_char_ids, char_token_type_ids)
jamo_input_ids, jamo_attention_mask, jamo_token_type_ids = jamo_input_ids.type_as(char_input_ids), jamo_attention_mask.type_as(char_attention_mask), jamo_token_type_ids.type_as(jamo_attention_mask)
_, logits_jamo = self(jamo_input_ids, jamo_attention_mask, jamo_token_type_ids)
pred_jamo_ids = self.pred_jamo_ids(logits_jamo, jamo_input_ids, jamo_token_type_ids)
pred_jamo_ids = pred_jamo_ids.detach().to('cpu').tolist()
y_pred = [self.tokenizer.decode_jamo(pred_jamo_id, jamo_token_type_id) for pred_jamo_id, jamo_token_type_id in zip(pred_jamo_ids, jamo_token_type_ids.tolist())]
return y_pred[0]
def deobfuscate(self, sentence, sentence_tokenizer=None):
if sentence_tokenizer is not None:
chunks_row = sentence_tokenizer.tokenize(sentence)
chunks_overlap_row = sentence_tokenizer.overlap(chunks_row)
row = []
for start_idx, end_idx, chunk_overlap_row in chunks_overlap_row:
row.append((start_idx, end_idx, self._deobfuscate(chunk_overlap_row)))
return sentence_tokenizer.decode_overlap(row)
else:
return self._deobfuscate(sentence)
def decoder_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
attention_mask = torch.tril(torch.triu(attention_mask, diagonal=-self.sliding_window), diagonal=self.sliding_window)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
|