Spaces:
Sleeping
Sleeping
MUTED64
commited on
Commit
·
cefcefa
1
Parent(s):
0668dff
change scorer
Browse files- __init__.py +0 -0
- api.py +5 -0
- app.py +78 -26
- requirements.txt +2 -1
- setup.py +28 -0
- waifu_scorer/__init__.py +1 -0
- waifu_scorer/__pycache__/__init__.cpython-310.pyc +0 -0
- waifu_scorer/__pycache__/__init__.cpython-312.pyc +0 -0
- waifu_scorer/__pycache__/mlp.cpython-312.pyc +0 -0
- waifu_scorer/__pycache__/predict.cpython-310.pyc +0 -0
- waifu_scorer/__pycache__/predict.cpython-312.pyc +0 -0
- waifu_scorer/__pycache__/train.cpython-312.pyc +0 -0
- waifu_scorer/__pycache__/train_utils.cpython-312.pyc +0 -0
- waifu_scorer/__pycache__/ui.cpython-312.pyc +0 -0
- waifu_scorer/__pycache__/utils.cpython-312.pyc +0 -0
- waifu_scorer/mlp.py +127 -0
- waifu_scorer/predict.py +63 -0
- waifu_scorer/train.py +307 -0
- waifu_scorer/train_utils.py +333 -0
- waifu_scorer/ui.py +91 -0
- waifu_scorer/utils.py +72 -0
__init__.py
ADDED
File without changes
|
api.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from waifu_scorer.ui import launch, parse_args
|
2 |
+
|
3 |
+
if __name__ == '__main__':
|
4 |
+
args = parse_args()
|
5 |
+
launch(args)
|
app.py
CHANGED
@@ -1,36 +1,88 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from PIL import Image
|
4 |
-
from torchvision.transforms import functional as F
|
5 |
from typing import List
|
6 |
-
from
|
|
|
7 |
|
8 |
# Load the pre-trained model
|
9 |
-
model_path = "1024_MLP_best-MSE4.1636_ep75.pth"
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
return scores
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
description="Predict the score of a kemono based on aesthetic features.",
|
33 |
)
|
34 |
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from PIL import Image
|
|
|
4 |
from typing import List
|
5 |
+
from waifu_scorer.mlp import MLP
|
6 |
+
import clip
|
7 |
|
8 |
# Load the pre-trained model
|
9 |
+
model_path = "./1024_MLP_best-MSE4.1636_ep75.pth"
|
10 |
+
device = "cpu"
|
11 |
+
dtype = torch.float32
|
12 |
+
s = torch.load(model_path, map_location=device)
|
13 |
+
model = MLP(input_size=768)
|
14 |
+
model.load_state_dict(s)
|
15 |
+
model.to(device=device, dtype=dtype)
|
16 |
+
|
17 |
+
model2, preprocess = clip.load("ViT-L/14", device=device)
|
18 |
+
|
19 |
+
def normalized(a: torch.Tensor, order=2, dim=-1):
|
20 |
+
l2 = a.norm(order, dim, keepdim=True)
|
21 |
+
l2[l2 == 0] = 1
|
22 |
+
return a / l2
|
23 |
+
|
24 |
+
@torch.no_grad()
|
25 |
+
def encode_images(images: List[Image.Image], model2, preprocess, device='cpu') -> torch.Tensor:
|
26 |
+
if not isinstance(images, list):
|
27 |
+
images = [images]
|
28 |
+
image_tensors = [preprocess(img).unsqueeze(0) for img in images]
|
29 |
+
image_batch = torch.cat(image_tensors).to(device)
|
30 |
+
image_features = model2.encode_image(image_batch)
|
31 |
+
im_emb_arr = normalized(image_features).cpu().float()
|
32 |
+
return im_emb_arr
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def predict(inputs: List[Image.Image]) -> float:
|
36 |
+
images = encode_images(inputs, model2, preprocess, device=device).to(device=device, dtype=dtype)
|
37 |
+
predictions = model(images)
|
38 |
+
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
|
39 |
return scores
|
40 |
|
41 |
+
|
42 |
+
from waifu_scorer.predict import WaifuScorer, load_model
|
43 |
+
scorer = WaifuScorer(
|
44 |
+
model_path=model_path,
|
45 |
+
model_type="mlp",
|
46 |
+
device=device,
|
|
|
47 |
)
|
48 |
|
49 |
+
with gr.Blocks() as demo:
|
50 |
+
with gr.Row():
|
51 |
+
with gr.Column():
|
52 |
+
image = gr.Image(
|
53 |
+
label='Image',
|
54 |
+
type='pil',
|
55 |
+
height=512,
|
56 |
+
sources=['upload', 'clipboard'],
|
57 |
+
)
|
58 |
+
with gr.Column():
|
59 |
+
with gr.Row():
|
60 |
+
model_path = gr.Textbox(
|
61 |
+
label='Model Path',
|
62 |
+
value=model_path,
|
63 |
+
placeholder='Path or URL to the model file',
|
64 |
+
# interactive=not fix_model_path,
|
65 |
+
)
|
66 |
+
with gr.Row():
|
67 |
+
score = gr.Number(
|
68 |
+
label='Score',
|
69 |
+
)
|
70 |
+
|
71 |
+
def change_model(model_path):
|
72 |
+
scorer.mlp = load_model(model_path, model_type="mlp", device=device)
|
73 |
+
print(f"Model changed to `{model_path}`")
|
74 |
+
return gr.update()
|
75 |
+
|
76 |
+
model_path.submit(
|
77 |
+
fn=change_model,
|
78 |
+
inputs=model_path,
|
79 |
+
outputs=model_path,
|
80 |
+
)
|
81 |
+
|
82 |
+
image.change(
|
83 |
+
fn=lambda image: predict([image]*2)[0] if image is not None else None,
|
84 |
+
inputs=image,
|
85 |
+
outputs=score,
|
86 |
+
)
|
87 |
+
|
88 |
+
demo.launch()
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ torch
|
|
3 |
Pillow
|
4 |
torchvision
|
5 |
typing
|
6 |
-
|
|
|
|
3 |
Pillow
|
4 |
torchvision
|
5 |
typing
|
6 |
+
pytorch_lightning
|
7 |
+
clip
|
setup.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
with open('./requirements.txt') as f:
|
3 |
+
requirements = f.read().splitlines()
|
4 |
+
|
5 |
+
for i, req in enumerate(requirements):
|
6 |
+
if req.startswith('git+'):
|
7 |
+
package_name = req.split('/')[-1].split('.')[0] # Extract package name from URL
|
8 |
+
requirements[i] = f"{package_name} @ {req}"
|
9 |
+
|
10 |
+
setup(
|
11 |
+
name='waifu-scorer',
|
12 |
+
version='0.1',
|
13 |
+
packages=find_packages(),
|
14 |
+
include_package_data=True,
|
15 |
+
description='Image caption tools',
|
16 |
+
long_description='',
|
17 |
+
author='euge',
|
18 |
+
author_email='[email protected]',
|
19 |
+
url='https://github.com/Eugeoter/waifu-scorer',
|
20 |
+
install_requires=requirements,
|
21 |
+
classifiers=[
|
22 |
+
'Development Status :: 3 - Alpha',
|
23 |
+
'Intended Audience :: Developers',
|
24 |
+
'License :: OSI Approved :: MIT License',
|
25 |
+
'Programming Language :: Python :: 3',
|
26 |
+
'Programming Language :: Python :: 3.7',
|
27 |
+
],
|
28 |
+
)
|
waifu_scorer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .predict import WaifuScorer
|
waifu_scorer/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (233 Bytes). View file
|
|
waifu_scorer/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (240 Bytes). View file
|
|
waifu_scorer/__pycache__/mlp.cpython-312.pyc
ADDED
Binary file (5.52 kB). View file
|
|
waifu_scorer/__pycache__/predict.cpython-310.pyc
ADDED
Binary file (3.04 kB). View file
|
|
waifu_scorer/__pycache__/predict.cpython-312.pyc
ADDED
Binary file (4.98 kB). View file
|
|
waifu_scorer/__pycache__/train.cpython-312.pyc
ADDED
Binary file (14.2 kB). View file
|
|
waifu_scorer/__pycache__/train_utils.cpython-312.pyc
ADDED
Binary file (14.3 kB). View file
|
|
waifu_scorer/__pycache__/ui.cpython-312.pyc
ADDED
Binary file (3.89 kB). View file
|
|
waifu_scorer/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (3.75 kB). View file
|
|
waifu_scorer/mlp.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
|
5 |
+
|
6 |
+
class MLP(pl.LightningModule):
|
7 |
+
def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True):
|
8 |
+
super().__init__()
|
9 |
+
self.input_size = input_size
|
10 |
+
self.xcol = xcol
|
11 |
+
self.ycol = ycol
|
12 |
+
# self.layers = nn.Sequential(
|
13 |
+
# nn.Linear(self.input_size, 2048),
|
14 |
+
# nn.ReLU(),
|
15 |
+
# nn.BatchNorm1d(2048),
|
16 |
+
# nn.Dropout(0.4),
|
17 |
+
|
18 |
+
# nn.Linear(2048, 512),
|
19 |
+
# nn.ReLU(),
|
20 |
+
# nn.BatchNorm1d(512),
|
21 |
+
# nn.Dropout(0.3),
|
22 |
+
|
23 |
+
# nn.Linear(512, 256),
|
24 |
+
# nn.ReLU(),
|
25 |
+
# nn.BatchNorm1d(256),
|
26 |
+
# nn.Dropout(0.2),
|
27 |
+
|
28 |
+
# nn.Linear(256, 128),
|
29 |
+
# nn.ReLU(),
|
30 |
+
# nn.BatchNorm1d(128),
|
31 |
+
# nn.Dropout(0.1),
|
32 |
+
|
33 |
+
# nn.Linear(128, 32),
|
34 |
+
# nn.ReLU(),
|
35 |
+
# nn.Linear(32, 1)
|
36 |
+
# )
|
37 |
+
self.layers = nn.Sequential(
|
38 |
+
nn.Linear(self.input_size, 1024),
|
39 |
+
# nn.ReLU(),
|
40 |
+
nn.Dropout(0.2),
|
41 |
+
nn.Linear(1024, 128),
|
42 |
+
# nn.ReLU(),
|
43 |
+
nn.Dropout(0.2),
|
44 |
+
nn.Linear(128, 64),
|
45 |
+
# nn.ReLU(),
|
46 |
+
nn.Dropout(0.1),
|
47 |
+
|
48 |
+
nn.Linear(64, 16),
|
49 |
+
# nn.ReLU(),
|
50 |
+
|
51 |
+
nn.Linear(16, 1)
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.layers(x)
|
56 |
+
|
57 |
+
def training_step(self, batch, batch_idx):
|
58 |
+
x = batch[self.xcol]
|
59 |
+
y = batch[self.ycol].reshape(-1, 1)
|
60 |
+
x_hat = self.layers(x)
|
61 |
+
loss = F.mse_loss(x_hat, y)
|
62 |
+
return loss
|
63 |
+
|
64 |
+
def validation_step(self, batch, batch_idx):
|
65 |
+
x = batch[self.xcol]
|
66 |
+
y = batch[self.ycol].reshape(-1, 1)
|
67 |
+
x_hat = self.layers(x)
|
68 |
+
loss = F.mse_loss(x_hat, y)
|
69 |
+
return loss
|
70 |
+
|
71 |
+
# def configure_optimizers(self):
|
72 |
+
# optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
73 |
+
# return optimizer
|
74 |
+
|
75 |
+
|
76 |
+
class ResidualBlock(nn.Module):
|
77 |
+
def __init__(self, input_size, output_size, batch_norm=True, dropout_rate=0.0):
|
78 |
+
super(ResidualBlock, self).__init__()
|
79 |
+
self.linear = nn.Linear(input_size, output_size)
|
80 |
+
self.relu = nn.ReLU()
|
81 |
+
self.batch_norm = nn.BatchNorm1d(output_size) if batch_norm else nn.Identity()
|
82 |
+
self.dropout = nn.Dropout(dropout_rate)
|
83 |
+
self.adjust_dims = nn.Linear(input_size, output_size) if input_size != output_size else nn.Identity()
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
identity = self.adjust_dims(x)
|
87 |
+
out = self.linear(x)
|
88 |
+
out = self.relu(out)
|
89 |
+
out = self.batch_norm(out)
|
90 |
+
out = self.dropout(out)
|
91 |
+
out += identity
|
92 |
+
out = self.relu(out)
|
93 |
+
return out
|
94 |
+
|
95 |
+
|
96 |
+
class ResMLP(pl.LightningModule):
|
97 |
+
def __init__(self, input_size, xcol='emb', ycol='avg_rating', batch_norm=True):
|
98 |
+
super().__init__()
|
99 |
+
self.input_size = input_size
|
100 |
+
self.xcol = xcol
|
101 |
+
self.ycol = ycol
|
102 |
+
self.layers = nn.Sequential(
|
103 |
+
ResidualBlock(input_size, 2048, batch_norm, dropout_rate=0.3),
|
104 |
+
ResidualBlock(2048, 512, batch_norm, dropout_rate=0.3),
|
105 |
+
ResidualBlock(512, 256, batch_norm, dropout_rate=0.2),
|
106 |
+
ResidualBlock(256, 128, batch_norm, dropout_rate=0.1),
|
107 |
+
nn.Linear(128, 32),
|
108 |
+
nn.ReLU(),
|
109 |
+
nn.Linear(32, 1)
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
return self.layers(x)
|
114 |
+
|
115 |
+
def training_step(self, batch, batch_idx):
|
116 |
+
x = batch[self.xcol]
|
117 |
+
y = batch[self.ycol].reshape(-1, 1)
|
118 |
+
x_hat = self.layers(x)
|
119 |
+
loss = F.mse_loss(x_hat, y)
|
120 |
+
return loss
|
121 |
+
|
122 |
+
def validation_step(self, batch, batch_idx):
|
123 |
+
x = batch[self.xcol]
|
124 |
+
y = batch[self.ycol].reshape(-1, 1)
|
125 |
+
x_hat = self.layers(x)
|
126 |
+
loss = F.mse_loss(x_hat, y)
|
127 |
+
return loss
|
waifu_scorer/predict.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import clip
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
from typing import List
|
6 |
+
from .utils import get_model_cls
|
7 |
+
|
8 |
+
WAIFU_FILTER_V1_MODEL_REPO = 'Eugeoter/waifu-filter-v1/waifu-filter-v1.pth'
|
9 |
+
|
10 |
+
|
11 |
+
def download_from_url(url):
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
split = url.split("/")
|
14 |
+
username, repo_id, model_name = split[-3], split[-2], split[-1]
|
15 |
+
model_path = hf_hub_download(f"{username}/{repo_id}", model_name)
|
16 |
+
return model_path
|
17 |
+
|
18 |
+
|
19 |
+
def load_model(model_path: str = None, model_type='mlp', input_size=768, device: str = 'cuda', dtype=torch.float32):
|
20 |
+
model_cls = get_model_cls(model_type)
|
21 |
+
model = model_cls(input_size=input_size)
|
22 |
+
if not os.path.isfile(model_path):
|
23 |
+
model_path = download_from_url(model_path)
|
24 |
+
s = torch.load(model_path, map_location=device)
|
25 |
+
model.load_state_dict(s)
|
26 |
+
model.to(device=device, dtype=dtype)
|
27 |
+
return model
|
28 |
+
|
29 |
+
|
30 |
+
def normalized(a: torch.Tensor, order=2, dim=-1):
|
31 |
+
l2 = a.norm(order, dim, keepdim=True)
|
32 |
+
l2[l2 == 0] = 1
|
33 |
+
return a / l2
|
34 |
+
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def encode_images(images: List[Image.Image], model2, preprocess, device='cuda') -> torch.Tensor:
|
38 |
+
if isinstance(images, Image.Image):
|
39 |
+
images = [images]
|
40 |
+
image_tensors = [preprocess(img).unsqueeze(0) for img in images]
|
41 |
+
image_batch = torch.cat(image_tensors).to(device)
|
42 |
+
image_features = model2.encode_image(image_batch)
|
43 |
+
im_emb_arr = normalized(image_features).cpu().float()
|
44 |
+
return im_emb_arr
|
45 |
+
|
46 |
+
|
47 |
+
class WaifuScorer:
|
48 |
+
def __init__(self, model_path: str = WAIFU_FILTER_V1_MODEL_REPO, model_type='mlp', device: str = None, dtype=torch.float32):
|
49 |
+
print(f"loading model from `{model_path}`...")
|
50 |
+
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
51 |
+
self.mlp = load_model(model_path, model_type=model_type, input_size=768, device=device, dtype=dtype)
|
52 |
+
self.mlp.eval()
|
53 |
+
self.model2, self.preprocess = clip.load("ViT-L/14", device=device)
|
54 |
+
self.device = self.mlp.device
|
55 |
+
self.dtype = self.mlp.dtype
|
56 |
+
print(f"model loaded: cls={model_type} | device={self.device} | dtype={self.dtype}")
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def predict(self, images: List[Image.Image]) -> float:
|
60 |
+
images = encode_images(images, self.model2, self.preprocess, device=self.device).to(device=self.device, dtype=self.dtype)
|
61 |
+
predictions = self.mlp(images)
|
62 |
+
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist()
|
63 |
+
return scores
|
waifu_scorer/train.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = "0" # in case you are using a multi GPU workstation, choose your GPU here
|
2 |
+
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import torch.nn as nn
|
7 |
+
from pathlib import Path
|
8 |
+
from tqdm import tqdm
|
9 |
+
from accelerate import Accelerator
|
10 |
+
from typing import Literal, Callable, Optional, Union
|
11 |
+
from waifuset.utils import log_utils
|
12 |
+
from waifuset.classes import Dataset, ImageInfo
|
13 |
+
from . import mlp, utils, train_utils
|
14 |
+
|
15 |
+
StrPath = Union[str, Path]
|
16 |
+
|
17 |
+
|
18 |
+
def train(
|
19 |
+
dataset_source,
|
20 |
+
save_path,
|
21 |
+
resume_path: StrPath = None,
|
22 |
+
data_preprocessor: Optional[Callable[[ImageInfo], float]] = None,
|
23 |
+
rating_func_type: Union[Callable[[ImageInfo], float], Literal['direct', 'label', 'quality']] = 'quality',
|
24 |
+
num_train_epochs=50,
|
25 |
+
learning_rate=1e-3,
|
26 |
+
train_batch_size=256,
|
27 |
+
shuffle=True,
|
28 |
+
flip_aug=True,
|
29 |
+
val_batch_size=512,
|
30 |
+
val_every_n_epochs=1,
|
31 |
+
val_percentage=0.05, # 5% of the training data will be used for validation
|
32 |
+
save_best_model=True,
|
33 |
+
clip_batch_size=1,
|
34 |
+
cache_to_disk: bool = False,
|
35 |
+
cache_path: StrPath = None,
|
36 |
+
mixed_precision=None,
|
37 |
+
max_data_loader_n_workers: int = 4,
|
38 |
+
persistent_workers=False,
|
39 |
+
mlp_model_type: Literal['default', 'large'] = 'default',
|
40 |
+
clip_model_name: str = "ViT-L/14",
|
41 |
+
input_size: int = 768,
|
42 |
+
batch_norm: bool = True,
|
43 |
+
):
|
44 |
+
r"""
|
45 |
+
:param dataset_source: any dataset source, e.g. path to the dataset.
|
46 |
+
:param save_path: path to save the trained model.
|
47 |
+
:param resume_path: path to the model to resume from.
|
48 |
+
:param cache_to_disk: whether to cache the training data to disk.
|
49 |
+
:param cache_path: path to the cached training data. If not exists, will be created from `dataset_source`. If exists, will be loaded from disk.
|
50 |
+
:param num_train_epochs: number of training epochs.
|
51 |
+
:param learning_rate: learning rate.
|
52 |
+
:param train_batch_size: training batch size.
|
53 |
+
:param val_batch_size: validation batch size.
|
54 |
+
:param val_every_n_epochs: validation frequency.
|
55 |
+
:param val_percentage: percentage of the training data to be used for validation.
|
56 |
+
:param encoder_batch_size: batch size for encoding images.
|
57 |
+
:param mixed_precision: whether to use mixed precision training.
|
58 |
+
:param max_data_loader_n_workers: maximum number of workers for data loaders.
|
59 |
+
:param persistent_workers: whether to use persistent workers for data loaders.
|
60 |
+
:param input_size: input size of the model.
|
61 |
+
"""
|
62 |
+
log_utils.info(f"prepare for training")
|
63 |
+
accelerator = Accelerator(mixed_precision=mixed_precision)
|
64 |
+
weight_dtype = train_utils.prepare_dtype(mixed_precision)
|
65 |
+
device = accelerator.device
|
66 |
+
max_data_loader_n_workers = min(max_data_loader_n_workers, os.cpu_count()-1)
|
67 |
+
if callable(rating_func_type):
|
68 |
+
rating_func = rating_func_type
|
69 |
+
else:
|
70 |
+
rating_func = train_utils.get_rating_func(rating_func_type)
|
71 |
+
|
72 |
+
model2, preprocess = utils.load_clip_models(name=clip_model_name, device=device) # RN50x64
|
73 |
+
|
74 |
+
dataset = Dataset(dataset_source, verbose=True, condition=lambda img_info: img_info.image_path.is_file())
|
75 |
+
if data_preprocessor:
|
76 |
+
for img_key, img_info in dataset.items():
|
77 |
+
img_info = data_preprocessor(img_info)
|
78 |
+
keys = list(dataset.keys())
|
79 |
+
random.shuffle(keys)
|
80 |
+
dataset = Dataset({k: dataset[k] for k in keys})
|
81 |
+
|
82 |
+
num_pos = 0
|
83 |
+
num_neg = 0
|
84 |
+
num_mid = 0
|
85 |
+
for img_key, img_info in dataset.items():
|
86 |
+
rating = rating_func(img_info)
|
87 |
+
if rating == 10:
|
88 |
+
num_pos += 1
|
89 |
+
elif rating == 0:
|
90 |
+
num_neg += 1
|
91 |
+
else:
|
92 |
+
num_mid += 1
|
93 |
+
log_utils.info(f"num_pos: {num_pos} | num_mid: {num_mid} | num_neg: {num_neg}")
|
94 |
+
|
95 |
+
train_size = int(len(dataset) * (1 - val_percentage))
|
96 |
+
val_size = len(dataset) - train_size
|
97 |
+
train_dataset, val_dataset = Dataset(dataset.values()[:train_size]), Dataset(dataset.values()[train_size:])
|
98 |
+
|
99 |
+
log_utils.info(f"train_size: {train_size} | val_size: {val_size}")
|
100 |
+
|
101 |
+
train_dataset, train_loader = train_utils.prepare_dataloader(
|
102 |
+
train_dataset,
|
103 |
+
batch_size=train_batch_size,
|
104 |
+
clip_batch_size=clip_batch_size,
|
105 |
+
model2=model2,
|
106 |
+
preprocess=preprocess,
|
107 |
+
input_size=input_size,
|
108 |
+
rating_func=rating_func,
|
109 |
+
shuffle=shuffle,
|
110 |
+
flip_aug=flip_aug,
|
111 |
+
cache_to_disk=cache_to_disk,
|
112 |
+
cache_path=cache_path,
|
113 |
+
max_data_loader_n_workers=max_data_loader_n_workers,
|
114 |
+
persistent_workers=persistent_workers,
|
115 |
+
device=device,
|
116 |
+
)
|
117 |
+
|
118 |
+
val_dataset, val_loader = train_utils.prepare_dataloader(
|
119 |
+
val_dataset,
|
120 |
+
batch_size=val_batch_size,
|
121 |
+
clip_batch_size=clip_batch_size,
|
122 |
+
model2=model2,
|
123 |
+
preprocess=preprocess,
|
124 |
+
rating_func=rating_func,
|
125 |
+
shuffle=shuffle,
|
126 |
+
flip_aug=flip_aug,
|
127 |
+
cache_to_disk=cache_to_disk,
|
128 |
+
cache_path=cache_path,
|
129 |
+
max_data_loader_n_workers=max_data_loader_n_workers,
|
130 |
+
persistent_workers=persistent_workers,
|
131 |
+
device=device,
|
132 |
+
)
|
133 |
+
|
134 |
+
rating_stat = {}
|
135 |
+
for i in range(len(train_dataset)):
|
136 |
+
# to list
|
137 |
+
ratings: torch.Tensor = train_dataset[i]['ratings']
|
138 |
+
ratings = ratings.squeeze().tolist()
|
139 |
+
for rating in ratings:
|
140 |
+
if rating not in rating_stat:
|
141 |
+
rating_stat[rating] = 0
|
142 |
+
rating_stat[rating] += 1
|
143 |
+
|
144 |
+
log_utils.info("rating_stat:\n", '\n'.join(f'{k}: {v}' for k, v in rating_stat.items()))
|
145 |
+
|
146 |
+
# prepare model
|
147 |
+
|
148 |
+
model: mlp.MLP = utils.load_model(resume_path, model_type=mlp_model_type, input_size=input_size, batch_norm=batch_norm, device=device, dtype=weight_dtype)
|
149 |
+
|
150 |
+
# import prodigyopt
|
151 |
+
# print(f"use Prodigy optimizer | {optimizer_kwargs}")
|
152 |
+
# optimizer_class = prodigyopt.Prodigy
|
153 |
+
# optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
154 |
+
|
155 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
156 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)
|
157 |
+
|
158 |
+
# choose the loss you want to optimize for
|
159 |
+
criterion = nn.MSELoss(reduction='mean')
|
160 |
+
criterion2 = nn.L1Loss(reduction='mean')
|
161 |
+
|
162 |
+
model, optimizer, train_loader, val_loader = accelerator.prepare(
|
163 |
+
model, optimizer, train_loader, val_loader
|
164 |
+
)
|
165 |
+
|
166 |
+
log_utils.info(f"device: {accelerator.device}")
|
167 |
+
|
168 |
+
# training loop
|
169 |
+
best_loss = 999 # best validation loss
|
170 |
+
total_train_steps = len(train_loader) * num_train_epochs
|
171 |
+
progress_bar = tqdm(range(total_train_steps), position=0, leave=True)
|
172 |
+
print(f"total_train_steps: {total_train_steps}")
|
173 |
+
|
174 |
+
class LossRecorder:
|
175 |
+
def __init__(self):
|
176 |
+
self.loss_list = []
|
177 |
+
self.loss_total: float = 0.0
|
178 |
+
|
179 |
+
def add(self, *, epoch: int, step: int, loss: float) -> None:
|
180 |
+
if epoch == 0:
|
181 |
+
self.loss_list.append(loss)
|
182 |
+
else:
|
183 |
+
self.loss_total -= self.loss_list[step]
|
184 |
+
self.loss_list[step] = loss
|
185 |
+
self.loss_total += loss
|
186 |
+
|
187 |
+
@property
|
188 |
+
def moving_average(self) -> float:
|
189 |
+
return self.loss_total / len(self.loss_list)
|
190 |
+
|
191 |
+
loss_recorder = LossRecorder()
|
192 |
+
model.requires_grad_(True)
|
193 |
+
save_on_end = False
|
194 |
+
|
195 |
+
try:
|
196 |
+
for epoch in range(num_train_epochs):
|
197 |
+
model.train()
|
198 |
+
losses = []
|
199 |
+
losses2 = []
|
200 |
+
for step, input_data in enumerate(train_loader):
|
201 |
+
optimizer.zero_grad(set_to_none=True)
|
202 |
+
im_emb_arr: torch.Tensor = input_data['im_emb_arrs'].to(accelerator.device).to(dtype=weight_dtype) # shape: (batch_size, input_size)
|
203 |
+
rating: torch.Tensor = input_data['ratings'].to(accelerator.device).to(dtype=weight_dtype) # shape: (batch_size, 1)
|
204 |
+
|
205 |
+
# randomize the rating
|
206 |
+
# rating_std = 0.5
|
207 |
+
# rating = rating + torch.randn_like(rating) * rating_std
|
208 |
+
|
209 |
+
# log_utils.debug(f"x.dtype: {x.dtype} | y.dtype: {y.dtype} | model.dtype: {model.dtype}")
|
210 |
+
|
211 |
+
with accelerator.autocast():
|
212 |
+
output = model(im_emb_arr)
|
213 |
+
|
214 |
+
loss = criterion(output, rating)
|
215 |
+
|
216 |
+
accelerator.backward(loss)
|
217 |
+
|
218 |
+
losses.append(loss.detach().item())
|
219 |
+
|
220 |
+
optimizer.step()
|
221 |
+
|
222 |
+
# if step % 1000 == 0:
|
223 |
+
# print('\tEpoch %d | Batch %d | Loss %6.2f' % (epoch, step, loss.item()))
|
224 |
+
# # print(y)
|
225 |
+
|
226 |
+
progress_bar.update(1)
|
227 |
+
|
228 |
+
current_loss = loss.detach().item()
|
229 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
230 |
+
avr_loss: float = loss_recorder.moving_average
|
231 |
+
pbar_logs = {
|
232 |
+
'lr': f"{lr_scheduler.get_last_lr()[0]:.3e}",
|
233 |
+
'epoch': epoch,
|
234 |
+
'loss': avr_loss,
|
235 |
+
}
|
236 |
+
progress_bar.set_postfix(pbar_logs)
|
237 |
+
|
238 |
+
progress_bar.write('epoch %d | avg loss %6.6f' % (epoch, avr_loss))
|
239 |
+
|
240 |
+
# validation
|
241 |
+
if accelerator.is_main_process and epoch > 0 and epoch % val_every_n_epochs == 0:
|
242 |
+
model.eval()
|
243 |
+
with torch.no_grad():
|
244 |
+
losses = []
|
245 |
+
losses2 = []
|
246 |
+
for step, input_data in enumerate(val_loader):
|
247 |
+
# optimizer.zero_grad(set_to_none=True)
|
248 |
+
im_emb_arr = input_data['im_emb_arrs'].to(accelerator.device).to(dtype=weight_dtype)
|
249 |
+
rating = input_data['ratings'].to(accelerator.device).to(dtype=weight_dtype)
|
250 |
+
|
251 |
+
with accelerator.autocast():
|
252 |
+
output = model(im_emb_arr)
|
253 |
+
loss = criterion(output, rating)
|
254 |
+
lossMAE = criterion2(output, rating)
|
255 |
+
# loss.backward()
|
256 |
+
losses.append(loss.detach().item())
|
257 |
+
losses2.append(lossMAE.detach().item())
|
258 |
+
# optimizer.step()
|
259 |
+
|
260 |
+
# if step % 1000 == 0:
|
261 |
+
# print('\tValidation - Epoch %d | Batch %d | MSE Loss %6.2f' % (epoch, step, loss.item()))
|
262 |
+
# print('\tValidation - Epoch %d | Batch %d | MAE Loss %6.2f' % (epoch, step, lossMAE.item()))
|
263 |
+
|
264 |
+
# print(y)
|
265 |
+
current_loss = sum(losses)/len(losses)
|
266 |
+
s = [f"validation - epoch {log_utils.stylize(epoch, log_utils.ANSI.YELLOW)}"]
|
267 |
+
s.append(f"avg MSE loss {log_utils.stylize(current_loss, log_utils.ANSI.GREEN, format_spec='.4f')}")
|
268 |
+
s.append(f"avg MAE loss {log_utils.stylize(sum(losses2)/len(losses2), log_utils.ANSI.YELLOW, format_spec='.4f')}")
|
269 |
+
progress_bar.write(' | '.join(s))
|
270 |
+
# progress_bar.write('validation - epoch %d | avg MSE loss %6.4f' % (epoch, sum(losses)/len(losses)))
|
271 |
+
# progress_bar.write('validation - epoch %d | avg MAE loss %6.4f' % (epoch, sum(losses2)/len(losses2)))
|
272 |
+
|
273 |
+
if save_best_model and current_loss < best_loss:
|
274 |
+
best_loss = current_loss
|
275 |
+
progress_bar.write(f"best MSE val loss ({log_utils.stylize(best_loss, log_utils.ANSI.BOLD, log_utils.ANSI.GREEN)}) so far. saving model...")
|
276 |
+
best_save_path = Path(save_path).parent / f"{Path(save_path).stem}_best-MSE{best_loss:.4f}{Path(save_path).suffix}"
|
277 |
+
train_utils.save_model(model, best_save_path, epoch=epoch)
|
278 |
+
progress_bar.write(f"model saved: `{save_path}`")
|
279 |
+
|
280 |
+
lr_scheduler.step()
|
281 |
+
accelerator.wait_for_everyone()
|
282 |
+
except KeyboardInterrupt:
|
283 |
+
log_utils.warn("KeyboardInterrupt")
|
284 |
+
if input(f"save model to {save_path}? [y/n]") == 'y':
|
285 |
+
save_on_end = True
|
286 |
+
else:
|
287 |
+
save_on_end = True
|
288 |
+
|
289 |
+
progress_bar.close()
|
290 |
+
model = accelerator.unwrap_model(model)
|
291 |
+
accelerator.wait_for_everyone()
|
292 |
+
|
293 |
+
if accelerator.is_main_process and save_on_end:
|
294 |
+
log_utils.info("saving model...")
|
295 |
+
train_utils.save_model(model, save_path)
|
296 |
+
log_utils.info(f"model saved: `{save_path}`")
|
297 |
+
|
298 |
+
del accelerator
|
299 |
+
|
300 |
+
log_utils.success(f"training done. best loss: {best_loss}")
|
301 |
+
|
302 |
+
# inferece test with dummy samples from the val set, sanity check
|
303 |
+
# log_utils.info("inference test with dummy samples from the val set, sanity check")
|
304 |
+
# model.eval()
|
305 |
+
# output = model(x[:5].to(device))
|
306 |
+
# log_utils.info(output.size())
|
307 |
+
# log_utils.info(output)
|
waifu_scorer/train_utils.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import h5py
|
4 |
+
import math
|
5 |
+
import random
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List, Callable, Tuple
|
9 |
+
from tqdm import tqdm
|
10 |
+
from PIL import Image
|
11 |
+
from waifuset.classes import Dataset, ImageInfo
|
12 |
+
from waifuset.utils import log_utils
|
13 |
+
from .utils import encode_images, load_clip_models, quality_rating
|
14 |
+
|
15 |
+
|
16 |
+
class LaionImageInfo:
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
img_path=None,
|
20 |
+
im_emb_arr=None,
|
21 |
+
rating=None,
|
22 |
+
im_emb_arr_flipped=None,
|
23 |
+
num_repeats=1,
|
24 |
+
):
|
25 |
+
self.img_path = img_path
|
26 |
+
self.im_emb_arr = im_emb_arr
|
27 |
+
self.rating = rating
|
28 |
+
self.im_emb_arr_flipped = im_emb_arr_flipped
|
29 |
+
self.num_repeats = num_repeats
|
30 |
+
|
31 |
+
|
32 |
+
class LaionDataset:
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
source,
|
36 |
+
cache_to_disk=True,
|
37 |
+
cache_path=None,
|
38 |
+
batch_size=1,
|
39 |
+
clip_batch_size=4,
|
40 |
+
model2=None,
|
41 |
+
preprocess=None,
|
42 |
+
input_size=768,
|
43 |
+
rating_func: Callable = quality_rating,
|
44 |
+
repeating_func: Callable = None,
|
45 |
+
shuffle=True,
|
46 |
+
flip_aug: bool = True,
|
47 |
+
device='cuda'
|
48 |
+
):
|
49 |
+
if model2 is None or preprocess is None:
|
50 |
+
model2, preprocess = load_clip_models(device) # RN50x64
|
51 |
+
if cache_to_disk and cache_path is None:
|
52 |
+
raise ValueError("cache_path must be specified when cache_to_disk is True.")
|
53 |
+
self.source = source
|
54 |
+
self.cache_to_disk = cache_to_disk
|
55 |
+
self.cache_path = Path(cache_path)
|
56 |
+
self.model2, self.preprocess = model2, preprocess
|
57 |
+
self.input_size = input_size
|
58 |
+
self.rating_func = rating_func
|
59 |
+
self.batch_size = batch_size
|
60 |
+
self.encoder_batch_size = clip_batch_size
|
61 |
+
self.shuffle = shuffle
|
62 |
+
self.flip_aug = flip_aug
|
63 |
+
self.device = device
|
64 |
+
|
65 |
+
dataset: Dataset = Dataset(source, verbose=True)
|
66 |
+
|
67 |
+
self.image_data = []
|
68 |
+
|
69 |
+
for img_key, img_info in tqdm(dataset.items(), desc='prepare dataset'):
|
70 |
+
img_path = img_info.image_path
|
71 |
+
rating = self.rating_func(img_info)
|
72 |
+
laion_image_info = LaionImageInfo(
|
73 |
+
img_path=img_path,
|
74 |
+
rating=rating,
|
75 |
+
)
|
76 |
+
self.register_image_info(laion_image_info)
|
77 |
+
|
78 |
+
rating_counter = {}
|
79 |
+
for laion_img_info in tqdm(self.image_data, desc='calculating num repeats (1/2)'):
|
80 |
+
# to list
|
81 |
+
rating: torch.Tensor = laion_img_info.rating
|
82 |
+
rating_counter.setdefault(rating, 0)
|
83 |
+
rating_counter[rating] += 1
|
84 |
+
|
85 |
+
for laion_img_info in tqdm(self.image_data, desc='calculating num repeats (2/2)'):
|
86 |
+
benchmark = 30000
|
87 |
+
num_repeats = benchmark / rating_counter[laion_img_info.rating]
|
88 |
+
prob = num_repeats - math.floor(num_repeats)
|
89 |
+
num_repeats = math.floor(num_repeats) if random.random() < prob else math.ceil(num_repeats)
|
90 |
+
laion_img_info.num_repeats = max(1, num_repeats)
|
91 |
+
|
92 |
+
self.cache_embs()
|
93 |
+
self.batches = self.make_batches()
|
94 |
+
|
95 |
+
def register_image_info(self, image_info: LaionImageInfo):
|
96 |
+
self.image_data.append(image_info)
|
97 |
+
|
98 |
+
def cache_embs(self):
|
99 |
+
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
100 |
+
|
101 |
+
not_cached = [] # list of (image_info, flipped)
|
102 |
+
num_cached = 0
|
103 |
+
|
104 |
+
# load cache
|
105 |
+
if self.cache_to_disk:
|
106 |
+
pbar = tqdm(total=len(self.image_data), desc='loading cache')
|
107 |
+
|
108 |
+
def load_cached_emb(h5, image_info: LaionImageInfo, flipped=False):
|
109 |
+
nonlocal num_cached
|
110 |
+
image_key = image_info.img_path.stem
|
111 |
+
if flipped:
|
112 |
+
image_key = image_key + '_flipped'
|
113 |
+
if image_key in h5:
|
114 |
+
im_emb_arr = torch.from_numpy(f[image_key][:])
|
115 |
+
if im_emb_arr.shape[-1] != self.input_size:
|
116 |
+
raise ValueError(f"Input size mismatched. Except {self.input_size} dim, but got {im_emb_arr.shape[-1]} dim loaded. Please check your cache file.")
|
117 |
+
assert im_emb_arr.device == torch.device('cpu'), "flipped image emb should be on cpu"
|
118 |
+
if flipped:
|
119 |
+
image_info.im_emb_arr_flipped = im_emb_arr
|
120 |
+
else:
|
121 |
+
image_info.im_emb_arr = im_emb_arr
|
122 |
+
num_cached += 1
|
123 |
+
else:
|
124 |
+
not_cached.append((image_info, flipped))
|
125 |
+
|
126 |
+
if not is_h5_file(self.cache_path):
|
127 |
+
# create cache
|
128 |
+
log_utils.info(f"cache file not found, creating new cache file: {self.cache_path}")
|
129 |
+
with h5py.File(self.cache_path, 'w') as f:
|
130 |
+
pass
|
131 |
+
else:
|
132 |
+
log_utils.info(f"loading cache file: {self.cache_path}")
|
133 |
+
with h5py.File(self.cache_path, 'r') as f:
|
134 |
+
for image_info in self.image_data:
|
135 |
+
load_cached_emb(f, image_info, flipped=False)
|
136 |
+
if self.flip_aug:
|
137 |
+
load_cached_emb(f, image_info, flipped=True)
|
138 |
+
pbar.update()
|
139 |
+
pbar.close()
|
140 |
+
else:
|
141 |
+
not_cached = [(image_info, False) for image_info in self.image_data]
|
142 |
+
if self.flip_aug:
|
143 |
+
not_cached += [(image_info, True) for image_info in self.image_data]
|
144 |
+
|
145 |
+
# encode not-cached images
|
146 |
+
if len(not_cached) == 0:
|
147 |
+
log_utils.info("all images are cached.")
|
148 |
+
else:
|
149 |
+
log_utils.info(f"number of cached instances: {num_cached}")
|
150 |
+
log_utils.info(f"number of not cached instances: {len(not_cached)}")
|
151 |
+
|
152 |
+
batches = [not_cached[i:i + self.encoder_batch_size] for i in range(0, len(not_cached), self.encoder_batch_size)]
|
153 |
+
pbar = tqdm(total=len(batches), desc='encoding images')
|
154 |
+
|
155 |
+
def cache_batch_embs(h5, batch: List[Tuple[LaionImageInfo, bool]]):
|
156 |
+
try:
|
157 |
+
images = [Image.open(image_info.img_path) if not flipped else Image.open(image_info.img_path).transpose(Image.FLIP_LEFT_RIGHT) for image_info, flipped in batch]
|
158 |
+
except:
|
159 |
+
log_utils.error(f"Error occurred when loading one of the images: {[image_info.img_path for image_info, flipped in batch]}")
|
160 |
+
raise
|
161 |
+
im_emb_arrs = encode_images(images, self.model2, self.preprocess, device=self.device) # shape: [batch_size, input_size]
|
162 |
+
for i, item in enumerate(batch):
|
163 |
+
image_info, flipped = item
|
164 |
+
im_emb_arr = im_emb_arrs[i]
|
165 |
+
shape_size = len(im_emb_arr.shape)
|
166 |
+
if shape_size == 1:
|
167 |
+
im_emb_arr = im_emb_arr.unsqueeze(0)
|
168 |
+
elif shape_size == 3:
|
169 |
+
im_emb_arr = im_emb_arr.squeeze(1)
|
170 |
+
|
171 |
+
image_key = image_info.img_path.stem
|
172 |
+
assert im_emb_arr.device == torch.device('cpu'), "flipped image emb should be on cpu"
|
173 |
+
if flipped:
|
174 |
+
image_key = image_key + '_flipped'
|
175 |
+
image_info.im_emb_arr_flipped = im_emb_arr
|
176 |
+
else:
|
177 |
+
image_info.im_emb_arr = im_emb_arr
|
178 |
+
|
179 |
+
if self.cache_to_disk:
|
180 |
+
if image_key in h5:
|
181 |
+
continue
|
182 |
+
h5.create_dataset(image_key, data=im_emb_arr.cpu().numpy())
|
183 |
+
|
184 |
+
try:
|
185 |
+
h5 = h5py.File(self.cache_path, 'a') if self.cache_to_disk else None
|
186 |
+
for batch in batches:
|
187 |
+
cache_batch_embs(h5, batch)
|
188 |
+
pbar.update()
|
189 |
+
finally:
|
190 |
+
if h5:
|
191 |
+
h5.close()
|
192 |
+
pbar.close()
|
193 |
+
|
194 |
+
def make_batches(self):
|
195 |
+
batches = []
|
196 |
+
repeated_image_data = []
|
197 |
+
for image_info in self.image_data:
|
198 |
+
repeated_image_data += [image_info] * image_info.num_repeats
|
199 |
+
log_utils.info(f"number of instances (repeated): {len(repeated_image_data)}")
|
200 |
+
for i in range(0, len(repeated_image_data), self.batch_size):
|
201 |
+
batch = repeated_image_data[i:i + self.batch_size]
|
202 |
+
batches.append(batch)
|
203 |
+
if self.shuffle:
|
204 |
+
random.shuffle(batches)
|
205 |
+
return batches
|
206 |
+
|
207 |
+
def __getitem__(self, index):
|
208 |
+
batch = self.batches[index]
|
209 |
+
im_emb_arrs = []
|
210 |
+
ratings = []
|
211 |
+
for image_info in batch:
|
212 |
+
flip = self.flip_aug and random.random() > 0.5
|
213 |
+
if not flip:
|
214 |
+
im_emb_arr = image_info.im_emb_arr
|
215 |
+
else:
|
216 |
+
im_emb_arr = image_info.im_emb_arr_flipped
|
217 |
+
rating = image_info.rating
|
218 |
+
|
219 |
+
im_emb_arrs.append(im_emb_arr)
|
220 |
+
ratings.append(rating)
|
221 |
+
|
222 |
+
im_emb_arrs = torch.cat(im_emb_arrs, dim=0)
|
223 |
+
ratings = torch.tensor(ratings).unsqueeze(-1)
|
224 |
+
sample = dict(
|
225 |
+
im_emb_arrs=im_emb_arrs,
|
226 |
+
ratings=ratings,
|
227 |
+
)
|
228 |
+
return sample
|
229 |
+
|
230 |
+
def __len__(self):
|
231 |
+
return len(self.batches)
|
232 |
+
|
233 |
+
|
234 |
+
def collate_fn(batch):
|
235 |
+
return batch[0]
|
236 |
+
|
237 |
+
|
238 |
+
def get_rating_func(rating_func_type: str):
|
239 |
+
if rating_func_type == 'quality':
|
240 |
+
from .utils import quality_rating
|
241 |
+
rating_func = quality_rating
|
242 |
+
else:
|
243 |
+
raise ValueError(f"Invalid rating type: {rating_func_type}")
|
244 |
+
return rating_func
|
245 |
+
|
246 |
+
|
247 |
+
def prepare_dataloader(
|
248 |
+
dataset_source,
|
249 |
+
cache_to_disk=True,
|
250 |
+
cache_path=None,
|
251 |
+
batch_size=1,
|
252 |
+
clip_batch_size=4,
|
253 |
+
model2=None,
|
254 |
+
preprocess=None,
|
255 |
+
input_size=768,
|
256 |
+
rating_func: Callable = quality_rating,
|
257 |
+
shuffle=True,
|
258 |
+
flip_aug: bool = True,
|
259 |
+
device='cuda',
|
260 |
+
persistent_workers=False,
|
261 |
+
max_data_loader_n_workers=0,
|
262 |
+
):
|
263 |
+
dataset = LaionDataset(
|
264 |
+
dataset_source,
|
265 |
+
cache_to_disk=cache_to_disk,
|
266 |
+
cache_path=cache_path,
|
267 |
+
batch_size=batch_size,
|
268 |
+
clip_batch_size=clip_batch_size,
|
269 |
+
model2=model2,
|
270 |
+
preprocess=preprocess,
|
271 |
+
input_size=input_size,
|
272 |
+
rating_func=rating_func,
|
273 |
+
shuffle=shuffle,
|
274 |
+
flip_aug=flip_aug,
|
275 |
+
device=device,
|
276 |
+
)
|
277 |
+
|
278 |
+
dataloader = DataLoader(
|
279 |
+
dataset,
|
280 |
+
batch_size=1, # fix to 1
|
281 |
+
shuffle=shuffle,
|
282 |
+
num_workers=max_data_loader_n_workers,
|
283 |
+
persistent_workers=persistent_workers,
|
284 |
+
collate_fn=collate_fn,
|
285 |
+
)
|
286 |
+
|
287 |
+
return dataset, dataloader
|
288 |
+
|
289 |
+
|
290 |
+
def is_h5_file(cache_path):
|
291 |
+
if not cache_path or not h5py.is_hdf5(cache_path):
|
292 |
+
return False
|
293 |
+
return True
|
294 |
+
|
295 |
+
|
296 |
+
# def make_train_data(
|
297 |
+
# dataset_source,
|
298 |
+
# rating_func: Callable = quality_rating,
|
299 |
+
# batch_size=1,
|
300 |
+
# flip_aug: bool = True,
|
301 |
+
# device='cuda'
|
302 |
+
# ):
|
303 |
+
# model2, preprocess = clip.load("ViT-L/14", device=device) # RN50x64
|
304 |
+
# dataset = Dataset.from_source(dataset_source, verbose=True)
|
305 |
+
# x_train = []
|
306 |
+
# y_train = []
|
307 |
+
# batches = [dataset[i:i + batch_size] for i in range(0, len(dataset), batch_size)]
|
308 |
+
# for batch in tqdm(batches, desc='encoding images', smoothing=1):
|
309 |
+
# im_emb_arr = encode_images([d.pil_img for d in batch], model2, preprocess, device=device) # shape: [batch_size, 768]
|
310 |
+
# ratings = torch.tensor([rating_func(data) for data in batch]).unsqueeze(-1).to(device) # shape: [batch_size, 1]
|
311 |
+
# x_train.append(im_emb_arr)
|
312 |
+
# y_train.append(ratings)
|
313 |
+
# x_train = torch.cat(x_train, dim=0)
|
314 |
+
# y_train = torch.cat(y_train, dim=0)
|
315 |
+
# return x_train, y_train
|
316 |
+
|
317 |
+
|
318 |
+
def prepare_dtype(mixed_precision: str):
|
319 |
+
weight_dtype = torch.float32
|
320 |
+
if mixed_precision == "fp16":
|
321 |
+
weight_dtype = torch.float16
|
322 |
+
elif mixed_precision == "bf16":
|
323 |
+
weight_dtype = torch.bfloat16
|
324 |
+
return weight_dtype
|
325 |
+
|
326 |
+
|
327 |
+
def save_model(model, save_path, epoch=None):
|
328 |
+
save_path = str(save_path)
|
329 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
330 |
+
if epoch is not None:
|
331 |
+
save_path = save_path.replace('.pth', f'_ep{epoch}.pth')
|
332 |
+
torch.save(model.state_dict(), save_path)
|
333 |
+
return save_path
|
waifu_scorer/ui.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
|
4 |
+
|
5 |
+
def parse_args():
|
6 |
+
parser = ArgumentParser()
|
7 |
+
parser.add_argument(
|
8 |
+
'--model_path',
|
9 |
+
type=str,
|
10 |
+
default='./model/v3.pth',
|
11 |
+
help='Path or url to the model file',
|
12 |
+
)
|
13 |
+
parser.add_argument(
|
14 |
+
'--model_type',
|
15 |
+
type=str,
|
16 |
+
default='mlp',
|
17 |
+
help='Type of the model',
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
'--fix_model_path',
|
21 |
+
action='store_true',
|
22 |
+
help='Fix the model path',
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
'--device',
|
26 |
+
type=str,
|
27 |
+
default='cuda',
|
28 |
+
help='Device to use',
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
'--share',
|
32 |
+
action='store_true',
|
33 |
+
help='Share the demo',
|
34 |
+
)
|
35 |
+
return parser.parse_args()
|
36 |
+
|
37 |
+
|
38 |
+
def ui(args):
|
39 |
+
from waifu_scorer.predict import WaifuScorer, load_model
|
40 |
+
scorer = WaifuScorer(
|
41 |
+
model_path=args.model_path,
|
42 |
+
model_type=args.model_type,
|
43 |
+
device=args.device,
|
44 |
+
)
|
45 |
+
|
46 |
+
with gr.Blocks() as demo:
|
47 |
+
with gr.Row():
|
48 |
+
with gr.Column():
|
49 |
+
image = gr.Image(
|
50 |
+
label='Image',
|
51 |
+
type='pil',
|
52 |
+
height=512,
|
53 |
+
sources=['upload', 'clipboard'],
|
54 |
+
)
|
55 |
+
with gr.Column():
|
56 |
+
with gr.Row():
|
57 |
+
model_path = gr.Textbox(
|
58 |
+
label='Model Path',
|
59 |
+
value=args.model_path,
|
60 |
+
placeholder='Path or URL to the model file',
|
61 |
+
interactive=not args.fix_model_path,
|
62 |
+
)
|
63 |
+
with gr.Row():
|
64 |
+
score = gr.Number(
|
65 |
+
label='Score',
|
66 |
+
)
|
67 |
+
|
68 |
+
def change_model(model_path):
|
69 |
+
nonlocal scorer
|
70 |
+
scorer.mlp = load_model(model_path, model_type=args.model_type, device=args.device)
|
71 |
+
print(f"Model changed to `{model_path}`")
|
72 |
+
return gr.update()
|
73 |
+
|
74 |
+
model_path.submit(
|
75 |
+
fn=change_model,
|
76 |
+
inputs=model_path,
|
77 |
+
outputs=model_path,
|
78 |
+
)
|
79 |
+
|
80 |
+
image.change(
|
81 |
+
fn=lambda image: scorer.predict([image]*2)[0] if image is not None else None,
|
82 |
+
inputs=image,
|
83 |
+
outputs=score,
|
84 |
+
)
|
85 |
+
|
86 |
+
return demo
|
87 |
+
|
88 |
+
|
89 |
+
def launch(args):
|
90 |
+
demo = ui(args)
|
91 |
+
demo.launch(share=args.share)
|
waifu_scorer/utils.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import clip
|
3 |
+
from PIL import Image
|
4 |
+
from typing import List, Union
|
5 |
+
from . import mlp
|
6 |
+
|
7 |
+
QUALITY_TO_RATING = {
|
8 |
+
'amazing': 10,
|
9 |
+
'best': 8.5,
|
10 |
+
'high': 7,
|
11 |
+
'normal': 5,
|
12 |
+
'low': 2.5,
|
13 |
+
'worst': 0,
|
14 |
+
'horrible': 0,
|
15 |
+
}
|
16 |
+
|
17 |
+
MODEL_TYPE = {
|
18 |
+
'mlp': mlp.MLP,
|
19 |
+
'res_mlp': mlp.ResMLP,
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
def quality_rating(img_info):
|
24 |
+
quality = (img_info.caption.quality or 'normal') if img_info.caption is not None else 'normal'
|
25 |
+
rating = QUALITY_TO_RATING[quality]
|
26 |
+
return rating
|
27 |
+
|
28 |
+
|
29 |
+
def get_model_cls(model_type) -> Union[mlp.MLP, None]:
|
30 |
+
return MODEL_TYPE.get(model_type, mlp.MLP)
|
31 |
+
|
32 |
+
|
33 |
+
def load_clip_models(name: str = "ViT-L/14", device='cuda'):
|
34 |
+
model2, preprocess = clip.load(name, device=device) # RN50x64
|
35 |
+
return model2, preprocess
|
36 |
+
|
37 |
+
|
38 |
+
def load_model(model_path: str = None, model_type=None, input_size=768, batch_norm: bool = True, device: str = 'cuda', dtype=None):
|
39 |
+
model_cls = get_model_cls(model_type)
|
40 |
+
print(f"Loading model from class `{model_cls}`...")
|
41 |
+
model_kwargs = {}
|
42 |
+
if model_type in ('large', 'res_large'):
|
43 |
+
model_kwargs['batch_norm'] = True
|
44 |
+
model = model_cls(input_size, **model_kwargs)
|
45 |
+
if model_path:
|
46 |
+
try:
|
47 |
+
s = torch.load(model_path, map_location=device)
|
48 |
+
model.load_state_dict(s)
|
49 |
+
except Exception as e:
|
50 |
+
print(f"Model type mismatch. Desired model type: `{model_type}` (model class: `{model_cls}`).")
|
51 |
+
raise e
|
52 |
+
model.to(device)
|
53 |
+
if dtype:
|
54 |
+
model = model.to(dtype=dtype)
|
55 |
+
return model
|
56 |
+
|
57 |
+
|
58 |
+
def normalized(a: torch.Tensor, order=2, dim=-1):
|
59 |
+
l2 = a.norm(order, dim, keepdim=True)
|
60 |
+
l2[l2 == 0] = 1
|
61 |
+
return a / l2
|
62 |
+
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
def encode_images(images: List[Image.Image], model2, preprocess, device='cuda') -> torch.Tensor:
|
66 |
+
if isinstance(images, Image.Image):
|
67 |
+
images = [images]
|
68 |
+
image_tensors = [preprocess(img).unsqueeze(0) for img in images]
|
69 |
+
image_batch = torch.cat(image_tensors).to(device)
|
70 |
+
image_features = model2.encode_image(image_batch)
|
71 |
+
im_emb_arr = normalized(image_features).cpu().float()
|
72 |
+
return im_emb_arr
|