Spaces:
Runtime error
Runtime error
Commit
·
55d6a83
1
Parent(s):
a581500
match multimodlar
Browse files- __pycache__/safety_checker_improved.cpython-310.pyc +0 -0
- app.py +144 -65
- eigth.mp4 +0 -0
- ninth.mp4 +0 -0
- seventh.mp4 +0 -0
- tenth.mp4 +0 -0
__pycache__/safety_checker_improved.cpython-310.pyc
DELETED
|
Binary file (1.38 kB)
|
|
|
app.py
CHANGED
|
@@ -15,14 +15,12 @@ import matplotlib.pyplot as plt
|
|
| 15 |
import matplotlib
|
| 16 |
import logging
|
| 17 |
|
| 18 |
-
from sklearn.linear_model import Ridge
|
| 19 |
|
| 20 |
import os
|
| 21 |
import imageio
|
| 22 |
import gradio as gr
|
| 23 |
import numpy as np
|
| 24 |
from sklearn.svm import SVC
|
| 25 |
-
from sklearn.inspection import permutation_importance
|
| 26 |
from sklearn import preprocessing
|
| 27 |
import pandas as pd
|
| 28 |
from apscheduler.schedulers.background import BackgroundScheduler
|
|
@@ -39,14 +37,13 @@ torch.set_grad_enabled(False)
|
|
| 39 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 40 |
torch.backends.cudnn.allow_tf32 = True
|
| 41 |
|
| 42 |
-
prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id'])
|
| 43 |
|
| 44 |
import spaces
|
| 45 |
start_time = time.time()
|
| 46 |
|
| 47 |
####################### Setup Model
|
| 48 |
-
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
|
| 49 |
-
utils.logging.disable_progress_bar
|
| 50 |
from transformers import CLIPTextModel
|
| 51 |
from huggingface_hub import hf_hub_download
|
| 52 |
from safetensors.torch import load_file
|
|
@@ -54,6 +51,7 @@ from PIL import Image
|
|
| 54 |
from transformers import CLIPVisionModelWithProjection
|
| 55 |
import uuid
|
| 56 |
import av
|
|
|
|
| 57 |
|
| 58 |
def write_video(file_name, images, fps=17):
|
| 59 |
container = av.open(file_name, mode="w")
|
|
@@ -92,6 +90,9 @@ device_map='cuda')
|
|
| 92 |
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
| 93 |
# vae = compile_unet(vae, config=config)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
|
|
@@ -99,7 +100,8 @@ text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='
|
|
| 99 |
device_map='cpu').to(dtype)
|
| 100 |
|
| 101 |
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
| 102 |
-
pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype,
|
|
|
|
| 103 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
| 104 |
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
| 105 |
pipe.set_adapters(["lcm-lora"], [.9])
|
|
@@ -114,7 +116,7 @@ pipe.fuse_lora()
|
|
| 114 |
|
| 115 |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
|
| 116 |
# This IP adapter improves outputs substantially.
|
| 117 |
-
pipe.set_ip_adapter_scale(.
|
| 118 |
pipe.unet.fuse_qkv_projections()
|
| 119 |
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
| 120 |
|
|
@@ -122,21 +124,71 @@ pipe.to(device=DEVICE)
|
|
| 122 |
#pipe.unet = torch.compile(pipe.unet)
|
| 123 |
#pipe.vae = torch.compile(pipe.vae)
|
| 124 |
|
| 125 |
-
@spaces.GPU()
|
| 126 |
-
def generate_gpu(in_im_embs):
|
| 127 |
-
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
| 128 |
-
output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
| 129 |
-
im_emb, _ = pipe.encode_image(
|
| 130 |
-
output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
|
| 131 |
-
)
|
| 132 |
-
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
| 133 |
-
return output, im_emb
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
name = str(uuid.uuid4()).replace("-", "")
|
| 141 |
path = f"/tmp/{name}.mp4"
|
| 142 |
|
|
@@ -149,19 +201,19 @@ def generate(in_im_embs):
|
|
| 149 |
output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
|
| 150 |
|
| 151 |
write_video(path, output.frames[0])
|
| 152 |
-
return path, im_emb
|
| 153 |
|
| 154 |
|
| 155 |
#######################
|
| 156 |
|
| 157 |
def get_user_emb(embs, ys):
|
| 158 |
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
|
|
|
| 159 |
if len(list(ys)) <= 7:
|
| 160 |
-
aways = [.01*torch.
|
| 161 |
embs += aways
|
| 162 |
awal = [0 for i in range(3)]
|
| 163 |
ys += awal
|
| 164 |
-
print('Fixing only one feedback class available.\n')
|
| 165 |
|
| 166 |
indices = list(range(len(embs)))
|
| 167 |
# sample only as many negatives as there are positives
|
|
@@ -176,21 +228,20 @@ def get_user_emb(embs, ys):
|
|
| 176 |
# this ends up adding a rating but losing an embedding, it seems.
|
| 177 |
# let's take off a rating if so to continue without indexing errors.
|
| 178 |
if len(ys) > len(embs):
|
|
|
|
| 179 |
ys.pop(-1)
|
| 180 |
|
| 181 |
feature_embs = torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu')
|
| 182 |
#scaler = preprocessing.StandardScaler().fit(feature_embs)
|
| 183 |
#feature_embs = scaler.transform(feature_embs)
|
| 184 |
-
|
| 185 |
|
| 186 |
if feature_embs.norm() != 0:
|
| 187 |
feature_embs = feature_embs / feature_embs.norm()
|
| 188 |
|
| 189 |
-
chosen_y = np.array([ys[i] for i in indices])
|
| 190 |
-
|
| 191 |
#lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
|
| 192 |
-
lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y)
|
| 193 |
-
coef_ = torch.tensor(lin_class.coef_, dtype=torch.
|
| 194 |
coef_ = coef_ / coef_.abs().max() * 3
|
| 195 |
|
| 196 |
w = 1# if len(embs) % 2 == 0 else 0
|
|
@@ -212,7 +263,8 @@ def pluck_img(user_id, user_emb):
|
|
| 212 |
best_sim = sim
|
| 213 |
best_row = i[1]
|
| 214 |
img = best_row['paths']
|
| 215 |
-
|
|
|
|
| 216 |
|
| 217 |
|
| 218 |
def background_next_image():
|
|
@@ -236,39 +288,48 @@ def background_next_image():
|
|
| 236 |
unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]]
|
| 237 |
rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
|
| 238 |
|
| 239 |
-
# we pop previous ratings if there are >
|
| 240 |
-
if len(rated_from_user) >=
|
| 241 |
oldest = rated_from_user.iloc[0]['paths']
|
| 242 |
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
| 243 |
-
# we don't compute more after
|
| 244 |
if len(unrated_from_user) >= 10:
|
| 245 |
continue
|
| 246 |
|
| 247 |
-
if len(rated_rows) <
|
| 248 |
continue
|
| 249 |
|
| 250 |
-
embs, ys = pluck_embs_ys(uid)
|
| 251 |
|
| 252 |
user_emb = get_user_emb(embs, ys)
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
if img:
|
| 255 |
-
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
|
| 256 |
tmp_df['paths'] = [img]
|
| 257 |
tmp_df['embeddings'] = [embs]
|
| 258 |
tmp_df['user:rating'] = [{' ': ' '}]
|
| 259 |
tmp_df['from_user_id'] = [uid]
|
|
|
|
|
|
|
| 260 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 261 |
-
|
| 262 |
# we can free up storage by deleting the image
|
| 263 |
-
if len(prevs_df) >
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
# only keep
|
| 271 |
-
prevs_df = prevs_df[prevs_df[
|
|
|
|
| 272 |
|
| 273 |
def pluck_embs_ys(user_id):
|
| 274 |
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
|
@@ -281,21 +342,21 @@ def pluck_embs_ys(user_id):
|
|
| 281 |
|
| 282 |
embs = rated_rows['embeddings'].to_list()
|
| 283 |
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
| 284 |
-
|
|
|
|
| 285 |
|
| 286 |
def next_image(calibrate_prompts, user_id):
|
| 287 |
-
|
| 288 |
with torch.no_grad():
|
| 289 |
if len(calibrate_prompts) > 0:
|
| 290 |
cal_video = calibrate_prompts.pop(0)
|
| 291 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
| 292 |
|
| 293 |
-
return image, calibrate_prompts
|
| 294 |
else:
|
| 295 |
-
embs, ys = pluck_embs_ys(user_id)
|
| 296 |
user_emb = get_user_emb(embs, ys)
|
| 297 |
-
image = pluck_img(user_id, user_emb)
|
| 298 |
-
return image, calibrate_prompts
|
| 299 |
|
| 300 |
|
| 301 |
|
|
@@ -307,7 +368,7 @@ def next_image(calibrate_prompts, user_id):
|
|
| 307 |
|
| 308 |
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
| 309 |
user_id = int(str(time.time())[-7:].replace('.', ''))
|
| 310 |
-
image, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
| 311 |
return [
|
| 312 |
gr.Button(value='Like (L)', interactive=True),
|
| 313 |
gr.Button(value='Neither (Space)', interactive=True, visible=False),
|
|
@@ -326,14 +387,15 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
|
| 326 |
if choice == 'Like (L)':
|
| 327 |
choice = 1
|
| 328 |
elif choice == 'Neither (Space)':
|
| 329 |
-
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
| 330 |
-
return img, calibrate_prompts
|
| 331 |
else:
|
| 332 |
choice = 0
|
| 333 |
|
| 334 |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
| 335 |
# TODO skip allowing rating & just continue
|
| 336 |
if img == None:
|
|
|
|
| 337 |
choice = 0
|
| 338 |
|
| 339 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
|
@@ -341,8 +403,8 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
|
| 341 |
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
| 342 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
| 343 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
| 344 |
-
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
| 345 |
-
return img, calibrate_prompts
|
| 346 |
|
| 347 |
css = '''.gradio-container{max-width: 700px !important}
|
| 348 |
#description{text-align: center}
|
|
@@ -426,6 +488,8 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
| 426 |
elem_id="video_output"
|
| 427 |
)
|
| 428 |
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
|
|
|
|
|
|
| 429 |
with gr.Row(equal_height=True):
|
| 430 |
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
|
| 431 |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
|
|
@@ -433,17 +497,17 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
| 433 |
b1.click(
|
| 434 |
choose,
|
| 435 |
[img, b1, calibrate_prompts, user_id],
|
| 436 |
-
[img, calibrate_prompts],
|
| 437 |
)
|
| 438 |
b2.click(
|
| 439 |
choose,
|
| 440 |
[img, b2, calibrate_prompts, user_id],
|
| 441 |
-
[img, calibrate_prompts],
|
| 442 |
)
|
| 443 |
b3.click(
|
| 444 |
choose,
|
| 445 |
[img, b3, calibrate_prompts, user_id],
|
| 446 |
-
[img, calibrate_prompts],
|
| 447 |
)
|
| 448 |
with gr.Row():
|
| 449 |
b4 = gr.Button(value='Start')
|
|
@@ -464,20 +528,28 @@ log = logging.getLogger('log_here')
|
|
| 464 |
log.setLevel(logging.ERROR)
|
| 465 |
|
| 466 |
scheduler = BackgroundScheduler()
|
| 467 |
-
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.
|
| 468 |
scheduler.start()
|
| 469 |
|
| 470 |
#thread = threading.Thread(target=background_next_image,)
|
| 471 |
#thread.start()
|
| 472 |
|
|
|
|
| 473 |
@spaces.GPU()
|
| 474 |
def encode_space(x):
|
| 475 |
im_emb, _ = pipe.encode_image(
|
| 476 |
image, DEVICE, 1, output_hidden_state
|
| 477 |
)
|
| 478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
|
| 480 |
-
# prep our calibration
|
| 481 |
for im in [
|
| 482 |
'./first.mp4',
|
| 483 |
'./second.mp4',
|
|
@@ -485,16 +557,23 @@ for im in [
|
|
| 485 |
'./fourth.mp4',
|
| 486 |
'./fifth.mp4',
|
| 487 |
'./sixth.mp4',
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
]:
|
| 489 |
-
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating'])
|
| 490 |
tmp_df['paths'] = [im]
|
| 491 |
image = list(imageio.imiter(im))
|
| 492 |
image = image[len(image)//2]
|
| 493 |
-
im_emb = encode_space(image)
|
| 494 |
|
| 495 |
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
|
|
|
| 496 |
tmp_df['user:rating'] = [{' ': ' '}]
|
| 497 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 498 |
|
| 499 |
|
| 500 |
-
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
| 15 |
import matplotlib
|
| 16 |
import logging
|
| 17 |
|
|
|
|
| 18 |
|
| 19 |
import os
|
| 20 |
import imageio
|
| 21 |
import gradio as gr
|
| 22 |
import numpy as np
|
| 23 |
from sklearn.svm import SVC
|
|
|
|
| 24 |
from sklearn import preprocessing
|
| 25 |
import pandas as pd
|
| 26 |
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
|
| 37 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 38 |
torch.backends.cudnn.allow_tf32 = True
|
| 39 |
|
| 40 |
+
prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id', 'text', 'gemb'])
|
| 41 |
|
| 42 |
import spaces
|
| 43 |
start_time = time.time()
|
| 44 |
|
| 45 |
####################### Setup Model
|
| 46 |
+
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
|
|
|
|
| 47 |
from transformers import CLIPTextModel
|
| 48 |
from huggingface_hub import hf_hub_download
|
| 49 |
from safetensors.torch import load_file
|
|
|
|
| 51 |
from transformers import CLIPVisionModelWithProjection
|
| 52 |
import uuid
|
| 53 |
import av
|
| 54 |
+
import torchvision
|
| 55 |
|
| 56 |
def write_video(file_name, images, fps=17):
|
| 57 |
container = av.open(file_name, mode="w")
|
|
|
|
| 90 |
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
| 91 |
# vae = compile_unet(vae, config=config)
|
| 92 |
|
| 93 |
+
#finetune_path = '''/home/ryn_mote/Misc/finetune-sd1.5/dreambooth-model best'''''
|
| 94 |
+
#unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
|
| 95 |
+
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
| 96 |
|
| 97 |
|
| 98 |
unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
|
|
|
|
| 100 |
device_map='cpu').to(dtype)
|
| 101 |
|
| 102 |
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
| 103 |
+
pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype,
|
| 104 |
+
unet=unet, text_encoder=text_encoder)
|
| 105 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
| 106 |
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
| 107 |
pipe.set_adapters(["lcm-lora"], [.9])
|
|
|
|
| 116 |
|
| 117 |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
|
| 118 |
# This IP adapter improves outputs substantially.
|
| 119 |
+
pipe.set_ip_adapter_scale(.6)
|
| 120 |
pipe.unet.fuse_qkv_projections()
|
| 121 |
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
| 122 |
|
|
|
|
| 124 |
#pipe.unet = torch.compile(pipe.unet)
|
| 125 |
#pipe.vae = torch.compile(pipe.vae)
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
#############################################################
|
| 129 |
+
|
| 130 |
+
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
|
| 131 |
+
|
| 132 |
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
| 133 |
+
pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-pt-224', torch_dtype=dtype, quantization_config=quantization_config).eval()
|
| 134 |
+
processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
|
| 135 |
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
|
| 139 |
+
inputs_embeds = pali.get_input_embeddings()(input_ids)
|
| 140 |
+
selected_image_feature = image_outputs.to(dtype).to(device)
|
| 141 |
+
image_features = pali.multi_modal_projector(selected_image_feature)
|
| 142 |
+
|
| 143 |
+
if cache_position is None:
|
| 144 |
+
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
| 145 |
+
inputs_embeds, attention_mask, labels, position_ids = pali._merge_input_ids_with_image_features(
|
| 146 |
+
image_features, inputs_embeds, input_ids, attention_mask, None, None, cache_position
|
| 147 |
+
)
|
| 148 |
+
return inputs_embeds
|
| 149 |
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def generate_pali(user_emb):
|
| 153 |
+
prompt = 'caption en'
|
| 154 |
+
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
| 155 |
+
# we need to get im_embs taken in here.
|
| 156 |
+
input_len = model_inputs["input_ids"].shape[-1]
|
| 157 |
+
input_embeds = to_wanted_embs(user_emb.squeeze()[None, None, :].repeat(1, 256, 1),
|
| 158 |
+
model_inputs["input_ids"].to(device),
|
| 159 |
+
model_inputs["attention_mask"].to(device))
|
| 160 |
+
|
| 161 |
+
generation = pali.generate(max_new_tokens=100, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
| 162 |
+
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
| 163 |
+
return decoded
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
#############################################################
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@spaces.GPU()
|
| 173 |
+
def generate_gpu(in_im_embs, prompt='the scene'):
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
| 176 |
+
output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
| 177 |
+
im_emb, _ = pipe.encode_image(
|
| 178 |
+
output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
|
| 179 |
+
)
|
| 180 |
+
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
| 181 |
+
im = torchvision.transforms.ToTensor()(output.frames[0][len(output.frames[0])//2]).unsqueeze(0)
|
| 182 |
+
im = torch.nn.functional.interpolate(im, (224, 224))
|
| 183 |
+
im = (im - .5) * 2
|
| 184 |
+
gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32).mean(1)
|
| 185 |
+
return output, im_emb, gemb
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def generate(in_im_embs, prompt='the scene'):
|
| 189 |
+
output, im_emb, gemb = generate_gpu(in_im_embs, prompt)
|
| 190 |
+
nsfw =maybe_nsfw(output.frames[0][len(output.frames[0])//2])
|
| 191 |
+
print(prompt)
|
| 192 |
name = str(uuid.uuid4()).replace("-", "")
|
| 193 |
path = f"/tmp/{name}.mp4"
|
| 194 |
|
|
|
|
| 201 |
output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
|
| 202 |
|
| 203 |
write_video(path, output.frames[0])
|
| 204 |
+
return path, im_emb, gemb
|
| 205 |
|
| 206 |
|
| 207 |
#######################
|
| 208 |
|
| 209 |
def get_user_emb(embs, ys):
|
| 210 |
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
| 211 |
+
|
| 212 |
if len(list(ys)) <= 7:
|
| 213 |
+
aways = [.01*torch.randn_like(embs[0]) for i in range(3)]
|
| 214 |
embs += aways
|
| 215 |
awal = [0 for i in range(3)]
|
| 216 |
ys += awal
|
|
|
|
| 217 |
|
| 218 |
indices = list(range(len(embs)))
|
| 219 |
# sample only as many negatives as there are positives
|
|
|
|
| 228 |
# this ends up adding a rating but losing an embedding, it seems.
|
| 229 |
# let's take off a rating if so to continue without indexing errors.
|
| 230 |
if len(ys) > len(embs):
|
| 231 |
+
print('ys are longer than embs; popping latest rating')
|
| 232 |
ys.pop(-1)
|
| 233 |
|
| 234 |
feature_embs = torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu')
|
| 235 |
#scaler = preprocessing.StandardScaler().fit(feature_embs)
|
| 236 |
#feature_embs = scaler.transform(feature_embs)
|
| 237 |
+
chosen_y = np.array([ys[i] for i in indices])
|
| 238 |
|
| 239 |
if feature_embs.norm() != 0:
|
| 240 |
feature_embs = feature_embs / feature_embs.norm()
|
| 241 |
|
|
|
|
|
|
|
| 242 |
#lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
|
| 243 |
+
lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs.squeeze(), chosen_y)
|
| 244 |
+
coef_ = torch.tensor(lin_class.coef_, dtype=torch.float32).detach().to('cpu')
|
| 245 |
coef_ = coef_ / coef_.abs().max() * 3
|
| 246 |
|
| 247 |
w = 1# if len(embs) % 2 == 0 else 0
|
|
|
|
| 263 |
best_sim = sim
|
| 264 |
best_row = i[1]
|
| 265 |
img = best_row['paths']
|
| 266 |
+
text = best_row.get('text', '')
|
| 267 |
+
return img, text
|
| 268 |
|
| 269 |
|
| 270 |
def background_next_image():
|
|
|
|
| 288 |
unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]]
|
| 289 |
rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
|
| 290 |
|
| 291 |
+
# we pop previous ratings if there are > n
|
| 292 |
+
if len(rated_from_user) >= 15:
|
| 293 |
oldest = rated_from_user.iloc[0]['paths']
|
| 294 |
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
| 295 |
+
# we don't compute more after n are in the queue for them
|
| 296 |
if len(unrated_from_user) >= 10:
|
| 297 |
continue
|
| 298 |
|
| 299 |
+
if len(rated_rows) < 5:
|
| 300 |
continue
|
| 301 |
|
| 302 |
+
embs, ys, gembs = pluck_embs_ys(uid)
|
| 303 |
|
| 304 |
user_emb = get_user_emb(embs, ys)
|
| 305 |
+
|
| 306 |
+
if len(gembs) > 4:
|
| 307 |
+
user_gem = get_user_emb(gembs, ys) / 4 # TODO scale this correctly; matplotlib, etc.
|
| 308 |
+
text = generate_pali(user_gem)
|
| 309 |
+
else:
|
| 310 |
+
text = generate_pali(torch.zeros(1, 1152))
|
| 311 |
+
img, embs, new_gem = generate(user_emb, text)
|
| 312 |
+
|
| 313 |
if img:
|
| 314 |
+
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
|
| 315 |
tmp_df['paths'] = [img]
|
| 316 |
tmp_df['embeddings'] = [embs]
|
| 317 |
tmp_df['user:rating'] = [{' ': ' '}]
|
| 318 |
tmp_df['from_user_id'] = [uid]
|
| 319 |
+
tmp_df['text'] = [text]
|
| 320 |
+
tmp_df['gemb'] = [new_gem]
|
| 321 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
|
|
|
| 322 |
# we can free up storage by deleting the image
|
| 323 |
+
if len(prevs_df) > 500:
|
| 324 |
+
oldest_path = prevs_df.iloc[6]['paths']
|
| 325 |
+
if os.path.isfile(oldest_path):
|
| 326 |
+
os.remove(oldest_path)
|
| 327 |
+
else:
|
| 328 |
+
# If it fails, inform the user.
|
| 329 |
+
print("Error: %s file not found" % oldest_path)
|
| 330 |
+
# only keep 50 images & embeddings & ips, then remove oldest besides calibrating
|
| 331 |
+
prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:]))
|
| 332 |
+
|
| 333 |
|
| 334 |
def pluck_embs_ys(user_id):
|
| 335 |
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
|
|
|
| 342 |
|
| 343 |
embs = rated_rows['embeddings'].to_list()
|
| 344 |
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
| 345 |
+
gembs = rated_rows['gemb'].to_list()
|
| 346 |
+
return embs, ys, gembs
|
| 347 |
|
| 348 |
def next_image(calibrate_prompts, user_id):
|
|
|
|
| 349 |
with torch.no_grad():
|
| 350 |
if len(calibrate_prompts) > 0:
|
| 351 |
cal_video = calibrate_prompts.pop(0)
|
| 352 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
| 353 |
|
| 354 |
+
return image, calibrate_prompts, ''
|
| 355 |
else:
|
| 356 |
+
embs, ys, gembs = pluck_embs_ys(user_id)
|
| 357 |
user_emb = get_user_emb(embs, ys)
|
| 358 |
+
image, text = pluck_img(user_id, user_emb)
|
| 359 |
+
return image, calibrate_prompts, text
|
| 360 |
|
| 361 |
|
| 362 |
|
|
|
|
| 368 |
|
| 369 |
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
| 370 |
user_id = int(str(time.time())[-7:].replace('.', ''))
|
| 371 |
+
image, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
|
| 372 |
return [
|
| 373 |
gr.Button(value='Like (L)', interactive=True),
|
| 374 |
gr.Button(value='Neither (Space)', interactive=True, visible=False),
|
|
|
|
| 387 |
if choice == 'Like (L)':
|
| 388 |
choice = 1
|
| 389 |
elif choice == 'Neither (Space)':
|
| 390 |
+
img, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
|
| 391 |
+
return img, calibrate_prompts, text
|
| 392 |
else:
|
| 393 |
choice = 0
|
| 394 |
|
| 395 |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
| 396 |
# TODO skip allowing rating & just continue
|
| 397 |
if img == None:
|
| 398 |
+
print('NSFW -- choice is disliked')
|
| 399 |
choice = 0
|
| 400 |
|
| 401 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
|
|
|
| 403 |
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
| 404 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
| 405 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
| 406 |
+
img, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
|
| 407 |
+
return img, calibrate_prompts, text
|
| 408 |
|
| 409 |
css = '''.gradio-container{max-width: 700px !important}
|
| 410 |
#description{text-align: center}
|
|
|
|
| 488 |
elem_id="video_output"
|
| 489 |
)
|
| 490 |
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
| 491 |
+
with gr.Row():
|
| 492 |
+
text = gr.Textbox(interactive=False, visible=True, label='Text')
|
| 493 |
with gr.Row(equal_height=True):
|
| 494 |
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
|
| 495 |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
|
|
|
|
| 497 |
b1.click(
|
| 498 |
choose,
|
| 499 |
[img, b1, calibrate_prompts, user_id],
|
| 500 |
+
[img, calibrate_prompts, text],
|
| 501 |
)
|
| 502 |
b2.click(
|
| 503 |
choose,
|
| 504 |
[img, b2, calibrate_prompts, user_id],
|
| 505 |
+
[img, calibrate_prompts, text],
|
| 506 |
)
|
| 507 |
b3.click(
|
| 508 |
choose,
|
| 509 |
[img, b3, calibrate_prompts, user_id],
|
| 510 |
+
[img, calibrate_prompts, text],
|
| 511 |
)
|
| 512 |
with gr.Row():
|
| 513 |
b4 = gr.Button(value='Start')
|
|
|
|
| 528 |
log.setLevel(logging.ERROR)
|
| 529 |
|
| 530 |
scheduler = BackgroundScheduler()
|
| 531 |
+
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.5)
|
| 532 |
scheduler.start()
|
| 533 |
|
| 534 |
#thread = threading.Thread(target=background_next_image,)
|
| 535 |
#thread.start()
|
| 536 |
|
| 537 |
+
# TODO shouldn't call this before gradio launch, yeah?
|
| 538 |
@spaces.GPU()
|
| 539 |
def encode_space(x):
|
| 540 |
im_emb, _ = pipe.encode_image(
|
| 541 |
image, DEVICE, 1, output_hidden_state
|
| 542 |
)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
im = torchvision.transforms.ToTensor()(x).unsqueeze(0)
|
| 546 |
+
im = torch.nn.functional.interpolate(im, (224, 224))
|
| 547 |
+
im = (im - .5) * 2
|
| 548 |
+
gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32).mean(1)
|
| 549 |
+
|
| 550 |
+
return im_emb.detach().to('cpu').to(torch.float32), gemb
|
| 551 |
|
| 552 |
+
# prep our calibration videos
|
| 553 |
for im in [
|
| 554 |
'./first.mp4',
|
| 555 |
'./second.mp4',
|
|
|
|
| 557 |
'./fourth.mp4',
|
| 558 |
'./fifth.mp4',
|
| 559 |
'./sixth.mp4',
|
| 560 |
+
'./seventh.mp4',
|
| 561 |
+
'./eigth.mp4',
|
| 562 |
+
'./ninth.mp4',
|
| 563 |
+
'./tenth.mp4',
|
| 564 |
]:
|
| 565 |
+
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
|
| 566 |
tmp_df['paths'] = [im]
|
| 567 |
image = list(imageio.imiter(im))
|
| 568 |
image = image[len(image)//2]
|
| 569 |
+
im_emb, gemb = encode_space(image)
|
| 570 |
|
| 571 |
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
| 572 |
+
tmp_df['gemb'] = [gemb.detach().to('cpu')]
|
| 573 |
tmp_df['user:rating'] = [{' ': ' '}]
|
| 574 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 575 |
|
| 576 |
|
| 577 |
+
demo.launch(share=True, server_port=8443)
|
| 578 |
+
|
| 579 |
+
|
eigth.mp4
ADDED
|
Binary file (47.7 kB). View file
|
|
|
ninth.mp4
ADDED
|
Binary file (255 kB). View file
|
|
|
seventh.mp4
ADDED
|
Binary file (50 kB). View file
|
|
|
tenth.mp4
ADDED
|
Binary file (129 kB). View file
|
|
|