|
import torch |
|
import torch.nn as nn |
|
|
|
from types import MethodType |
|
from typing import List, Optional, Tuple, Union |
|
from copy import deepcopy |
|
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM, AutoConfig |
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
|
|
|
class HangulGemmaDeobfuscatorConfig(PretrainedConfig): |
|
model_type = "hangul_gemma_deobfuscator" |
|
|
|
def __init__( |
|
self, |
|
base_model_name='unsloth/gemma-2-2b', |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.base_model_name = base_model_name |
|
|
|
|
|
class HangulGemmaDeobfuscator(PreTrainedModel): |
|
config_class = HangulGemmaDeobfuscatorConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.base_model_config = AutoConfig.from_pretrained(config.base_model_name) |
|
self.base_model_config.training = True |
|
self.base_model_config._attn_implementation = 'eager' |
|
self.base_model_config.sliding_window = 12 |
|
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name, config=self.base_model_config) |
|
|
|
new_layers = [] |
|
layer_indices = [24, 25] |
|
for i in range(len(base_model.model.layers)): |
|
if i in layer_indices: |
|
new_layers.append(deepcopy(base_model.model.layers[i])) |
|
new_layers.append(base_model.model.layers[i]) |
|
new_layers = nn.ModuleList(new_layers) |
|
base_model.model.layers = new_layers |
|
base_model.config.num_hidden_layers = len(base_model.model.layers) |
|
|
|
for layer_idx, layer in enumerate(base_model.model.layers): |
|
layer.is_sliding = not bool(layer_idx % 2) |
|
|
|
|
|
for idx, layer in enumerate(base_model.model.layers): |
|
layer.forward = MethodType(decoder_forward, layer) |
|
|
|
self.model = base_model |
|
|
|
def load_hangul_tokenizer(self, hangul_tokenizer): |
|
self.tokenizer = hangul_tokenizer |
|
self.cho_ids = nn.Parameter(torch.LongTensor(self.tokenizer.cho_ids), requires_grad=False) |
|
self.joong_ids = nn.Parameter(torch.LongTensor(self.tokenizer.joong_ids), requires_grad=False) |
|
self.jong_ids = nn.Parameter(torch.LongTensor(self.tokenizer.jong_ids), requires_grad=False) |
|
self.char_1ids = nn.Parameter(torch.LongTensor(self.tokenizer.char_1ids), requires_grad=False) |
|
self.char_3ids = nn.Parameter(torch.LongTensor(self.tokenizer.char_3ids), requires_grad=False) |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids, output_ids=None): |
|
input_ids = torch.cat([ |
|
torch.full((input_ids.size(0), 1), self.tokenizer.base_tokenizer.bos_token_id, dtype=input_ids.dtype, device=input_ids.device), |
|
input_ids, |
|
], dim=1) |
|
|
|
attention_mask = torch.cat([ |
|
torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device), |
|
attention_mask, |
|
], dim=1) |
|
attention_mask = _prepare_4d_attention_mask(attention_mask, self.model.dtype) |
|
attention_mask = attention_mask == 0 |
|
|
|
logits = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
use_cache=False |
|
).logits[:, :-1, :] |
|
|
|
loss = None |
|
if output_ids is not None: |
|
loss = nn.CrossEntropyLoss(reduction='mean')( |
|
logits.reshape(-1, self.model.config.vocab_size), |
|
output_ids.reshape(-1), |
|
) |
|
return loss, logits |
|
|
|
def pred_jamo_ids( |
|
self, |
|
logits, |
|
input_ids, |
|
token_type_ids, |
|
): |
|
pred_ids = input_ids.clone() |
|
logits_cho = logits[token_type_ids==1][:, self.cho_ids] |
|
logits_joong = logits[token_type_ids==2][:, self.joong_ids] |
|
logits_jong = logits[token_type_ids==3][:, self.jong_ids] |
|
pred_cho_ids = self.cho_ids[logits_cho.argmax(1)] |
|
pred_joong_ids = self.joong_ids[logits_joong.argmax(1)] |
|
pred_jong_ids = self.jong_ids[logits_jong.argmax(1)] |
|
pred_ids[token_type_ids==1] = pred_cho_ids |
|
pred_ids[token_type_ids==2] = pred_joong_ids |
|
pred_ids[token_type_ids==3] = pred_jong_ids |
|
return pred_ids |
|
|
|
def pred_char_ids( |
|
self, |
|
logits, |
|
input_ids, |
|
token_type_ids |
|
): |
|
pred_ids = input_ids.clone() |
|
logits_char = logits[token_type_ids==4] |
|
if not len(logits_char): |
|
return pred_ids |
|
logits_char_chunks = logits_char.split(3) |
|
pred_char_ids = [] |
|
for logits_char_chunk in logits_char_chunks: |
|
if logits_char_chunk[0][self.char_1ids[:,0]].max() > logits_char_chunk[0][self.char_3ids[:,0]].max(): |
|
pred_char_ids.extend( self.char_1ids[ logits_char_chunk[0][self.char_1ids[:,0]].argmax() ] ) |
|
else: |
|
logits_char_3ids = torch.stack([ |
|
logits_char_chunk[0][self.char_3ids[:,0]], |
|
logits_char_chunk[1][self.char_3ids[:,1]], |
|
logits_char_chunk[2][self.char_3ids[:,2]] |
|
], 1) |
|
pred_char_ids.extend( self.char_3ids[ logits_char_3ids.log_softmax(-1).sum(1).argmax() ] ) |
|
pred_ids[token_type_ids==4] = torch.LongTensor(pred_char_ids).type_as(pred_ids) |
|
return pred_ids |
|
|
|
def _deobfuscate_by_syllable(self, sentence): |
|
sentences = [sentence] |
|
char_input_ids, char_attention_mask, char_token_type_ids = self.tokenizer.batch_encode_char(sentences) |
|
char_input_ids, char_attention_mask, char_token_type_ids = char_input_ids.to(self.device), char_attention_mask.to(self.device), char_token_type_ids.to(self.device) |
|
_, logits_char = self(char_input_ids, char_attention_mask, char_token_type_ids) |
|
pred_char_ids = self.pred_char_ids(logits_char, char_input_ids, char_token_type_ids) |
|
pred_char_ids = pred_char_ids.detach().to('cpu').tolist() |
|
char_token_type_ids = char_token_type_ids.detach().to('cpu').tolist() |
|
decoded = self.tokenizer.decode_char(pred_char_ids[0],char_token_type_ids[0]) |
|
return decoded |
|
|
|
def _deobfuscate(self, sentence): |
|
sentences = [sentence] |
|
char_input_ids, char_attention_mask, char_token_type_ids = self.tokenizer.batch_encode_char(sentences) |
|
char_input_ids, char_attention_mask, char_token_type_ids = char_input_ids.to(self.device), char_attention_mask.to(self.device), char_token_type_ids.to(self.device) |
|
_, logits_char = self(char_input_ids, char_attention_mask, char_token_type_ids) |
|
pred_char_ids = self.pred_char_ids(logits_char, char_input_ids, char_token_type_ids) |
|
pred_char_ids = pred_char_ids.detach().to('cpu').tolist() |
|
char_token_type_ids = char_token_type_ids.detach().to('cpu').tolist() |
|
|
|
jamo_input_ids, jamo_attention_mask, jamo_token_type_ids = self.tokenizer.batch_encode_jamo_from_char_encoded(pred_char_ids, char_token_type_ids) |
|
jamo_input_ids, jamo_attention_mask, jamo_token_type_ids = jamo_input_ids.type_as(char_input_ids), jamo_attention_mask.type_as(char_attention_mask), jamo_token_type_ids.type_as(jamo_attention_mask) |
|
_, logits_jamo = self(jamo_input_ids, jamo_attention_mask, jamo_token_type_ids) |
|
pred_jamo_ids = self.pred_jamo_ids(logits_jamo, jamo_input_ids, jamo_token_type_ids) |
|
pred_jamo_ids = pred_jamo_ids.detach().to('cpu').tolist() |
|
y_pred = [self.tokenizer.decode_jamo(pred_jamo_id, jamo_token_type_id) for pred_jamo_id, jamo_token_type_id in zip(pred_jamo_ids, jamo_token_type_ids.tolist())] |
|
return y_pred[0] |
|
|
|
def deobfuscate(self, sentence, sentence_tokenizer=None): |
|
if sentence_tokenizer is not None: |
|
chunks_row = sentence_tokenizer.tokenize(sentence) |
|
chunks_overlap_row = sentence_tokenizer.overlap(chunks_row) |
|
row = [] |
|
for start_idx, end_idx, chunk_overlap_row in chunks_overlap_row: |
|
row.append((start_idx, end_idx, self._deobfuscate(chunk_overlap_row))) |
|
return sentence_tokenizer.decode_overlap(row) |
|
else: |
|
return self._deobfuscate(sentence) |
|
|
|
|
|
|
|
|
|
def decoder_forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value=None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
last_cache_position: int = 0, |
|
**kwargs, |
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
if self.is_sliding and attention_mask is not None: |
|
attention_mask = torch.tril(torch.triu(attention_mask, diagonal=-self.sliding_window), diagonal=self.sliding_window) |
|
|
|
residual = hidden_states |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
position_embeddings=position_embeddings, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.pre_feedforward_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = self.post_feedforward_layernorm(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
return outputs |
|
|