import torch from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoModelForCausalLM, OPTForCausalLM, BitsAndBytesConfig from torch import nn import os from typing import Optional, List import os def kmp_preprocess(pattern): pattern_len = len(pattern) prefix_suffix = [0] * pattern_len j = 0 for i in range(1, pattern_len): while j > 0 and pattern[i] != pattern[j]: j = prefix_suffix[j - 1] if pattern[i] == pattern[j]: j += 1 prefix_suffix[i] = j return prefix_suffix def kmp_search(text, pattern): text_len = len(text) pattern_len = len(pattern) prefix_suffix = kmp_preprocess(pattern) matches = [] j = 0 for i in range(text_len): while j > 0 and text[i] != pattern[j]: j = prefix_suffix[j - 1] if text[i] == pattern[j]: j += 1 if j == pattern_len: matches.append(i - j + 1) j = prefix_suffix[j - 1] return matches class ModelWrapper: def __init__(self, model): self.model = model def __getattr__(self, name): return getattr(self.model, name) @torch.no_grad() def __call__(self, pixel_values): return self.model(pixel_values) def eval(self): pass def train(self): pass def parameters(self): return self.model.parameters() class CrelloModelConfig(PretrainedConfig): def __init__( self, old_vocab_size: int = 32000, vocab_size: int = 32000, pad_token_id: int = 2, ignore_ids: List[int] = [], freeze_lm: bool = True, # lm.eval() opt_version: str = 'facebook/opt-6.7b', task: str = 'captioning', use_lora: bool = False, lora_alpha: int = 32, lora_r: int = 8, lora_dropout: float = 0.05, lora_target_modules: str = r'.*\.(q_proj|v_proj)', hidden_size: int = -1, load_in_4bit: Optional[bool] = False, **kwargs, ): super().__init__(**kwargs) assert old_vocab_size > 0, 'old_vocab_size must be positive' assert vocab_size > 0, 'vocab_size must be positive' self.old_vocab_size = old_vocab_size self.vocab_size = vocab_size self.pad_token_id = pad_token_id self.freeze_lm = freeze_lm self.opt_version = opt_version self.task = task self.use_lora = use_lora self.lora_alpha = lora_alpha self.lora_r = lora_r self.lora_dropout = lora_dropout self.lora_target_modules = lora_target_modules self.hidden_size = hidden_size self.load_in_4bit = load_in_4bit self.ignore_ids = ignore_ids class CrelloModel(PreTrainedModel): config_class = CrelloModelConfig supports_gradient_checkpointing = True def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): self.lm.gradient_checkpointing_enable() def __init__(self, config: CrelloModelConfig): # 显示声明config类型 super().__init__(config) use_auth_token = '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' self.pad_token_id = config.pad_token_id self.args = config opt_version = config.opt_version print(f"Using {opt_version} for the language model.") if 'facebook/opt' in opt_version: self.lm = OPTForCausalLM.from_pretrained(opt_version) word_embed_proj_dim = self.lm.config.word_embed_proj_dim else: if config.load_in_4bit: print("\n would load_in_4bit") quantization_config = BitsAndBytesConfig( load_in_4bit=config.load_in_4bit ) # This means: fit the entire model on the GPU:0 local_rank = int(os.environ.get("LOCAL_RANK", 0)) device_map = {"": local_rank} torch_dtype = torch.bfloat16 else: print("\n wouldn't load_in_4bit") quantization_config = None device_map = None torch_dtype = None self.lm = AutoModelForCausalLM.from_pretrained( "WYBar/LLM_For_Layout_Planning", subfolder="Meta-Llama-3-8B", # use_auth_token=use_auth_token, # quantization_config=quantization_config, # device_map=device_map, trust_remote_code=True, torch_dtype=torch.bfloat16, # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir", ) # self.lm = AutoModelForCausalLM.from_pretrained( # opt_version, # use_auth_token=use_auth_token, # quantization_config=quantization_config, # device_map=device_map, # trust_remote_code=True, # # attn_implementation="flash_attention_2", # # flash_attn=True, # # flash_rotary=True, # # fused_dense=True, # torch_dtype=torch.bfloat16, # ) word_embed_proj_dim = self.lm.config.hidden_size self.config.hidden_size = self.lm.config.hidden_size self.opt_version = opt_version if self.args.freeze_lm: self.lm.eval() print("Freezing the LM.") for param in self.lm.parameters(): param.requires_grad = False else: print("\n no freeze lm, so to train lm") self.lm.train() self.lm.config.gradient_checkpointing = True print('resize token embeddings to match the tokenizer', config.vocab_size) self.lm.resize_token_embeddings(config.vocab_size) self.input_embeddings = self.lm.get_input_embeddings() print('after token embeddings to match the tokenizer', config.vocab_size) def train(self, mode=True): super().train(mode=mode) # Overwrite train() to ensure frozen models remain frozen. if self.args.freeze_lm: self.lm.eval() def forward( self, labels: torch.LongTensor, ): print("inside Crello") batch_size = labels.shape[0] full_labels = labels.detach().clone() input_embs = self.input_embeddings(labels) # (N, T, D) input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean() for ignore_id in self.config.ignore_ids: full_labels[full_labels == ignore_id] = -100 pad_idx = [] # 获取每一个batch的 seq 长度,取值为 max_len or padding_position,记录在pad_idx # -100 is the ignore index for cross entropy loss. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for label in full_labels: for k, token in enumerate(label): # Mask out pad tokens if they exist. if token in [self.pad_token_id]: label[k:] = -100 # 将后面的token都mask掉 pad_idx.append(k) break if k == len(label) - 1: # No padding found. pad_idx.append(k + 1) assert len(pad_idx) == batch_size, (len(pad_idx), batch_size) print("inside Crello, lm1") output = self.lm( inputs_embeds=input_embs, # input_ids=labels, labels=full_labels, output_hidden_states=True) print("inside Crello, lm2") return output, full_labels, input_embs_norm if __name__=="__main__": config = CrelloModelConfig( vocab_size=50265, image_reg_token=50264, image_gt_token=50263, ) print("config: ",config) model1 = CrelloModel(config) print("\nmodel1: ",model1) model1.save_pretrained('test') model2 = CrelloModel.from_pretrained('test') print("\nmodel2: ",model2) # compare model1 and model2 state_dict1 = model1.state_dict() state_dict2 = model2.state_dict() assert set(state_dict1.keys()) == set(state_dict2.keys()) for k in state_dict1.keys(): assert torch.equal(state_dict1[k], state_dict2[k]) print('all parameters are equal')