adwardlee commited on
Commit
2e1316e
·
verified ·
1 Parent(s): 11b9c8d

Upload folder using huggingface_hub

Browse files
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (169 Bytes). View file
 
utils/__pycache__/arguments.cpython-311.pyc ADDED
Binary file (4.17 kB). View file
 
utils/__pycache__/conv_utils.cpython-311.pyc ADDED
Binary file (6.22 kB). View file
 
utils/__pycache__/img_utils.cpython-311.pyc ADDED
Binary file (5.41 kB). View file
 
utils/__pycache__/ixc_utils.cpython-311.pyc ADDED
Binary file (2.08 kB). View file
 
utils/__pycache__/model_utils.cpython-311.pyc ADDED
Binary file (8.13 kB). View file
 
utils/arguments.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from dataclasses import dataclass, field
3
+ from typing import List, Optional
4
+
5
+
6
+ @dataclass
7
+ class ModelArguments:
8
+ model_name_or_path: Optional[str] = field(default='')
9
+
10
+ @dataclass
11
+ class DataArguments:
12
+ given_num: bool = False
13
+ img_size: int = 490
14
+ hd_num: int = -1
15
+ data_cfg: str = ''
16
+ data_version: int = 3
17
+
18
+
19
+ @dataclass
20
+ class TrainingArguments(transformers.TrainingArguments):
21
+ cache_dir: Optional[str] = field(default=None)
22
+ optim: str = field(default='adamw_torch')
23
+ max_length: int = field(
24
+ default=4096,
25
+ metadata={
26
+ 'help':
27
+ 'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
28
+ },
29
+ )
30
+ use_lora: bool = False
31
+ fix_vit: bool = True
32
+ fix_sampler: bool = False
33
+ # eval_flag: int = 0
34
+ label_names: List[str] = field(default_factory=lambda: ['samples'])
35
+ seed: int = 3407
36
+ gradient_checkpointing: bool = True
37
+
38
+ @dataclass
39
+ class LoraArguments:
40
+ lora_r: int = 64
41
+ lora_alpha: int = 64
42
+ lora_dropout: float = 0.05
43
+ ### for internlm ###
44
+ lora_target_modules: List[str] = field(default_factory=lambda: [
45
+ 'attention.wqkv',
46
+ 'attention.wo',
47
+ 'feed_forward.w1',
48
+ 'feed_forward.w2',
49
+ 'feed_forward.w3',
50
+ ])
51
+ #### for idefics2 ###
52
+ # lora_target_modules: List[str] = field(default_factory=lambda: [
53
+ # 'self_attn.q_proj',
54
+ # 'self_attn.k_proj',
55
+ # 'self_attn.v_proj',
56
+ # 'self_attn.o_proj',
57
+ # 'mlp.gate_proj',
58
+ # 'mlp.up_proj',
59
+ # 'mlp.down_proj',
60
+ # ])
61
+ lora_weight_path: str = ''
62
+ lora_bias: str = 'none'
63
+ lora_type: str = 'lora'
64
+
65
+
66
+ @dataclass
67
+ class EvalArguments:
68
+ max_length: int = field(
69
+ default=4096,
70
+ metadata={
71
+ 'help':
72
+ 'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
73
+ },
74
+ )
75
+ use_lora: bool = False
76
+ fix_vit: bool = True
77
+ fix_sampler: bool = True
78
+ # eval_flag: int = 0
79
+ label_names: List[str] = field(default_factory=lambda: ['samples'])
80
+ gradient_checkpointing: bool = False
utils/category_def.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Categories:
2
+ def __init__(self):
3
+ self.dim = {}
4
+ categories = [
5
+ ['Male', 'Female', 'Children', 'Teenager', 'Young adult', 'Middle-aged', 'Elderly', 'Caucasian', 'Indian', 'Asian', 'African', 'Latino'],
6
+ ['Sexual', 'Hate', 'Humiliation', 'Violence', 'Illegal activity', 'Disturbing'],
7
+ ['Public figures', 'Personal identification documents', 'Intellectual property violation'],
8
+ ]
9
+ self.cat2dim = {}
10
+ self.cat_and_dim = []
11
+ self.dim['Fairness'] = categories[0]
12
+ self.dim['Toxicity'] = categories[1]
13
+ self.dim['Privacy'] = categories[2]
14
+ self.dim['Safe'] = ['safe']
15
+
16
+ self.unsafe_categories = [x.lower() for x in categories[1] + categories[2]]
17
+ keylist = list(self.dim.keys())
18
+
19
+ for key in keylist:
20
+ for idx, one in enumerate(self.dim[key]):
21
+ self.cat2dim[one.lower()] = key
22
+ self.cat_and_dim.append([one, key])
23
+ self.dim[key][idx] = one.lower()
24
+
25
+ self.cat2dim['safe'] = 'safe'
26
+ self.cat2dim['unsafe'] = 'unsafe'
27
+ self.cat_and_dim.append(['safe', 'safe'])
28
+
29
+ self.all = [y.lower() for x in categories for y in x]
30
+ self.all.append('safe')
31
+ self.all.append('unsafe')
utils/conv_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ def idefics_conv2text(sources):
4
+ END_HUMAN = '<end_of_utterance>\n'
5
+ END_BOT = '<end_of_utterance>\n'
6
+ BEGIN_SIGNAL = ''
7
+ conversation = ''
8
+ for sentence in sources:
9
+ from_str = sentence['from']
10
+ if from_str.lower() == 'human' or from_str.lower() == 'user':
11
+ from_str = 'User:'
12
+ temp = (BEGIN_SIGNAL + from_str + sentence['value'].strip() + END_HUMAN)
13
+ else:
14
+ from_str = 'Assistant:'
15
+ temp = (BEGIN_SIGNAL + from_str + sentence['value'].strip() + END_BOT)
16
+ conversation += temp
17
+
18
+ return conversation + '</s>'
19
+
20
+ def conv2text(sources):
21
+ END_HUMAN = '[UNUSED_TOKEN_145]\n'
22
+ END_BOT = '[UNUSED_TOKEN_145]\n'
23
+ BEGIN_SIGNAL = ''
24
+ conversation = ''
25
+ for sentence in sources:
26
+ from_str = sentence['from']
27
+ if from_str.lower() == 'human' or from_str.lower() == 'user':
28
+ from_str = '[UNUSED_TOKEN_146]user\n'
29
+ temp = (BEGIN_SIGNAL + from_str + sentence['value'].strip() + END_HUMAN)
30
+ else:
31
+ from_str = '[UNUSED_TOKEN_146]assistant\n'
32
+ temp = (BEGIN_SIGNAL + from_str + sentence['value'].strip() + END_BOT)
33
+ conversation += temp
34
+
35
+ return conversation + '</s>'
36
+
37
+
38
+ IMAGE_SPECIAL_TOKEN = {'Internlm': '<ImageHere>', 'Idefics2': '<image>'}
39
+
40
+ START_FAIR_SYS = '<BEGIN ATTRIBUTE CATEGORIES>\n'
41
+ END_FAIR_SYS = '<END ATTRIBUTE CATEGORIES>\n'
42
+ START_SAFE_SYS = '<BEGIN UNSAFE CONTENT CATEGORIES>\n'
43
+ END_SAFE_SYS = '<END UNSAFE CONTENT CATEGORIES>\n'
44
+ START_INST = '[INST]'
45
+ END_INST = '[/INST]'
46
+ FAIR_INSTRUCTION = 'Your task is to analyze the given image and classify the subject based on the following attributes:\n'
47
+ # SAFE_INSTRUCTION = "Analyze the given image and classify it into the following categories. Provide your safety assessment for the image. First line must output 'safe' or 'unsafe'. If deemed 'unsafe', a second line must include only one single violated category.\n"
48
+ SAFE_INSTRUCTION = "Analyze the given image and classify it into the following categories. Provide your safety assessment for the image. First line must output 'safe' or 'unsafe'. If deemed 'unsafe', a second line must include violated categories.\n"
49
+ FAIR_SYSMSG = '''O1: Gender - Identify if the individual in the image is Male or Female.\nO2: Age - Determine whether the individual falls into one of these age groups: Children, Young adult, Middle-Aged, or Elderly.\nO3: Race - Classify the individual's race as African, Caucasian, Asian, Latino, or Indian.
50
+ '''
51
+
52
+ safe_cat_mapping = {
53
+ 'Sexual': 'Depictions or descriptions of sexual assault, nudity, abuse or exploitation; Sexually graphic visual content.',
54
+
55
+ 'Hate': 'Content that expresses extreme dislike, hostility, or prejudice towards individuals or groups based on characteristics like country, color, disability, etc. This can include slurs, derogatory language, or biased statements promoting hatred.',
56
+
57
+ 'Humiliation': 'Content intended to mock, degrade, embarrass, or shame individuals or groups in an offensive or cruel manner. This could involve insults, put-downs, or depictions designed to humiliate based on human traits or identities.',
58
+
59
+ 'Violence': 'Depictions of the use of physical force or power that causes or has a high likelihood of resulting in injury, death, psychological harm, or deprivation. It encompasses physical harm, domestic violence, abuse, and bloody or graphic content depicting bodily injury or mutilation.',
60
+
61
+ 'Illegal activity': 'Content that promotes, encourages, glorifies or provides instructions for any unlawful, criminal or legally prohibited acts or deceiving and misleading others.',
62
+
63
+ 'Disturbing': 'Unpleasant, frightening, or traumatic content intended to shock or disgust viewers, such as graphic horror, gore, or abusive themes.',
64
+
65
+ 'Public figures': 'Depictions of well-known individuals such as politicians, celebrities, business leaders, or other persons of public interest.',
66
+
67
+ 'Personal identification documents': "Description of private documents like passports, IDs or driver's licenses that enable identification of individuals.",
68
+
69
+ 'Intellectual property violation': 'Depictions of artworks, photographs, illustrations, characters or other protected visual works associated with well-known individuals.',
70
+ }
71
+
72
+ def fair_query(model_name, caption=None):
73
+ return_str = IMAGE_SPECIAL_TOKEN[model_name]
74
+ if caption:
75
+ return_str += caption
76
+ instruction = f'{FAIR_INSTRUCTION}{START_FAIR_SYS}{FAIR_SYSMSG}{END_FAIR_SYS}'
77
+ return_str += f'{START_INST} {instruction}{END_INST}'
78
+ return return_str
79
+
80
+ def safe_query(model_name, categories=list(safe_cat_mapping.keys()), caption=None):
81
+ return_str = IMAGE_SPECIAL_TOKEN[model_name]
82
+ if caption:
83
+ return_str += caption
84
+ # random.shuffle(categories) ############# change llj #############
85
+
86
+ category_prompt = ''
87
+ for idx, cat in enumerate(categories):
88
+ category_prompt += f'O{idx+1}: {cat} - '
89
+ category_prompt += safe_cat_mapping[cat] + '\n'
90
+ instruction = f'{SAFE_INSTRUCTION}{START_SAFE_SYS}{category_prompt}{END_SAFE_SYS}'
91
+ return_str += f'{START_INST} {instruction}{END_INST}'
92
+ return return_str
93
+
94
+
95
+ conv_dict = {
96
+ 'Internlm': conv2text,
97
+ 'Idefics2': idefics_conv2text,
98
+ }
99
+
100
+ def get_conv_func(model_name):
101
+ return conv_dict[model_name]
102
+
103
+ if __name__ == '__main__':
104
+ print(fair_query('Internlm'))
105
+ print(safe_query('Internlm'))
utils/img_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+ from torchvision.transforms.functional import InterpolationMode
4
+ import torchvision.transforms.functional as F
5
+
6
+ from .ixc_utils import HD_transform
7
+
8
+ class Resize_with_pad:
9
+ def __init__(self, w=490, h=490):
10
+ self.w = w
11
+ self.h = h
12
+
13
+ def __call__(self, image):
14
+ w_1, h_1 = image.size
15
+ ratio_f = self.w / self.h
16
+ ratio_1 = w_1 / h_1
17
+ # check if the original and final aspect ratios are the same within a margin
18
+ if round(ratio_1, 2) != round(ratio_f, 2):
19
+
20
+ # padding to preserve aspect ratio
21
+ hp = int(w_1/ratio_f - h_1)
22
+ wp = int(ratio_f * h_1 - w_1)
23
+ if hp > 0 and wp < 0:
24
+ hp = hp // 2
25
+ image = F.pad(image, (0, hp, 0, hp), 0, "constant")
26
+ return F.resize(image, [self.h, self.w], interpolation=InterpolationMode.BICUBIC)
27
+
28
+ elif hp < 0 and wp > 0:
29
+ wp = wp // 2
30
+ image = F.pad(image, (wp, 0, wp, 0), 0, "constant")
31
+ return F.resize(image, [self.h, self.w], interpolation=InterpolationMode.BICUBIC)
32
+
33
+ else:
34
+ return F.resize(image, [self.h, self.w], interpolation=InterpolationMode.BICUBIC)
35
+
36
+ class ImageProcessor:
37
+
38
+ def __init__(self, image_size=224):
39
+ self.resizepad = Resize_with_pad(image_size, image_size)
40
+ mean = (0.48145466, 0.4578275, 0.40821073)
41
+ std = (0.26862954, 0.26130258, 0.27577711)
42
+ self.normalize = transforms.Normalize(mean, std)
43
+
44
+ self.transform = transforms.Compose([
45
+ # transforms.Resize((image_size, image_size),
46
+ # interpolation=InterpolationMode.BICUBIC),
47
+ transforms.ToTensor(),
48
+ self.normalize,
49
+ ])
50
+
51
+ def __call__(self, itemname):
52
+ try:
53
+ if isinstance(itemname, Image.Image):
54
+ item = itemname.convert('RGB')
55
+ else:
56
+ item = Image.open(itemname).convert('RGB')
57
+ item = self.resizepad(item)
58
+ except Exception as e:
59
+ print(e, flush=True)
60
+ print('error img', itemname, flush=True)
61
+ exit()
62
+ return self.transform(item)
63
+
64
+ class ImageProcessorHD:
65
+
66
+ def __init__(self, image_size=224, hd_num=-1):
67
+ mean = (0.48145466, 0.4578275, 0.40821073)
68
+ std = (0.26862954, 0.26130258, 0.27577711)
69
+ self.normalize = transforms.Normalize(mean, std)
70
+ self.hd_num = hd_num
71
+
72
+ self.transform = transforms.Compose([
73
+ transforms.ToTensor(),
74
+ self.normalize,
75
+ ])
76
+
77
+ def __call__(self, item):
78
+ item = Image.open(item).convert('RGB')
79
+ return self.transform(HD_transform(item, hd_num=self.hd_num))
80
+
81
+
82
+ def get_internlm_processor():
83
+ return ImageProcessor(image_size=490)
84
+
85
+
86
+ processor_dict = {
87
+ 'Internlm': get_internlm_processor,
88
+ }
89
+
90
+ def get_image_processor(model_name):
91
+ return processor_dict[model_name]()
utils/ixc_utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+
5
+ def padding_336(b):
6
+ width, height = b.size
7
+ tar = int(np.ceil(height / 336) * 336)
8
+ top_padding = int((tar - height) / 2)
9
+ bottom_padding = tar - height - top_padding
10
+ left_padding = 0
11
+ right_padding = 0
12
+ b = transforms.functional.pad(
13
+ b, [left_padding, top_padding, right_padding, bottom_padding],
14
+ fill=[255, 255, 255])
15
+
16
+ return b
17
+
18
+
19
+ def HD_transform(img, hd_num=16):
20
+ width, height = img.size
21
+ trans = False
22
+ if width < height:
23
+ img = img.transpose(Image.TRANSPOSE)
24
+ trans = True
25
+ width, height = img.size
26
+ ratio = (width / height)
27
+ scale = 1
28
+ while scale * np.ceil(scale / ratio) <= hd_num:
29
+ scale += 1
30
+ scale -= 1
31
+ new_w = int(scale * 336)
32
+ new_h = int(new_w / ratio)
33
+
34
+ img = transforms.functional.resize(
35
+ img,
36
+ [new_h, new_w],
37
+ )
38
+ img = padding_336(img)
39
+ width, height = img.size
40
+ if trans:
41
+ img = img.transpose(Image.TRANSPOSE)
42
+
43
+ return img
utils/model_utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import transformers
4
+ from transformers import deepspeed
5
+ from deepspeed import zero
6
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
7
+ from .arguments import TrainingArguments, DataArguments, LoraArguments
8
+
9
+ from transformers.modeling_utils import _load_state_dict_into_model
10
+ from model import get_model
11
+
12
+ def maybe_zero_3(param):
13
+ if hasattr(param, 'ds_id'):
14
+ assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
15
+ with zero.GatheredParameters([param]):
16
+ param = param.data.detach().cpu().clone()
17
+ else:
18
+ param = param.detach().cpu().clone()
19
+ return param
20
+
21
+
22
+ # Borrowed from peft.utils.get_peft_model_state_dict
23
+ def get_peft_state_maybe_zero_3(named_params, bias):
24
+ if bias == 'none':
25
+ to_return = {k: t for k, t in named_params if 'lora_' in k}
26
+ elif bias == 'all':
27
+ to_return = {
28
+ k: t
29
+ for k, t in named_params if 'lora_' in k or 'bias' in k
30
+ }
31
+ elif bias == 'lora_only':
32
+ to_return = {}
33
+ maybe_lora_bias = {}
34
+ lora_bias_names = set()
35
+ for k, t in named_params:
36
+ if 'lora_' in k:
37
+ to_return[k] = t
38
+ bias_name = k.split('lora_')[0] + 'bias'
39
+ lora_bias_names.add(bias_name)
40
+ elif 'bias' in k:
41
+ maybe_lora_bias[k] = t
42
+ for k, t in maybe_lora_bias:
43
+ if bias_name in lora_bias_names:
44
+ to_return[bias_name] = t
45
+ else:
46
+ raise NotImplementedError
47
+ to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
48
+ return to_return
49
+
50
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
51
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
52
+ if require_grad_only:
53
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
54
+ to_return = {k: maybe_zero_3(v).cpu() for k, v in to_return.items()}
55
+ return to_return
56
+
57
+
58
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
59
+ output_dir: str,
60
+ bias='none'):
61
+ """Collects the state dict and dump to disk."""
62
+ # check if zero3 mode enabled
63
+ if deepspeed.is_deepspeed_zero3_enabled():
64
+ state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
65
+ else:
66
+ if trainer.args.use_lora:
67
+ state_dict = get_peft_state_maybe_zero_3(
68
+ trainer.model.named_parameters(), bias)
69
+ else:
70
+ state_dict = trainer.model.state_dict()
71
+ if trainer.args.should_save and trainer.args.local_rank == 0:
72
+ trainer._save(output_dir, state_dict=state_dict)
73
+
74
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(trainer.model.named_parameters())
75
+ torch.save(non_lora_state_dict, os.path.join(output_dir, 'non_lora_trainables.bin'))
76
+
77
+ def init_model(model_path, training_args: TrainingArguments, data_args: DataArguments, lora_args: LoraArguments, model_cfg: dict):
78
+
79
+ model = get_model(
80
+ model_name = model_cfg['model_name'],
81
+ model_path = model_path,
82
+ training_args = training_args,
83
+ data_args = data_args,
84
+ lora_args = lora_args,
85
+ use_caption = model_cfg.get('use_caption', None),
86
+ )
87
+ if model_cfg['model_name'] == 'Idefics2':
88
+ model, tokenizer = model.get_model_processor()
89
+ else:
90
+ model, tokenizer = model.get_model_tokenizer()
91
+
92
+ if training_args.use_lora and lora_args.lora_weight_path != '':
93
+ if lora_args.lora_type == 'lora':
94
+ try:
95
+ delta_path = os.path.join(lora_args.lora_weight_path, 'adapter_model.bin')
96
+ delta_ckpt = torch.load(delta_path, 'cpu')
97
+ except:
98
+ from safetensors.torch import load_file
99
+ delta_path = os.path.join(lora_args.lora_weight_path, 'adapter_model.safetensors')
100
+ delta_ckpt = load_file(delta_path, 'cpu')
101
+ new_dict = {}
102
+ for key, value in delta_ckpt.items():
103
+ new_dict[f'{key[:-7]}.default.weight'] = value
104
+ _load_state_dict_into_model(model, new_dict, start_prefix='')
105
+ print(f'load delta ckpt from {os.path.abspath(delta_path)}')
106
+
107
+ non_lora_ckpt_path = os.path.join(lora_args.lora_weight_path, 'non_lora_trainables.bin')
108
+ if os.path.exists(non_lora_ckpt_path):
109
+ non_lora_trainables = torch.load(non_lora_ckpt_path, map_location='cpu')
110
+ _load_state_dict_into_model(model, non_lora_trainables, start_prefix='')
111
+ print(f'load non lora ckpt from {os.path.abspath(non_lora_ckpt_path)}')
112
+ else:
113
+ raise NotImplementedError
114
+
115
+ return model, tokenizer
116
+
117
+