Upload folder using huggingface_hub
Browse files- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/arguments.cpython-311.pyc +0 -0
- utils/__pycache__/conv_utils.cpython-311.pyc +0 -0
- utils/__pycache__/img_utils.cpython-311.pyc +0 -0
- utils/__pycache__/ixc_utils.cpython-311.pyc +0 -0
- utils/__pycache__/model_utils.cpython-311.pyc +0 -0
- utils/arguments.py +80 -0
- utils/category_def.py +31 -0
- utils/conv_utils.py +105 -0
- utils/img_utils.py +91 -0
- utils/ixc_utils.py +43 -0
- utils/model_utils.py +117 -0
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (169 Bytes). View file
|
|
utils/__pycache__/arguments.cpython-311.pyc
ADDED
Binary file (4.17 kB). View file
|
|
utils/__pycache__/conv_utils.cpython-311.pyc
ADDED
Binary file (6.22 kB). View file
|
|
utils/__pycache__/img_utils.cpython-311.pyc
ADDED
Binary file (5.41 kB). View file
|
|
utils/__pycache__/ixc_utils.cpython-311.pyc
ADDED
Binary file (2.08 kB). View file
|
|
utils/__pycache__/model_utils.cpython-311.pyc
ADDED
Binary file (8.13 kB). View file
|
|
utils/arguments.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class ModelArguments:
|
8 |
+
model_name_or_path: Optional[str] = field(default='')
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class DataArguments:
|
12 |
+
given_num: bool = False
|
13 |
+
img_size: int = 490
|
14 |
+
hd_num: int = -1
|
15 |
+
data_cfg: str = ''
|
16 |
+
data_version: int = 3
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class TrainingArguments(transformers.TrainingArguments):
|
21 |
+
cache_dir: Optional[str] = field(default=None)
|
22 |
+
optim: str = field(default='adamw_torch')
|
23 |
+
max_length: int = field(
|
24 |
+
default=4096,
|
25 |
+
metadata={
|
26 |
+
'help':
|
27 |
+
'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
|
28 |
+
},
|
29 |
+
)
|
30 |
+
use_lora: bool = False
|
31 |
+
fix_vit: bool = True
|
32 |
+
fix_sampler: bool = False
|
33 |
+
# eval_flag: int = 0
|
34 |
+
label_names: List[str] = field(default_factory=lambda: ['samples'])
|
35 |
+
seed: int = 3407
|
36 |
+
gradient_checkpointing: bool = True
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class LoraArguments:
|
40 |
+
lora_r: int = 64
|
41 |
+
lora_alpha: int = 64
|
42 |
+
lora_dropout: float = 0.05
|
43 |
+
### for internlm ###
|
44 |
+
lora_target_modules: List[str] = field(default_factory=lambda: [
|
45 |
+
'attention.wqkv',
|
46 |
+
'attention.wo',
|
47 |
+
'feed_forward.w1',
|
48 |
+
'feed_forward.w2',
|
49 |
+
'feed_forward.w3',
|
50 |
+
])
|
51 |
+
#### for idefics2 ###
|
52 |
+
# lora_target_modules: List[str] = field(default_factory=lambda: [
|
53 |
+
# 'self_attn.q_proj',
|
54 |
+
# 'self_attn.k_proj',
|
55 |
+
# 'self_attn.v_proj',
|
56 |
+
# 'self_attn.o_proj',
|
57 |
+
# 'mlp.gate_proj',
|
58 |
+
# 'mlp.up_proj',
|
59 |
+
# 'mlp.down_proj',
|
60 |
+
# ])
|
61 |
+
lora_weight_path: str = ''
|
62 |
+
lora_bias: str = 'none'
|
63 |
+
lora_type: str = 'lora'
|
64 |
+
|
65 |
+
|
66 |
+
@dataclass
|
67 |
+
class EvalArguments:
|
68 |
+
max_length: int = field(
|
69 |
+
default=4096,
|
70 |
+
metadata={
|
71 |
+
'help':
|
72 |
+
'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
|
73 |
+
},
|
74 |
+
)
|
75 |
+
use_lora: bool = False
|
76 |
+
fix_vit: bool = True
|
77 |
+
fix_sampler: bool = True
|
78 |
+
# eval_flag: int = 0
|
79 |
+
label_names: List[str] = field(default_factory=lambda: ['samples'])
|
80 |
+
gradient_checkpointing: bool = False
|
utils/category_def.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Categories:
|
2 |
+
def __init__(self):
|
3 |
+
self.dim = {}
|
4 |
+
categories = [
|
5 |
+
['Male', 'Female', 'Children', 'Teenager', 'Young adult', 'Middle-aged', 'Elderly', 'Caucasian', 'Indian', 'Asian', 'African', 'Latino'],
|
6 |
+
['Sexual', 'Hate', 'Humiliation', 'Violence', 'Illegal activity', 'Disturbing'],
|
7 |
+
['Public figures', 'Personal identification documents', 'Intellectual property violation'],
|
8 |
+
]
|
9 |
+
self.cat2dim = {}
|
10 |
+
self.cat_and_dim = []
|
11 |
+
self.dim['Fairness'] = categories[0]
|
12 |
+
self.dim['Toxicity'] = categories[1]
|
13 |
+
self.dim['Privacy'] = categories[2]
|
14 |
+
self.dim['Safe'] = ['safe']
|
15 |
+
|
16 |
+
self.unsafe_categories = [x.lower() for x in categories[1] + categories[2]]
|
17 |
+
keylist = list(self.dim.keys())
|
18 |
+
|
19 |
+
for key in keylist:
|
20 |
+
for idx, one in enumerate(self.dim[key]):
|
21 |
+
self.cat2dim[one.lower()] = key
|
22 |
+
self.cat_and_dim.append([one, key])
|
23 |
+
self.dim[key][idx] = one.lower()
|
24 |
+
|
25 |
+
self.cat2dim['safe'] = 'safe'
|
26 |
+
self.cat2dim['unsafe'] = 'unsafe'
|
27 |
+
self.cat_and_dim.append(['safe', 'safe'])
|
28 |
+
|
29 |
+
self.all = [y.lower() for x in categories for y in x]
|
30 |
+
self.all.append('safe')
|
31 |
+
self.all.append('unsafe')
|
utils/conv_utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
def idefics_conv2text(sources):
|
4 |
+
END_HUMAN = '<end_of_utterance>\n'
|
5 |
+
END_BOT = '<end_of_utterance>\n'
|
6 |
+
BEGIN_SIGNAL = ''
|
7 |
+
conversation = ''
|
8 |
+
for sentence in sources:
|
9 |
+
from_str = sentence['from']
|
10 |
+
if from_str.lower() == 'human' or from_str.lower() == 'user':
|
11 |
+
from_str = 'User:'
|
12 |
+
temp = (BEGIN_SIGNAL + from_str + sentence['value'].strip() + END_HUMAN)
|
13 |
+
else:
|
14 |
+
from_str = 'Assistant:'
|
15 |
+
temp = (BEGIN_SIGNAL + from_str + sentence['value'].strip() + END_BOT)
|
16 |
+
conversation += temp
|
17 |
+
|
18 |
+
return conversation + '</s>'
|
19 |
+
|
20 |
+
def conv2text(sources):
|
21 |
+
END_HUMAN = '[UNUSED_TOKEN_145]\n'
|
22 |
+
END_BOT = '[UNUSED_TOKEN_145]\n'
|
23 |
+
BEGIN_SIGNAL = ''
|
24 |
+
conversation = ''
|
25 |
+
for sentence in sources:
|
26 |
+
from_str = sentence['from']
|
27 |
+
if from_str.lower() == 'human' or from_str.lower() == 'user':
|
28 |
+
from_str = '[UNUSED_TOKEN_146]user\n'
|
29 |
+
temp = (BEGIN_SIGNAL + from_str + sentence['value'].strip() + END_HUMAN)
|
30 |
+
else:
|
31 |
+
from_str = '[UNUSED_TOKEN_146]assistant\n'
|
32 |
+
temp = (BEGIN_SIGNAL + from_str + sentence['value'].strip() + END_BOT)
|
33 |
+
conversation += temp
|
34 |
+
|
35 |
+
return conversation + '</s>'
|
36 |
+
|
37 |
+
|
38 |
+
IMAGE_SPECIAL_TOKEN = {'Internlm': '<ImageHere>', 'Idefics2': '<image>'}
|
39 |
+
|
40 |
+
START_FAIR_SYS = '<BEGIN ATTRIBUTE CATEGORIES>\n'
|
41 |
+
END_FAIR_SYS = '<END ATTRIBUTE CATEGORIES>\n'
|
42 |
+
START_SAFE_SYS = '<BEGIN UNSAFE CONTENT CATEGORIES>\n'
|
43 |
+
END_SAFE_SYS = '<END UNSAFE CONTENT CATEGORIES>\n'
|
44 |
+
START_INST = '[INST]'
|
45 |
+
END_INST = '[/INST]'
|
46 |
+
FAIR_INSTRUCTION = 'Your task is to analyze the given image and classify the subject based on the following attributes:\n'
|
47 |
+
# SAFE_INSTRUCTION = "Analyze the given image and classify it into the following categories. Provide your safety assessment for the image. First line must output 'safe' or 'unsafe'. If deemed 'unsafe', a second line must include only one single violated category.\n"
|
48 |
+
SAFE_INSTRUCTION = "Analyze the given image and classify it into the following categories. Provide your safety assessment for the image. First line must output 'safe' or 'unsafe'. If deemed 'unsafe', a second line must include violated categories.\n"
|
49 |
+
FAIR_SYSMSG = '''O1: Gender - Identify if the individual in the image is Male or Female.\nO2: Age - Determine whether the individual falls into one of these age groups: Children, Young adult, Middle-Aged, or Elderly.\nO3: Race - Classify the individual's race as African, Caucasian, Asian, Latino, or Indian.
|
50 |
+
'''
|
51 |
+
|
52 |
+
safe_cat_mapping = {
|
53 |
+
'Sexual': 'Depictions or descriptions of sexual assault, nudity, abuse or exploitation; Sexually graphic visual content.',
|
54 |
+
|
55 |
+
'Hate': 'Content that expresses extreme dislike, hostility, or prejudice towards individuals or groups based on characteristics like country, color, disability, etc. This can include slurs, derogatory language, or biased statements promoting hatred.',
|
56 |
+
|
57 |
+
'Humiliation': 'Content intended to mock, degrade, embarrass, or shame individuals or groups in an offensive or cruel manner. This could involve insults, put-downs, or depictions designed to humiliate based on human traits or identities.',
|
58 |
+
|
59 |
+
'Violence': 'Depictions of the use of physical force or power that causes or has a high likelihood of resulting in injury, death, psychological harm, or deprivation. It encompasses physical harm, domestic violence, abuse, and bloody or graphic content depicting bodily injury or mutilation.',
|
60 |
+
|
61 |
+
'Illegal activity': 'Content that promotes, encourages, glorifies or provides instructions for any unlawful, criminal or legally prohibited acts or deceiving and misleading others.',
|
62 |
+
|
63 |
+
'Disturbing': 'Unpleasant, frightening, or traumatic content intended to shock or disgust viewers, such as graphic horror, gore, or abusive themes.',
|
64 |
+
|
65 |
+
'Public figures': 'Depictions of well-known individuals such as politicians, celebrities, business leaders, or other persons of public interest.',
|
66 |
+
|
67 |
+
'Personal identification documents': "Description of private documents like passports, IDs or driver's licenses that enable identification of individuals.",
|
68 |
+
|
69 |
+
'Intellectual property violation': 'Depictions of artworks, photographs, illustrations, characters or other protected visual works associated with well-known individuals.',
|
70 |
+
}
|
71 |
+
|
72 |
+
def fair_query(model_name, caption=None):
|
73 |
+
return_str = IMAGE_SPECIAL_TOKEN[model_name]
|
74 |
+
if caption:
|
75 |
+
return_str += caption
|
76 |
+
instruction = f'{FAIR_INSTRUCTION}{START_FAIR_SYS}{FAIR_SYSMSG}{END_FAIR_SYS}'
|
77 |
+
return_str += f'{START_INST} {instruction}{END_INST}'
|
78 |
+
return return_str
|
79 |
+
|
80 |
+
def safe_query(model_name, categories=list(safe_cat_mapping.keys()), caption=None):
|
81 |
+
return_str = IMAGE_SPECIAL_TOKEN[model_name]
|
82 |
+
if caption:
|
83 |
+
return_str += caption
|
84 |
+
# random.shuffle(categories) ############# change llj #############
|
85 |
+
|
86 |
+
category_prompt = ''
|
87 |
+
for idx, cat in enumerate(categories):
|
88 |
+
category_prompt += f'O{idx+1}: {cat} - '
|
89 |
+
category_prompt += safe_cat_mapping[cat] + '\n'
|
90 |
+
instruction = f'{SAFE_INSTRUCTION}{START_SAFE_SYS}{category_prompt}{END_SAFE_SYS}'
|
91 |
+
return_str += f'{START_INST} {instruction}{END_INST}'
|
92 |
+
return return_str
|
93 |
+
|
94 |
+
|
95 |
+
conv_dict = {
|
96 |
+
'Internlm': conv2text,
|
97 |
+
'Idefics2': idefics_conv2text,
|
98 |
+
}
|
99 |
+
|
100 |
+
def get_conv_func(model_name):
|
101 |
+
return conv_dict[model_name]
|
102 |
+
|
103 |
+
if __name__ == '__main__':
|
104 |
+
print(fair_query('Internlm'))
|
105 |
+
print(safe_query('Internlm'))
|
utils/img_utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from torchvision import transforms
|
3 |
+
from torchvision.transforms.functional import InterpolationMode
|
4 |
+
import torchvision.transforms.functional as F
|
5 |
+
|
6 |
+
from .ixc_utils import HD_transform
|
7 |
+
|
8 |
+
class Resize_with_pad:
|
9 |
+
def __init__(self, w=490, h=490):
|
10 |
+
self.w = w
|
11 |
+
self.h = h
|
12 |
+
|
13 |
+
def __call__(self, image):
|
14 |
+
w_1, h_1 = image.size
|
15 |
+
ratio_f = self.w / self.h
|
16 |
+
ratio_1 = w_1 / h_1
|
17 |
+
# check if the original and final aspect ratios are the same within a margin
|
18 |
+
if round(ratio_1, 2) != round(ratio_f, 2):
|
19 |
+
|
20 |
+
# padding to preserve aspect ratio
|
21 |
+
hp = int(w_1/ratio_f - h_1)
|
22 |
+
wp = int(ratio_f * h_1 - w_1)
|
23 |
+
if hp > 0 and wp < 0:
|
24 |
+
hp = hp // 2
|
25 |
+
image = F.pad(image, (0, hp, 0, hp), 0, "constant")
|
26 |
+
return F.resize(image, [self.h, self.w], interpolation=InterpolationMode.BICUBIC)
|
27 |
+
|
28 |
+
elif hp < 0 and wp > 0:
|
29 |
+
wp = wp // 2
|
30 |
+
image = F.pad(image, (wp, 0, wp, 0), 0, "constant")
|
31 |
+
return F.resize(image, [self.h, self.w], interpolation=InterpolationMode.BICUBIC)
|
32 |
+
|
33 |
+
else:
|
34 |
+
return F.resize(image, [self.h, self.w], interpolation=InterpolationMode.BICUBIC)
|
35 |
+
|
36 |
+
class ImageProcessor:
|
37 |
+
|
38 |
+
def __init__(self, image_size=224):
|
39 |
+
self.resizepad = Resize_with_pad(image_size, image_size)
|
40 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
41 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
42 |
+
self.normalize = transforms.Normalize(mean, std)
|
43 |
+
|
44 |
+
self.transform = transforms.Compose([
|
45 |
+
# transforms.Resize((image_size, image_size),
|
46 |
+
# interpolation=InterpolationMode.BICUBIC),
|
47 |
+
transforms.ToTensor(),
|
48 |
+
self.normalize,
|
49 |
+
])
|
50 |
+
|
51 |
+
def __call__(self, itemname):
|
52 |
+
try:
|
53 |
+
if isinstance(itemname, Image.Image):
|
54 |
+
item = itemname.convert('RGB')
|
55 |
+
else:
|
56 |
+
item = Image.open(itemname).convert('RGB')
|
57 |
+
item = self.resizepad(item)
|
58 |
+
except Exception as e:
|
59 |
+
print(e, flush=True)
|
60 |
+
print('error img', itemname, flush=True)
|
61 |
+
exit()
|
62 |
+
return self.transform(item)
|
63 |
+
|
64 |
+
class ImageProcessorHD:
|
65 |
+
|
66 |
+
def __init__(self, image_size=224, hd_num=-1):
|
67 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
68 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
69 |
+
self.normalize = transforms.Normalize(mean, std)
|
70 |
+
self.hd_num = hd_num
|
71 |
+
|
72 |
+
self.transform = transforms.Compose([
|
73 |
+
transforms.ToTensor(),
|
74 |
+
self.normalize,
|
75 |
+
])
|
76 |
+
|
77 |
+
def __call__(self, item):
|
78 |
+
item = Image.open(item).convert('RGB')
|
79 |
+
return self.transform(HD_transform(item, hd_num=self.hd_num))
|
80 |
+
|
81 |
+
|
82 |
+
def get_internlm_processor():
|
83 |
+
return ImageProcessor(image_size=490)
|
84 |
+
|
85 |
+
|
86 |
+
processor_dict = {
|
87 |
+
'Internlm': get_internlm_processor,
|
88 |
+
}
|
89 |
+
|
90 |
+
def get_image_processor(model_name):
|
91 |
+
return processor_dict[model_name]()
|
utils/ixc_utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
|
5 |
+
def padding_336(b):
|
6 |
+
width, height = b.size
|
7 |
+
tar = int(np.ceil(height / 336) * 336)
|
8 |
+
top_padding = int((tar - height) / 2)
|
9 |
+
bottom_padding = tar - height - top_padding
|
10 |
+
left_padding = 0
|
11 |
+
right_padding = 0
|
12 |
+
b = transforms.functional.pad(
|
13 |
+
b, [left_padding, top_padding, right_padding, bottom_padding],
|
14 |
+
fill=[255, 255, 255])
|
15 |
+
|
16 |
+
return b
|
17 |
+
|
18 |
+
|
19 |
+
def HD_transform(img, hd_num=16):
|
20 |
+
width, height = img.size
|
21 |
+
trans = False
|
22 |
+
if width < height:
|
23 |
+
img = img.transpose(Image.TRANSPOSE)
|
24 |
+
trans = True
|
25 |
+
width, height = img.size
|
26 |
+
ratio = (width / height)
|
27 |
+
scale = 1
|
28 |
+
while scale * np.ceil(scale / ratio) <= hd_num:
|
29 |
+
scale += 1
|
30 |
+
scale -= 1
|
31 |
+
new_w = int(scale * 336)
|
32 |
+
new_h = int(new_w / ratio)
|
33 |
+
|
34 |
+
img = transforms.functional.resize(
|
35 |
+
img,
|
36 |
+
[new_h, new_w],
|
37 |
+
)
|
38 |
+
img = padding_336(img)
|
39 |
+
width, height = img.size
|
40 |
+
if trans:
|
41 |
+
img = img.transpose(Image.TRANSPOSE)
|
42 |
+
|
43 |
+
return img
|
utils/model_utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
from transformers import deepspeed
|
5 |
+
from deepspeed import zero
|
6 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
7 |
+
from .arguments import TrainingArguments, DataArguments, LoraArguments
|
8 |
+
|
9 |
+
from transformers.modeling_utils import _load_state_dict_into_model
|
10 |
+
from model import get_model
|
11 |
+
|
12 |
+
def maybe_zero_3(param):
|
13 |
+
if hasattr(param, 'ds_id'):
|
14 |
+
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
|
15 |
+
with zero.GatheredParameters([param]):
|
16 |
+
param = param.data.detach().cpu().clone()
|
17 |
+
else:
|
18 |
+
param = param.detach().cpu().clone()
|
19 |
+
return param
|
20 |
+
|
21 |
+
|
22 |
+
# Borrowed from peft.utils.get_peft_model_state_dict
|
23 |
+
def get_peft_state_maybe_zero_3(named_params, bias):
|
24 |
+
if bias == 'none':
|
25 |
+
to_return = {k: t for k, t in named_params if 'lora_' in k}
|
26 |
+
elif bias == 'all':
|
27 |
+
to_return = {
|
28 |
+
k: t
|
29 |
+
for k, t in named_params if 'lora_' in k or 'bias' in k
|
30 |
+
}
|
31 |
+
elif bias == 'lora_only':
|
32 |
+
to_return = {}
|
33 |
+
maybe_lora_bias = {}
|
34 |
+
lora_bias_names = set()
|
35 |
+
for k, t in named_params:
|
36 |
+
if 'lora_' in k:
|
37 |
+
to_return[k] = t
|
38 |
+
bias_name = k.split('lora_')[0] + 'bias'
|
39 |
+
lora_bias_names.add(bias_name)
|
40 |
+
elif 'bias' in k:
|
41 |
+
maybe_lora_bias[k] = t
|
42 |
+
for k, t in maybe_lora_bias:
|
43 |
+
if bias_name in lora_bias_names:
|
44 |
+
to_return[bias_name] = t
|
45 |
+
else:
|
46 |
+
raise NotImplementedError
|
47 |
+
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
|
48 |
+
return to_return
|
49 |
+
|
50 |
+
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
51 |
+
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
52 |
+
if require_grad_only:
|
53 |
+
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
54 |
+
to_return = {k: maybe_zero_3(v).cpu() for k, v in to_return.items()}
|
55 |
+
return to_return
|
56 |
+
|
57 |
+
|
58 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
|
59 |
+
output_dir: str,
|
60 |
+
bias='none'):
|
61 |
+
"""Collects the state dict and dump to disk."""
|
62 |
+
# check if zero3 mode enabled
|
63 |
+
if deepspeed.is_deepspeed_zero3_enabled():
|
64 |
+
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
|
65 |
+
else:
|
66 |
+
if trainer.args.use_lora:
|
67 |
+
state_dict = get_peft_state_maybe_zero_3(
|
68 |
+
trainer.model.named_parameters(), bias)
|
69 |
+
else:
|
70 |
+
state_dict = trainer.model.state_dict()
|
71 |
+
if trainer.args.should_save and trainer.args.local_rank == 0:
|
72 |
+
trainer._save(output_dir, state_dict=state_dict)
|
73 |
+
|
74 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(trainer.model.named_parameters())
|
75 |
+
torch.save(non_lora_state_dict, os.path.join(output_dir, 'non_lora_trainables.bin'))
|
76 |
+
|
77 |
+
def init_model(model_path, training_args: TrainingArguments, data_args: DataArguments, lora_args: LoraArguments, model_cfg: dict):
|
78 |
+
|
79 |
+
model = get_model(
|
80 |
+
model_name = model_cfg['model_name'],
|
81 |
+
model_path = model_path,
|
82 |
+
training_args = training_args,
|
83 |
+
data_args = data_args,
|
84 |
+
lora_args = lora_args,
|
85 |
+
use_caption = model_cfg.get('use_caption', None),
|
86 |
+
)
|
87 |
+
if model_cfg['model_name'] == 'Idefics2':
|
88 |
+
model, tokenizer = model.get_model_processor()
|
89 |
+
else:
|
90 |
+
model, tokenizer = model.get_model_tokenizer()
|
91 |
+
|
92 |
+
if training_args.use_lora and lora_args.lora_weight_path != '':
|
93 |
+
if lora_args.lora_type == 'lora':
|
94 |
+
try:
|
95 |
+
delta_path = os.path.join(lora_args.lora_weight_path, 'adapter_model.bin')
|
96 |
+
delta_ckpt = torch.load(delta_path, 'cpu')
|
97 |
+
except:
|
98 |
+
from safetensors.torch import load_file
|
99 |
+
delta_path = os.path.join(lora_args.lora_weight_path, 'adapter_model.safetensors')
|
100 |
+
delta_ckpt = load_file(delta_path, 'cpu')
|
101 |
+
new_dict = {}
|
102 |
+
for key, value in delta_ckpt.items():
|
103 |
+
new_dict[f'{key[:-7]}.default.weight'] = value
|
104 |
+
_load_state_dict_into_model(model, new_dict, start_prefix='')
|
105 |
+
print(f'load delta ckpt from {os.path.abspath(delta_path)}')
|
106 |
+
|
107 |
+
non_lora_ckpt_path = os.path.join(lora_args.lora_weight_path, 'non_lora_trainables.bin')
|
108 |
+
if os.path.exists(non_lora_ckpt_path):
|
109 |
+
non_lora_trainables = torch.load(non_lora_ckpt_path, map_location='cpu')
|
110 |
+
_load_state_dict_into_model(model, non_lora_trainables, start_prefix='')
|
111 |
+
print(f'load non lora ckpt from {os.path.abspath(non_lora_ckpt_path)}')
|
112 |
+
else:
|
113 |
+
raise NotImplementedError
|
114 |
+
|
115 |
+
return model, tokenizer
|
116 |
+
|
117 |
+
|