Update README.md
Browse files
README.md
CHANGED
@@ -1,121 +1,124 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
tag:
|
4 |
-
- vision
|
5 |
-
- image-classification
|
6 |
-
- image-to-text
|
7 |
-
- image-captioning
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
<
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
<
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
ImageGuard
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
import
|
39 |
-
import
|
40 |
-
import
|
41 |
-
import
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
from
|
47 |
-
from utils.
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
model
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
load_dir =
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
Lora_args
|
95 |
-
Lora_args.
|
96 |
-
Lora_args.
|
97 |
-
Lora_args.
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
train_args
|
102 |
-
train_args.
|
103 |
-
train_args.
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
tag:
|
4 |
+
- vision
|
5 |
+
- image-classification
|
6 |
+
- image-to-text
|
7 |
+
- image-captioning
|
8 |
+
base_model:
|
9 |
+
- internlm/internlm-xcomposer2-vl-7b
|
10 |
+
pipeline_tag: image-to-text
|
11 |
+
---
|
12 |
+
|
13 |
+
|
14 |
+
<p align="center">
|
15 |
+
<img src="logo_en.png" width="400"/>
|
16 |
+
<p>
|
17 |
+
|
18 |
+
<p align="center">
|
19 |
+
<b><font size="6">ImageGuard</font></b>
|
20 |
+
<p>
|
21 |
+
|
22 |
+
<div align="center">
|
23 |
+
|
24 |
+
[💻Github Repo](https://github.com/adwardlee/t2i_safety)
|
25 |
+
|
26 |
+
[Paper](https://arxiv.org/abs/)
|
27 |
+
|
28 |
+
</div>
|
29 |
+
|
30 |
+
**ImageGuard** is a vision-language model (VLM) based on [InternLM-XComposer2](https://github.com/InternLM/InternLM-XComposer) for advanced image safety evaluation.
|
31 |
+
|
32 |
+
### Import from Transformers
|
33 |
+
ImageGuard works with transformers>=4.42.
|
34 |
+
|
35 |
+
## Quickstart
|
36 |
+
We provide a simple example to show how to use InternLM-XComposer with 🤗 Transformers.
|
37 |
+
```python
|
38 |
+
import os
|
39 |
+
import json
|
40 |
+
import torch
|
41 |
+
import time
|
42 |
+
import numpy as np
|
43 |
+
import argparse
|
44 |
+
import yaml
|
45 |
+
|
46 |
+
from PIL import Image
|
47 |
+
from utils.img_utils import ImageProcessor
|
48 |
+
from utils.arguments import ModelArguments, DataArguments, EvalArguments, LoraArguments
|
49 |
+
from utils.model_utils import init_model
|
50 |
+
from utils.conv_utils import fair_query, safe_query
|
51 |
+
|
52 |
+
def load_yaml(cfg_path):
|
53 |
+
with open(cfg_path, 'r', encoding='utf-8') as f:
|
54 |
+
result = yaml.load(f.read(), Loader=yaml.FullLoader)
|
55 |
+
return result
|
56 |
+
|
57 |
+
def textprocess(safe=True):
|
58 |
+
if safe:
|
59 |
+
conversation = safe_query('Internlm')
|
60 |
+
else:
|
61 |
+
conversation = fair_query('Internlm')
|
62 |
+
return conversation
|
63 |
+
|
64 |
+
def model_init(
|
65 |
+
model_args: ModelArguments,
|
66 |
+
data_args: DataArguments,
|
67 |
+
training_args: EvalArguments,
|
68 |
+
lora_args: LoraArguments,
|
69 |
+
model_cfg):
|
70 |
+
model, tokenizer = init_model(model_args.model_name_or_path, training_args, data_args, lora_args, model_cfg)
|
71 |
+
model.eval()
|
72 |
+
model.cuda().eval().half()
|
73 |
+
model.tokenizer = tokenizer
|
74 |
+
return model
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
parser = argparse.ArgumentParser()
|
80 |
+
parser.add_argument('--load_dir', required=False, type=str, default='lora/')
|
81 |
+
parser.add_argument('--base_model', type=str, default='internlm/internlm-xcomposer2-vl-7b')
|
82 |
+
args = parser.parse_args()
|
83 |
+
load_dir = args.load_dir
|
84 |
+
config = load_yaml(os.path.join(load_dir, 'config.yaml'))
|
85 |
+
model_cfg = config['model_cfg']
|
86 |
+
data_cfg = config['data_cfg']['data_cfg']
|
87 |
+
model_cfg['model_name'] = 'Internlm'
|
88 |
+
data_cfg['train']['model_name'] = 'Internlm'
|
89 |
+
lora_cfg = config['lora_cfg']
|
90 |
+
training_cfg = config['training_cfg']
|
91 |
+
|
92 |
+
model_args = ModelArguments()
|
93 |
+
model_args.model_name_or_path = args.base_model
|
94 |
+
Lora_args = LoraArguments()
|
95 |
+
Lora_args.lora_alpha = lora_cfg['lora_alpha']
|
96 |
+
Lora_args.lora_bias = lora_cfg['lora_bias']
|
97 |
+
Lora_args.lora_dropout = lora_cfg['lora_dropout']
|
98 |
+
Lora_args.lora_r = lora_cfg['lora_r']
|
99 |
+
Lora_args.lora_target_modules = lora_cfg['lora_target_modules']
|
100 |
+
Lora_args.lora_weight_path = load_dir ### comment for base model testing ### llj ## change ##
|
101 |
+
train_args = EvalArguments()
|
102 |
+
train_args.max_length = training_cfg['max_length']
|
103 |
+
train_args.fix_vit = training_cfg['fix_vit']
|
104 |
+
train_args.fix_sampler = training_cfg['fix_sampler']
|
105 |
+
train_args.use_lora = training_cfg['use_lora']
|
106 |
+
train_args.gradient_checkpointing = training_cfg['gradient_checkpointing']
|
107 |
+
data_args = DataArguments()
|
108 |
+
|
109 |
+
model = model_init(model_args, data_args, train_args, Lora_args, model_cfg)
|
110 |
+
print(' model device: ', model.device, flush=True)
|
111 |
+
|
112 |
+
img = Image.open('punch.png')
|
113 |
+
safe = True ## True for toxicity and privacy, False for fairness
|
114 |
+
prompt = textprocess(safe=safe)
|
115 |
+
vis_processor = ImageProcessor(image_size=490)
|
116 |
+
image = vis_processor(img)[None, :, :, :]
|
117 |
+
with torch.cuda.amp.autocast():
|
118 |
+
response, _ = model.chat(model.tokenizer, prompt, image, history=[], do_sample=False, meta_instruction=None)
|
119 |
+
print(response)
|
120 |
+
# unsafe\n violence
|
121 |
+
```
|
122 |
+
|
123 |
+
### Open Source License
|
124 |
+
The code is licensed under Apache-2.0, while model weights are fully open for academic research and also allow free commercial usage.
|