Upload modeling_internvl_chat.py
Browse files- modeling_internvl_chat.py +42 -0
modeling_internvl_chat.py
CHANGED
@@ -16,6 +16,7 @@ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
|
|
16 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
17 |
from transformers.modeling_utils import PreTrainedModel
|
18 |
from transformers.utils import ModelOutput, logging
|
|
|
19 |
|
20 |
from .configuration_internvl_chat import InternVLChatConfig
|
21 |
from .conversation import get_conv_template
|
@@ -71,6 +72,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
71 |
self.language_model = InternLM2ForCausalLM(config.llm_config)
|
72 |
else:
|
73 |
raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
|
|
|
74 |
|
75 |
vit_hidden_size = config.vision_config.hidden_size
|
76 |
llm_hidden_size = config.llm_config.hidden_size
|
@@ -85,6 +87,46 @@ class InternVLChatModel(PreTrainedModel):
|
|
85 |
self.img_context_token_id = None
|
86 |
self.conv_template = get_conv_template(self.template)
|
87 |
self.system_message = self.conv_template.system_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
def forward(
|
90 |
self,
|
|
|
16 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
17 |
from transformers.modeling_utils import PreTrainedModel
|
18 |
from transformers.utils import ModelOutput, logging
|
19 |
+
from peft import LoraConfig, get_peft_model
|
20 |
|
21 |
from .configuration_internvl_chat import InternVLChatConfig
|
22 |
from .conversation import get_conv_template
|
|
|
72 |
self.language_model = InternLM2ForCausalLM(config.llm_config)
|
73 |
else:
|
74 |
raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
|
75 |
+
self.llm_arch_name = config.llm_config.architectures[0]
|
76 |
|
77 |
vit_hidden_size = config.vision_config.hidden_size
|
78 |
llm_hidden_size = config.llm_config.hidden_size
|
|
|
87 |
self.img_context_token_id = None
|
88 |
self.conv_template = get_conv_template(self.template)
|
89 |
self.system_message = self.conv_template.system_message
|
90 |
+
self.img_context_token_id = None
|
91 |
+
|
92 |
+
if config.use_backbone_lora:
|
93 |
+
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
|
94 |
+
|
95 |
+
if config.use_llm_lora:
|
96 |
+
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
97 |
+
|
98 |
+
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
99 |
+
lora_config = LoraConfig(
|
100 |
+
r=r,
|
101 |
+
target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
|
102 |
+
lora_alpha=lora_alpha,
|
103 |
+
lora_dropout=lora_dropout,
|
104 |
+
)
|
105 |
+
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
106 |
+
self.vision_model.print_trainable_parameters()
|
107 |
+
|
108 |
+
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
109 |
+
# Determine the target modules based on the architecture of the language model
|
110 |
+
if self.llm_arch_name == 'InternLM2ForCausalLM':
|
111 |
+
target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']
|
112 |
+
elif self.llm_arch_name == 'Phi3ForCausalLM':
|
113 |
+
target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj']
|
114 |
+
elif self.llm_arch_name in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
|
115 |
+
target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
|
116 |
+
'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj']
|
117 |
+
else:
|
118 |
+
raise NotImplemented
|
119 |
+
lora_config = LoraConfig(
|
120 |
+
r=r,
|
121 |
+
target_modules=target_modules,
|
122 |
+
lora_alpha=lora_alpha,
|
123 |
+
lora_dropout=lora_dropout,
|
124 |
+
task_type='CAUSAL_LM'
|
125 |
+
)
|
126 |
+
self.language_model = get_peft_model(self.language_model, lora_config)
|
127 |
+
self.language_model.enable_input_require_grads()
|
128 |
+
self.language_model.print_trainable_parameters()
|
129 |
+
|
130 |
|
131 |
def forward(
|
132 |
self,
|