Spaces:
Runtime error
Runtime error
update
Browse files- bubogpt/models/mm_gpt4.py +4 -0
- grounding_model.py +2 -0
- tagging_model.py +2 -0
bubogpt/models/mm_gpt4.py
CHANGED
|
@@ -78,10 +78,12 @@ class MMGPT4(BaseModel):
|
|
| 78 |
|
| 79 |
self.low_resource = low_resource
|
| 80 |
|
|
|
|
| 81 |
print('Loading ImageBind')
|
| 82 |
self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind,
|
| 83 |
with_head=with_bind_head, use_blip_vision=use_blip_vision)
|
| 84 |
print('Loading ImageBind Done')
|
|
|
|
| 85 |
|
| 86 |
print(f'Loading LLAMA from {llama_model}')
|
| 87 |
self.llama_tokenizer = LlamaTokenizer.from_pretrained('magicr/vicuna-7b', use_fast=False, use_auth_token=True)
|
|
@@ -94,12 +96,14 @@ class MMGPT4(BaseModel):
|
|
| 94 |
for name, param in self.llama_model.named_parameters():
|
| 95 |
param.requires_grad = False
|
| 96 |
print('Loading LLAMA Done')
|
|
|
|
| 97 |
|
| 98 |
print('Loading Q-Former and Adapter/Projector')
|
| 99 |
self.multimodal_joiner = ImageBindJoiner(joiner_cfg, output_dim=self.llama_model.config.hidden_size)
|
| 100 |
if use_blip_vision:
|
| 101 |
replace_joiner_vision(self.multimodal_joiner, q_former_model, proj_model)
|
| 102 |
print('Loading Q-Former and Adapter/Projector Done')
|
|
|
|
| 103 |
|
| 104 |
self.max_txt_len = max_txt_len
|
| 105 |
self.end_sym = end_sym
|
|
|
|
| 78 |
|
| 79 |
self.low_resource = low_resource
|
| 80 |
|
| 81 |
+
import gc
|
| 82 |
print('Loading ImageBind')
|
| 83 |
self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind,
|
| 84 |
with_head=with_bind_head, use_blip_vision=use_blip_vision)
|
| 85 |
print('Loading ImageBind Done')
|
| 86 |
+
gc.collect()
|
| 87 |
|
| 88 |
print(f'Loading LLAMA from {llama_model}')
|
| 89 |
self.llama_tokenizer = LlamaTokenizer.from_pretrained('magicr/vicuna-7b', use_fast=False, use_auth_token=True)
|
|
|
|
| 96 |
for name, param in self.llama_model.named_parameters():
|
| 97 |
param.requires_grad = False
|
| 98 |
print('Loading LLAMA Done')
|
| 99 |
+
gc.collect()
|
| 100 |
|
| 101 |
print('Loading Q-Former and Adapter/Projector')
|
| 102 |
self.multimodal_joiner = ImageBindJoiner(joiner_cfg, output_dim=self.llama_model.config.hidden_size)
|
| 103 |
if use_blip_vision:
|
| 104 |
replace_joiner_vision(self.multimodal_joiner, q_former_model, proj_model)
|
| 105 |
print('Loading Q-Former and Adapter/Projector Done')
|
| 106 |
+
gc.collect()
|
| 107 |
|
| 108 |
self.max_txt_len = max_txt_len
|
| 109 |
self.end_sym = end_sym
|
grounding_model.py
CHANGED
|
@@ -17,11 +17,13 @@ from groundingdino.util.utils import clean_state_dict
|
|
| 17 |
|
| 18 |
|
| 19 |
def load_groundingdino_model(model_config_path, model_checkpoint_path):
|
|
|
|
| 20 |
args = CN.load_cfg(open(model_config_path, "r"))
|
| 21 |
model = build_groundingdino(args)
|
| 22 |
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
| 23 |
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
| 24 |
print('loading GroundingDINO:', load_res)
|
|
|
|
| 25 |
_ = model.eval()
|
| 26 |
return model
|
| 27 |
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def load_groundingdino_model(model_config_path, model_checkpoint_path):
|
| 20 |
+
import gc
|
| 21 |
args = CN.load_cfg(open(model_config_path, "r"))
|
| 22 |
model = build_groundingdino(args)
|
| 23 |
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
| 24 |
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
| 25 |
print('loading GroundingDINO:', load_res)
|
| 26 |
+
gc.collect()
|
| 27 |
_ = model.eval()
|
| 28 |
return model
|
| 29 |
|
tagging_model.py
CHANGED
|
@@ -8,6 +8,7 @@ from ram.models import ram
|
|
| 8 |
class TaggingModule(nn.Module):
|
| 9 |
def __init__(self, device='cpu'):
|
| 10 |
super().__init__()
|
|
|
|
| 11 |
self.device = device
|
| 12 |
image_size = 384
|
| 13 |
self.transform = transforms.Compose([
|
|
@@ -23,6 +24,7 @@ class TaggingModule(nn.Module):
|
|
| 23 |
vit='swin_l'
|
| 24 |
).eval().to(device)
|
| 25 |
print('==> Tagging Module Loaded.')
|
|
|
|
| 26 |
|
| 27 |
@torch.no_grad()
|
| 28 |
def forward(self, original_image):
|
|
|
|
| 8 |
class TaggingModule(nn.Module):
|
| 9 |
def __init__(self, device='cpu'):
|
| 10 |
super().__init__()
|
| 11 |
+
import gc
|
| 12 |
self.device = device
|
| 13 |
image_size = 384
|
| 14 |
self.transform = transforms.Compose([
|
|
|
|
| 24 |
vit='swin_l'
|
| 25 |
).eval().to(device)
|
| 26 |
print('==> Tagging Module Loaded.')
|
| 27 |
+
gc.collect()
|
| 28 |
|
| 29 |
@torch.no_grad()
|
| 30 |
def forward(self, original_image):
|