ayeshaishaq commited on
Commit
28c10a3
·
verified ·
1 Parent(s): 7c12211

Upload modeling_internvl_chat.py

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +42 -0
modeling_internvl_chat.py CHANGED
@@ -16,6 +16,7 @@ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
16
  from transformers.modeling_outputs import CausalLMOutputWithPast
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import ModelOutput, logging
 
19
 
20
  from .configuration_internvl_chat import InternVLChatConfig
21
  from .conversation import get_conv_template
@@ -71,6 +72,7 @@ class InternVLChatModel(PreTrainedModel):
71
  self.language_model = InternLM2ForCausalLM(config.llm_config)
72
  else:
73
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
 
74
 
75
  vit_hidden_size = config.vision_config.hidden_size
76
  llm_hidden_size = config.llm_config.hidden_size
@@ -85,6 +87,46 @@ class InternVLChatModel(PreTrainedModel):
85
  self.img_context_token_id = None
86
  self.conv_template = get_conv_template(self.template)
87
  self.system_message = self.conv_template.system_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def forward(
90
  self,
 
16
  from transformers.modeling_outputs import CausalLMOutputWithPast
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import ModelOutput, logging
19
+ from peft import LoraConfig, get_peft_model
20
 
21
  from .configuration_internvl_chat import InternVLChatConfig
22
  from .conversation import get_conv_template
 
72
  self.language_model = InternLM2ForCausalLM(config.llm_config)
73
  else:
74
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
75
+ self.llm_arch_name = config.llm_config.architectures[0]
76
 
77
  vit_hidden_size = config.vision_config.hidden_size
78
  llm_hidden_size = config.llm_config.hidden_size
 
87
  self.img_context_token_id = None
88
  self.conv_template = get_conv_template(self.template)
89
  self.system_message = self.conv_template.system_message
90
+ self.img_context_token_id = None
91
+
92
+ if config.use_backbone_lora:
93
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
94
+
95
+ if config.use_llm_lora:
96
+ self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
97
+
98
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
99
+ lora_config = LoraConfig(
100
+ r=r,
101
+ target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
102
+ lora_alpha=lora_alpha,
103
+ lora_dropout=lora_dropout,
104
+ )
105
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
106
+ self.vision_model.print_trainable_parameters()
107
+
108
+ def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
109
+ # Determine the target modules based on the architecture of the language model
110
+ if self.llm_arch_name == 'InternLM2ForCausalLM':
111
+ target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']
112
+ elif self.llm_arch_name == 'Phi3ForCausalLM':
113
+ target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj']
114
+ elif self.llm_arch_name in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
115
+ target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
116
+ 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj']
117
+ else:
118
+ raise NotImplemented
119
+ lora_config = LoraConfig(
120
+ r=r,
121
+ target_modules=target_modules,
122
+ lora_alpha=lora_alpha,
123
+ lora_dropout=lora_dropout,
124
+ task_type='CAUSAL_LM'
125
+ )
126
+ self.language_model = get_peft_model(self.language_model, lora_config)
127
+ self.language_model.enable_input_require_grads()
128
+ self.language_model.print_trainable_parameters()
129
+
130
 
131
  def forward(
132
  self,