Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						55d6a83
	
1
								Parent(s):
							
							a581500
								
match multimodlar
Browse files- __pycache__/safety_checker_improved.cpython-310.pyc +0 -0
- app.py +144 -65
- eigth.mp4 +0 -0
- ninth.mp4 +0 -0
- seventh.mp4 +0 -0
- tenth.mp4 +0 -0
    	
        __pycache__/safety_checker_improved.cpython-310.pyc
    DELETED
    
    | Binary file (1.38 kB) | 
|  | 
    	
        app.py
    CHANGED
    
    | @@ -15,14 +15,12 @@ import matplotlib.pyplot as plt | |
| 15 | 
             
            import matplotlib
         | 
| 16 | 
             
            import logging
         | 
| 17 |  | 
| 18 | 
            -
            from sklearn.linear_model import Ridge
         | 
| 19 |  | 
| 20 | 
             
            import os
         | 
| 21 | 
             
            import imageio
         | 
| 22 | 
             
            import gradio as gr
         | 
| 23 | 
             
            import numpy as np
         | 
| 24 | 
             
            from sklearn.svm import SVC
         | 
| 25 | 
            -
            from sklearn.inspection import permutation_importance
         | 
| 26 | 
             
            from sklearn import preprocessing
         | 
| 27 | 
             
            import pandas as pd
         | 
| 28 | 
             
            from apscheduler.schedulers.background import BackgroundScheduler
         | 
| @@ -39,14 +37,13 @@ torch.set_grad_enabled(False) | |
| 39 | 
             
            torch.backends.cuda.matmul.allow_tf32 = True
         | 
| 40 | 
             
            torch.backends.cudnn.allow_tf32 = True
         | 
| 41 |  | 
| 42 | 
            -
            prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id'])
         | 
| 43 |  | 
| 44 | 
             
            import spaces
         | 
| 45 | 
             
            start_time = time.time()
         | 
| 46 |  | 
| 47 | 
             
            ####################### Setup Model
         | 
| 48 | 
            -
            from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL | 
| 49 | 
            -
            utils.logging.disable_progress_bar
         | 
| 50 | 
             
            from transformers import CLIPTextModel
         | 
| 51 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 52 | 
             
            from safetensors.torch import load_file
         | 
| @@ -54,6 +51,7 @@ from PIL import Image | |
| 54 | 
             
            from transformers import CLIPVisionModelWithProjection
         | 
| 55 | 
             
            import uuid
         | 
| 56 | 
             
            import av
         | 
|  | |
| 57 |  | 
| 58 | 
             
            def write_video(file_name, images, fps=17):
         | 
| 59 | 
             
                container = av.open(file_name, mode="w")
         | 
| @@ -92,6 +90,9 @@ device_map='cuda') | |
| 92 | 
             
            # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
         | 
| 93 | 
             
            # vae = compile_unet(vae, config=config)
         | 
| 94 |  | 
|  | |
|  | |
|  | |
| 95 |  | 
| 96 |  | 
| 97 | 
             
            unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
         | 
| @@ -99,7 +100,8 @@ text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder=' | |
| 99 | 
             
            device_map='cpu').to(dtype)
         | 
| 100 |  | 
| 101 | 
             
            adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
         | 
| 102 | 
            -
            pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, | 
|  | |
| 103 | 
             
            pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
         | 
| 104 | 
             
            pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
         | 
| 105 | 
             
            pipe.set_adapters(["lcm-lora"], [.9])
         | 
| @@ -114,7 +116,7 @@ pipe.fuse_lora() | |
| 114 |  | 
| 115 | 
             
            pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
         | 
| 116 | 
             
            # This IP adapter improves outputs substantially.
         | 
| 117 | 
            -
            pipe.set_ip_adapter_scale(. | 
| 118 | 
             
            pipe.unet.fuse_qkv_projections()
         | 
| 119 | 
             
            #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
         | 
| 120 |  | 
| @@ -122,21 +124,71 @@ pipe.to(device=DEVICE) | |
| 122 | 
             
            #pipe.unet = torch.compile(pipe.unet)
         | 
| 123 | 
             
            #pipe.vae = torch.compile(pipe.vae)
         | 
| 124 |  | 
| 125 | 
            -
            @spaces.GPU()
         | 
| 126 | 
            -
            def generate_gpu(in_im_embs):
         | 
| 127 | 
            -
                in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
         | 
| 128 | 
            -
                output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
         | 
| 129 | 
            -
                im_emb, _ = pipe.encode_image(
         | 
| 130 | 
            -
                            output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
         | 
| 131 | 
            -
                        )
         | 
| 132 | 
            -
                im_emb = im_emb.detach().to('cpu').to(torch.float32)
         | 
| 133 | 
            -
                return output, im_emb
         | 
| 134 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 135 |  | 
| 136 | 
            -
             | 
| 137 | 
            -
             | 
| 138 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 139 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 140 | 
             
                name = str(uuid.uuid4()).replace("-", "")
         | 
| 141 | 
             
                path = f"/tmp/{name}.mp4"
         | 
| 142 |  | 
| @@ -149,19 +201,19 @@ def generate(in_im_embs): | |
| 149 | 
             
                output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
         | 
| 150 |  | 
| 151 | 
             
                write_video(path, output.frames[0])
         | 
| 152 | 
            -
                return path, im_emb
         | 
| 153 |  | 
| 154 |  | 
| 155 | 
             
            #######################
         | 
| 156 |  | 
| 157 | 
             
            def get_user_emb(embs, ys):
         | 
| 158 | 
             
                # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
         | 
|  | |
| 159 | 
             
                if len(list(ys)) <= 7:
         | 
| 160 | 
            -
                    aways = [.01*torch. | 
| 161 | 
             
                    embs += aways
         | 
| 162 | 
             
                    awal = [0 for i in range(3)]
         | 
| 163 | 
             
                    ys += awal
         | 
| 164 | 
            -
                    print('Fixing only one feedback class available.\n')
         | 
| 165 |  | 
| 166 | 
             
                indices = list(range(len(embs)))
         | 
| 167 | 
             
                # sample only as many negatives as there are positives
         | 
| @@ -176,21 +228,20 @@ def get_user_emb(embs, ys): | |
| 176 | 
             
                # this ends up adding a rating but losing an embedding, it seems.
         | 
| 177 | 
             
                # let's take off a rating if so to continue without indexing errors.
         | 
| 178 | 
             
                if len(ys) > len(embs):
         | 
|  | |
| 179 | 
             
                    ys.pop(-1)
         | 
| 180 |  | 
| 181 | 
             
                feature_embs = torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu')
         | 
| 182 | 
             
                #scaler = preprocessing.StandardScaler().fit(feature_embs)
         | 
| 183 | 
             
                #feature_embs = scaler.transform(feature_embs)
         | 
| 184 | 
            -
                
         | 
| 185 |  | 
| 186 | 
             
                if feature_embs.norm() != 0:
         | 
| 187 | 
             
                    feature_embs = feature_embs / feature_embs.norm()
         | 
| 188 |  | 
| 189 | 
            -
                chosen_y = np.array([ys[i] for i in indices])
         | 
| 190 | 
            -
                
         | 
| 191 | 
             
                #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
         | 
| 192 | 
            -
                lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y)
         | 
| 193 | 
            -
                coef_ = torch.tensor(lin_class.coef_, dtype=torch. | 
| 194 | 
             
                coef_ = coef_ / coef_.abs().max() * 3
         | 
| 195 |  | 
| 196 | 
             
                w = 1# if len(embs) % 2 == 0 else 0
         | 
| @@ -212,7 +263,8 @@ def pluck_img(user_id, user_emb): | |
| 212 | 
             
                        best_sim = sim
         | 
| 213 | 
             
                        best_row = i[1]
         | 
| 214 | 
             
                img = best_row['paths']
         | 
| 215 | 
            -
                 | 
|  | |
| 216 |  | 
| 217 |  | 
| 218 | 
             
            def background_next_image():
         | 
| @@ -236,39 +288,48 @@ def background_next_image(): | |
| 236 | 
             
                        unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]]
         | 
| 237 | 
             
                        rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
         | 
| 238 |  | 
| 239 | 
            -
                        # we pop previous ratings if there are >  | 
| 240 | 
            -
                        if len(rated_from_user) >=  | 
| 241 | 
             
                            oldest = rated_from_user.iloc[0]['paths']
         | 
| 242 | 
             
                            prevs_df = prevs_df[prevs_df['paths'] != oldest]
         | 
| 243 | 
            -
                        # we don't compute more after  | 
| 244 | 
             
                        if len(unrated_from_user) >= 10:
         | 
| 245 | 
             
                            continue
         | 
| 246 |  | 
| 247 | 
            -
                        if len(rated_rows) <  | 
| 248 | 
             
                            continue
         | 
| 249 |  | 
| 250 | 
            -
                        embs, ys = pluck_embs_ys(uid)
         | 
| 251 |  | 
| 252 | 
             
                        user_emb = get_user_emb(embs, ys)
         | 
| 253 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 254 | 
             
                        if img:
         | 
| 255 | 
            -
                            tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
         | 
| 256 | 
             
                            tmp_df['paths'] = [img]
         | 
| 257 | 
             
                            tmp_df['embeddings'] = [embs]
         | 
| 258 | 
             
                            tmp_df['user:rating'] = [{' ': ' '}]
         | 
| 259 | 
             
                            tmp_df['from_user_id'] = [uid]
         | 
|  | |
|  | |
| 260 | 
             
                            prevs_df = pd.concat((prevs_df, tmp_df))
         | 
| 261 | 
            -
                            
         | 
| 262 | 
             
                            # we can free up storage by deleting the image
         | 
| 263 | 
            -
                            if len(prevs_df) >  | 
| 264 | 
            -
                                 | 
| 265 | 
            -
                                 | 
| 266 | 
            -
             | 
| 267 | 
            -
                                 | 
| 268 | 
            -
             | 
| 269 | 
            -
                                     | 
| 270 | 
            -
                                # only keep  | 
| 271 | 
            -
                                prevs_df = prevs_df[prevs_df[ | 
|  | |
| 272 |  | 
| 273 | 
             
            def pluck_embs_ys(user_id):
         | 
| 274 | 
             
                rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
         | 
| @@ -281,21 +342,21 @@ def pluck_embs_ys(user_id): | |
| 281 |  | 
| 282 | 
             
                embs = rated_rows['embeddings'].to_list()
         | 
| 283 | 
             
                ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
         | 
| 284 | 
            -
                 | 
|  | |
| 285 |  | 
| 286 | 
             
            def next_image(calibrate_prompts, user_id):
         | 
| 287 | 
            -
                
         | 
| 288 | 
             
                with torch.no_grad():
         | 
| 289 | 
             
                    if len(calibrate_prompts) > 0:
         | 
| 290 | 
             
                        cal_video = calibrate_prompts.pop(0)
         | 
| 291 | 
             
                        image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
         | 
| 292 |  | 
| 293 | 
            -
                        return image, calibrate_prompts
         | 
| 294 | 
             
                    else:
         | 
| 295 | 
            -
                        embs, ys = pluck_embs_ys(user_id)
         | 
| 296 | 
             
                        user_emb = get_user_emb(embs, ys)
         | 
| 297 | 
            -
                        image = pluck_img(user_id, user_emb)
         | 
| 298 | 
            -
                        return image, calibrate_prompts
         | 
| 299 |  | 
| 300 |  | 
| 301 |  | 
| @@ -307,7 +368,7 @@ def next_image(calibrate_prompts, user_id): | |
| 307 |  | 
| 308 | 
             
            def start(_, calibrate_prompts, user_id, request: gr.Request):
         | 
| 309 | 
             
                user_id = int(str(time.time())[-7:].replace('.', ''))
         | 
| 310 | 
            -
                image, calibrate_prompts = next_image(calibrate_prompts, user_id)
         | 
| 311 | 
             
                return [
         | 
| 312 | 
             
                        gr.Button(value='Like (L)', interactive=True), 
         | 
| 313 | 
             
                        gr.Button(value='Neither (Space)', interactive=True, visible=False), 
         | 
| @@ -326,14 +387,15 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request): | |
| 326 | 
             
                if choice == 'Like (L)':
         | 
| 327 | 
             
                    choice = 1
         | 
| 328 | 
             
                elif choice == 'Neither (Space)':
         | 
| 329 | 
            -
                    img, calibrate_prompts = next_image(calibrate_prompts, user_id)
         | 
| 330 | 
            -
                    return img, calibrate_prompts
         | 
| 331 | 
             
                else:
         | 
| 332 | 
             
                    choice = 0
         | 
| 333 |  | 
| 334 | 
             
                # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
         | 
| 335 | 
             
                # TODO skip allowing rating & just continue
         | 
| 336 | 
             
                if img == None:
         | 
|  | |
| 337 | 
             
                    choice = 0
         | 
| 338 |  | 
| 339 | 
             
                row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
         | 
| @@ -341,8 +403,8 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request): | |
| 341 | 
             
                if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
         | 
| 342 | 
             
                    prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
         | 
| 343 | 
             
                    prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
         | 
| 344 | 
            -
                img, calibrate_prompts = next_image(calibrate_prompts, user_id)
         | 
| 345 | 
            -
                return img, calibrate_prompts
         | 
| 346 |  | 
| 347 | 
             
            css = '''.gradio-container{max-width: 700px !important}
         | 
| 348 | 
             
            #description{text-align: center}
         | 
| @@ -426,6 +488,8 @@ Explore the latent space without text prompts based on your preferences. Learn m | |
| 426 | 
             
                    elem_id="video_output"
         | 
| 427 | 
             
                   )
         | 
| 428 | 
             
                    img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
         | 
|  | |
|  | |
| 429 | 
             
                with gr.Row(equal_height=True):
         | 
| 430 | 
             
                    b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
         | 
| 431 | 
             
                    b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
         | 
| @@ -433,17 +497,17 @@ Explore the latent space without text prompts based on your preferences. Learn m | |
| 433 | 
             
                    b1.click(
         | 
| 434 | 
             
                    choose, 
         | 
| 435 | 
             
                    [img, b1, calibrate_prompts, user_id],
         | 
| 436 | 
            -
                    [img, calibrate_prompts],
         | 
| 437 | 
             
                    )
         | 
| 438 | 
             
                    b2.click(
         | 
| 439 | 
             
                    choose, 
         | 
| 440 | 
             
                    [img, b2, calibrate_prompts, user_id],
         | 
| 441 | 
            -
                    [img, calibrate_prompts],
         | 
| 442 | 
             
                    )
         | 
| 443 | 
             
                    b3.click(
         | 
| 444 | 
             
                    choose, 
         | 
| 445 | 
             
                    [img, b3, calibrate_prompts, user_id],
         | 
| 446 | 
            -
                    [img, calibrate_prompts],
         | 
| 447 | 
             
                    )
         | 
| 448 | 
             
                with gr.Row():
         | 
| 449 | 
             
                    b4 = gr.Button(value='Start')
         | 
| @@ -464,20 +528,28 @@ log = logging.getLogger('log_here') | |
| 464 | 
             
            log.setLevel(logging.ERROR)
         | 
| 465 |  | 
| 466 | 
             
            scheduler = BackgroundScheduler()
         | 
| 467 | 
            -
            scheduler.add_job(func=background_next_image, trigger="interval", seconds=. | 
| 468 | 
             
            scheduler.start()
         | 
| 469 |  | 
| 470 | 
             
            #thread = threading.Thread(target=background_next_image,)
         | 
| 471 | 
             
            #thread.start()
         | 
| 472 |  | 
|  | |
| 473 | 
             
            @spaces.GPU()
         | 
| 474 | 
             
            def encode_space(x):
         | 
| 475 | 
             
                im_emb, _ = pipe.encode_image(
         | 
| 476 | 
             
                            image, DEVICE, 1, output_hidden_state
         | 
| 477 | 
             
                        )
         | 
| 478 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 479 |  | 
| 480 | 
            -
            # prep our calibration  | 
| 481 | 
             
            for im in [
         | 
| 482 | 
             
                './first.mp4',
         | 
| 483 | 
             
                './second.mp4',
         | 
| @@ -485,16 +557,23 @@ for im in [ | |
| 485 | 
             
                './fourth.mp4',
         | 
| 486 | 
             
                './fifth.mp4',
         | 
| 487 | 
             
                './sixth.mp4',
         | 
|  | |
|  | |
|  | |
|  | |
| 488 | 
             
                ]:
         | 
| 489 | 
            -
                tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating'])
         | 
| 490 | 
             
                tmp_df['paths'] = [im]
         | 
| 491 | 
             
                image = list(imageio.imiter(im))
         | 
| 492 | 
             
                image = image[len(image)//2]
         | 
| 493 | 
            -
                im_emb = encode_space(image)
         | 
| 494 |  | 
| 495 | 
             
                tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
         | 
|  | |
| 496 | 
             
                tmp_df['user:rating'] = [{' ': ' '}]
         | 
| 497 | 
             
                prevs_df = pd.concat((prevs_df, tmp_df))
         | 
| 498 |  | 
| 499 |  | 
| 500 | 
            -
            demo.launch(share=True)
         | 
|  | |
|  | 
|  | |
| 15 | 
             
            import matplotlib
         | 
| 16 | 
             
            import logging
         | 
| 17 |  | 
|  | |
| 18 |  | 
| 19 | 
             
            import os
         | 
| 20 | 
             
            import imageio
         | 
| 21 | 
             
            import gradio as gr
         | 
| 22 | 
             
            import numpy as np
         | 
| 23 | 
             
            from sklearn.svm import SVC
         | 
|  | |
| 24 | 
             
            from sklearn import preprocessing
         | 
| 25 | 
             
            import pandas as pd
         | 
| 26 | 
             
            from apscheduler.schedulers.background import BackgroundScheduler
         | 
|  | |
| 37 | 
             
            torch.backends.cuda.matmul.allow_tf32 = True
         | 
| 38 | 
             
            torch.backends.cudnn.allow_tf32 = True
         | 
| 39 |  | 
| 40 | 
            +
            prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id', 'text', 'gemb'])
         | 
| 41 |  | 
| 42 | 
             
            import spaces
         | 
| 43 | 
             
            start_time = time.time()
         | 
| 44 |  | 
| 45 | 
             
            ####################### Setup Model
         | 
| 46 | 
            +
            from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
         | 
|  | |
| 47 | 
             
            from transformers import CLIPTextModel
         | 
| 48 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 49 | 
             
            from safetensors.torch import load_file
         | 
|  | |
| 51 | 
             
            from transformers import CLIPVisionModelWithProjection
         | 
| 52 | 
             
            import uuid
         | 
| 53 | 
             
            import av
         | 
| 54 | 
            +
            import torchvision
         | 
| 55 |  | 
| 56 | 
             
            def write_video(file_name, images, fps=17):
         | 
| 57 | 
             
                container = av.open(file_name, mode="w")
         | 
|  | |
| 90 | 
             
            # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
         | 
| 91 | 
             
            # vae = compile_unet(vae, config=config)
         | 
| 92 |  | 
| 93 | 
            +
            #finetune_path = '''/home/ryn_mote/Misc/finetune-sd1.5/dreambooth-model best'''''
         | 
| 94 | 
            +
            #unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
         | 
| 95 | 
            +
            #text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
         | 
| 96 |  | 
| 97 |  | 
| 98 | 
             
            unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
         | 
|  | |
| 100 | 
             
            device_map='cpu').to(dtype)
         | 
| 101 |  | 
| 102 | 
             
            adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
         | 
| 103 | 
            +
            pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype,     
         | 
| 104 | 
            +
                                                        unet=unet, text_encoder=text_encoder)
         | 
| 105 | 
             
            pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
         | 
| 106 | 
             
            pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
         | 
| 107 | 
             
            pipe.set_adapters(["lcm-lora"], [.9])
         | 
|  | |
| 116 |  | 
| 117 | 
             
            pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
         | 
| 118 | 
             
            # This IP adapter improves outputs substantially.
         | 
| 119 | 
            +
            pipe.set_ip_adapter_scale(.6)
         | 
| 120 | 
             
            pipe.unet.fuse_qkv_projections()
         | 
| 121 | 
             
            #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
         | 
| 122 |  | 
|  | |
| 124 | 
             
            #pipe.unet = torch.compile(pipe.unet)
         | 
| 125 | 
             
            #pipe.vae = torch.compile(pipe.vae)
         | 
| 126 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 127 |  | 
| 128 | 
            +
            #############################################################
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            quantization_config = BitsAndBytesConfig(load_in_4bit=True)
         | 
| 133 | 
            +
            pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-pt-224', torch_dtype=dtype, quantization_config=quantization_config).eval()
         | 
| 134 | 
            +
            processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
         | 
| 135 |  | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
         | 
| 139 | 
            +
                inputs_embeds = pali.get_input_embeddings()(input_ids)
         | 
| 140 | 
            +
                selected_image_feature = image_outputs.to(dtype).to(device)
         | 
| 141 | 
            +
                image_features = pali.multi_modal_projector(selected_image_feature)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                if cache_position is None:
         | 
| 144 | 
            +
                    cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
         | 
| 145 | 
            +
                inputs_embeds, attention_mask, labels, position_ids = pali._merge_input_ids_with_image_features(
         | 
| 146 | 
            +
                    image_features, inputs_embeds, input_ids, attention_mask, None, None, cache_position
         | 
| 147 | 
            +
                )
         | 
| 148 | 
            +
                return inputs_embeds
         | 
| 149 |  | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
            def generate_pali(user_emb):
         | 
| 153 | 
            +
                prompt = 'caption en'
         | 
| 154 | 
            +
                model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
         | 
| 155 | 
            +
                # we need to get im_embs taken in here.
         | 
| 156 | 
            +
                input_len = model_inputs["input_ids"].shape[-1]
         | 
| 157 | 
            +
                input_embeds = to_wanted_embs(user_emb.squeeze()[None, None, :].repeat(1, 256, 1), 
         | 
| 158 | 
            +
                                    model_inputs["input_ids"].to(device), 
         | 
| 159 | 
            +
                                    model_inputs["attention_mask"].to(device))
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                generation = pali.generate(max_new_tokens=100, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
         | 
| 162 | 
            +
                decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
         | 
| 163 | 
            +
                return decoded
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
             | 
| 168 | 
            +
            #############################################################
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            @spaces.GPU()
         | 
| 173 | 
            +
            def generate_gpu(in_im_embs, prompt='the scene'):
         | 
| 174 | 
            +
                with torch.no_grad():
         | 
| 175 | 
            +
                    in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
         | 
| 176 | 
            +
                    output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
         | 
| 177 | 
            +
                    im_emb, _ = pipe.encode_image(
         | 
| 178 | 
            +
                                output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
         | 
| 179 | 
            +
                            )
         | 
| 180 | 
            +
                    im_emb = im_emb.detach().to('cpu').to(torch.float32)
         | 
| 181 | 
            +
                    im = torchvision.transforms.ToTensor()(output.frames[0][len(output.frames[0])//2]).unsqueeze(0)
         | 
| 182 | 
            +
                    im = torch.nn.functional.interpolate(im, (224, 224))
         | 
| 183 | 
            +
                    im = (im - .5) * 2
         | 
| 184 | 
            +
                    gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32).mean(1)
         | 
| 185 | 
            +
                return output, im_emb, gemb
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            def generate(in_im_embs, prompt='the scene'):
         | 
| 189 | 
            +
                output, im_emb, gemb = generate_gpu(in_im_embs, prompt)
         | 
| 190 | 
            +
                nsfw =maybe_nsfw(output.frames[0][len(output.frames[0])//2])
         | 
| 191 | 
            +
                print(prompt)
         | 
| 192 | 
             
                name = str(uuid.uuid4()).replace("-", "")
         | 
| 193 | 
             
                path = f"/tmp/{name}.mp4"
         | 
| 194 |  | 
|  | |
| 201 | 
             
                output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
         | 
| 202 |  | 
| 203 | 
             
                write_video(path, output.frames[0])
         | 
| 204 | 
            +
                return path, im_emb, gemb
         | 
| 205 |  | 
| 206 |  | 
| 207 | 
             
            #######################
         | 
| 208 |  | 
| 209 | 
             
            def get_user_emb(embs, ys):
         | 
| 210 | 
             
                # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
         | 
| 211 | 
            +
                
         | 
| 212 | 
             
                if len(list(ys)) <= 7:
         | 
| 213 | 
            +
                    aways = [.01*torch.randn_like(embs[0]) for i in range(3)]
         | 
| 214 | 
             
                    embs += aways
         | 
| 215 | 
             
                    awal = [0 for i in range(3)]
         | 
| 216 | 
             
                    ys += awal
         | 
|  | |
| 217 |  | 
| 218 | 
             
                indices = list(range(len(embs)))
         | 
| 219 | 
             
                # sample only as many negatives as there are positives
         | 
|  | |
| 228 | 
             
                # this ends up adding a rating but losing an embedding, it seems.
         | 
| 229 | 
             
                # let's take off a rating if so to continue without indexing errors.
         | 
| 230 | 
             
                if len(ys) > len(embs):
         | 
| 231 | 
            +
                    print('ys are longer than embs; popping latest rating')
         | 
| 232 | 
             
                    ys.pop(-1)
         | 
| 233 |  | 
| 234 | 
             
                feature_embs = torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu')
         | 
| 235 | 
             
                #scaler = preprocessing.StandardScaler().fit(feature_embs)
         | 
| 236 | 
             
                #feature_embs = scaler.transform(feature_embs)
         | 
| 237 | 
            +
                chosen_y = np.array([ys[i] for i in indices])
         | 
| 238 |  | 
| 239 | 
             
                if feature_embs.norm() != 0:
         | 
| 240 | 
             
                    feature_embs = feature_embs / feature_embs.norm()
         | 
| 241 |  | 
|  | |
|  | |
| 242 | 
             
                #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
         | 
| 243 | 
            +
                lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs.squeeze(), chosen_y)
         | 
| 244 | 
            +
                coef_ = torch.tensor(lin_class.coef_, dtype=torch.float32).detach().to('cpu')
         | 
| 245 | 
             
                coef_ = coef_ / coef_.abs().max() * 3
         | 
| 246 |  | 
| 247 | 
             
                w = 1# if len(embs) % 2 == 0 else 0
         | 
|  | |
| 263 | 
             
                        best_sim = sim
         | 
| 264 | 
             
                        best_row = i[1]
         | 
| 265 | 
             
                img = best_row['paths']
         | 
| 266 | 
            +
                text = best_row.get('text', '')
         | 
| 267 | 
            +
                return img, text
         | 
| 268 |  | 
| 269 |  | 
| 270 | 
             
            def background_next_image():
         | 
|  | |
| 288 | 
             
                        unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]]
         | 
| 289 | 
             
                        rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
         | 
| 290 |  | 
| 291 | 
            +
                        # we pop previous ratings if there are > n
         | 
| 292 | 
            +
                        if len(rated_from_user) >= 15:
         | 
| 293 | 
             
                            oldest = rated_from_user.iloc[0]['paths']
         | 
| 294 | 
             
                            prevs_df = prevs_df[prevs_df['paths'] != oldest]
         | 
| 295 | 
            +
                        # we don't compute more after n are in the queue for them
         | 
| 296 | 
             
                        if len(unrated_from_user) >= 10:
         | 
| 297 | 
             
                            continue
         | 
| 298 |  | 
| 299 | 
            +
                        if len(rated_rows) < 5:
         | 
| 300 | 
             
                            continue
         | 
| 301 |  | 
| 302 | 
            +
                        embs, ys, gembs = pluck_embs_ys(uid)
         | 
| 303 |  | 
| 304 | 
             
                        user_emb = get_user_emb(embs, ys)
         | 
| 305 | 
            +
                                    
         | 
| 306 | 
            +
                        if len(gembs) > 4:
         | 
| 307 | 
            +
                            user_gem = get_user_emb(gembs, ys) / 4 # TODO scale this correctly; matplotlib, etc.
         | 
| 308 | 
            +
                            text = generate_pali(user_gem)
         | 
| 309 | 
            +
                        else:
         | 
| 310 | 
            +
                            text = generate_pali(torch.zeros(1, 1152))
         | 
| 311 | 
            +
                        img, embs, new_gem = generate(user_emb, text)
         | 
| 312 | 
            +
                        
         | 
| 313 | 
             
                        if img:
         | 
| 314 | 
            +
                            tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
         | 
| 315 | 
             
                            tmp_df['paths'] = [img]
         | 
| 316 | 
             
                            tmp_df['embeddings'] = [embs]
         | 
| 317 | 
             
                            tmp_df['user:rating'] = [{' ': ' '}]
         | 
| 318 | 
             
                            tmp_df['from_user_id'] = [uid]
         | 
| 319 | 
            +
                            tmp_df['text'] = [text]
         | 
| 320 | 
            +
                            tmp_df['gemb'] = [new_gem]
         | 
| 321 | 
             
                            prevs_df = pd.concat((prevs_df, tmp_df))
         | 
|  | |
| 322 | 
             
                            # we can free up storage by deleting the image
         | 
| 323 | 
            +
                            if len(prevs_df) > 500:
         | 
| 324 | 
            +
                                oldest_path = prevs_df.iloc[6]['paths']
         | 
| 325 | 
            +
                                if os.path.isfile(oldest_path):
         | 
| 326 | 
            +
                                    os.remove(oldest_path)
         | 
| 327 | 
            +
                                else:
         | 
| 328 | 
            +
                                    # If it fails, inform the user.
         | 
| 329 | 
            +
                                    print("Error: %s file not found" % oldest_path)
         | 
| 330 | 
            +
                                # only keep 50 images & embeddings & ips, then remove oldest besides calibrating
         | 
| 331 | 
            +
                                prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:]))
         | 
| 332 | 
            +
                
         | 
| 333 |  | 
| 334 | 
             
            def pluck_embs_ys(user_id):
         | 
| 335 | 
             
                rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
         | 
|  | |
| 342 |  | 
| 343 | 
             
                embs = rated_rows['embeddings'].to_list()
         | 
| 344 | 
             
                ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
         | 
| 345 | 
            +
                gembs = rated_rows['gemb'].to_list()
         | 
| 346 | 
            +
                return embs, ys, gembs
         | 
| 347 |  | 
| 348 | 
             
            def next_image(calibrate_prompts, user_id):
         | 
|  | |
| 349 | 
             
                with torch.no_grad():
         | 
| 350 | 
             
                    if len(calibrate_prompts) > 0:
         | 
| 351 | 
             
                        cal_video = calibrate_prompts.pop(0)
         | 
| 352 | 
             
                        image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
         | 
| 353 |  | 
| 354 | 
            +
                        return image, calibrate_prompts, ''
         | 
| 355 | 
             
                    else:
         | 
| 356 | 
            +
                        embs, ys, gembs = pluck_embs_ys(user_id)
         | 
| 357 | 
             
                        user_emb = get_user_emb(embs, ys)
         | 
| 358 | 
            +
                        image, text = pluck_img(user_id, user_emb)
         | 
| 359 | 
            +
                        return image, calibrate_prompts, text
         | 
| 360 |  | 
| 361 |  | 
| 362 |  | 
|  | |
| 368 |  | 
| 369 | 
             
            def start(_, calibrate_prompts, user_id, request: gr.Request):
         | 
| 370 | 
             
                user_id = int(str(time.time())[-7:].replace('.', ''))
         | 
| 371 | 
            +
                image, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
         | 
| 372 | 
             
                return [
         | 
| 373 | 
             
                        gr.Button(value='Like (L)', interactive=True), 
         | 
| 374 | 
             
                        gr.Button(value='Neither (Space)', interactive=True, visible=False), 
         | 
|  | |
| 387 | 
             
                if choice == 'Like (L)':
         | 
| 388 | 
             
                    choice = 1
         | 
| 389 | 
             
                elif choice == 'Neither (Space)':
         | 
| 390 | 
            +
                    img, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
         | 
| 391 | 
            +
                    return img, calibrate_prompts, text
         | 
| 392 | 
             
                else:
         | 
| 393 | 
             
                    choice = 0
         | 
| 394 |  | 
| 395 | 
             
                # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
         | 
| 396 | 
             
                # TODO skip allowing rating & just continue
         | 
| 397 | 
             
                if img == None:
         | 
| 398 | 
            +
                    print('NSFW -- choice is disliked')
         | 
| 399 | 
             
                    choice = 0
         | 
| 400 |  | 
| 401 | 
             
                row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
         | 
|  | |
| 403 | 
             
                if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
         | 
| 404 | 
             
                    prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
         | 
| 405 | 
             
                    prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
         | 
| 406 | 
            +
                img, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
         | 
| 407 | 
            +
                return img, calibrate_prompts, text
         | 
| 408 |  | 
| 409 | 
             
            css = '''.gradio-container{max-width: 700px !important}
         | 
| 410 | 
             
            #description{text-align: center}
         | 
|  | |
| 488 | 
             
                    elem_id="video_output"
         | 
| 489 | 
             
                   )
         | 
| 490 | 
             
                    img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
         | 
| 491 | 
            +
                with gr.Row():
         | 
| 492 | 
            +
                    text = gr.Textbox(interactive=False, visible=True, label='Text')
         | 
| 493 | 
             
                with gr.Row(equal_height=True):
         | 
| 494 | 
             
                    b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
         | 
| 495 | 
             
                    b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
         | 
|  | |
| 497 | 
             
                    b1.click(
         | 
| 498 | 
             
                    choose, 
         | 
| 499 | 
             
                    [img, b1, calibrate_prompts, user_id],
         | 
| 500 | 
            +
                    [img, calibrate_prompts, text],
         | 
| 501 | 
             
                    )
         | 
| 502 | 
             
                    b2.click(
         | 
| 503 | 
             
                    choose, 
         | 
| 504 | 
             
                    [img, b2, calibrate_prompts, user_id],
         | 
| 505 | 
            +
                    [img, calibrate_prompts, text],
         | 
| 506 | 
             
                    )
         | 
| 507 | 
             
                    b3.click(
         | 
| 508 | 
             
                    choose, 
         | 
| 509 | 
             
                    [img, b3, calibrate_prompts, user_id],
         | 
| 510 | 
            +
                    [img, calibrate_prompts, text],
         | 
| 511 | 
             
                    )
         | 
| 512 | 
             
                with gr.Row():
         | 
| 513 | 
             
                    b4 = gr.Button(value='Start')
         | 
|  | |
| 528 | 
             
            log.setLevel(logging.ERROR)
         | 
| 529 |  | 
| 530 | 
             
            scheduler = BackgroundScheduler()
         | 
| 531 | 
            +
            scheduler.add_job(func=background_next_image, trigger="interval", seconds=.5)
         | 
| 532 | 
             
            scheduler.start()
         | 
| 533 |  | 
| 534 | 
             
            #thread = threading.Thread(target=background_next_image,)
         | 
| 535 | 
             
            #thread.start()
         | 
| 536 |  | 
| 537 | 
            +
            # TODO shouldn't call this before gradio launch, yeah?
         | 
| 538 | 
             
            @spaces.GPU()
         | 
| 539 | 
             
            def encode_space(x):
         | 
| 540 | 
             
                im_emb, _ = pipe.encode_image(
         | 
| 541 | 
             
                            image, DEVICE, 1, output_hidden_state
         | 
| 542 | 
             
                        )
         | 
| 543 | 
            +
                        
         | 
| 544 | 
            +
                        
         | 
| 545 | 
            +
                im = torchvision.transforms.ToTensor()(x).unsqueeze(0)
         | 
| 546 | 
            +
                im = torch.nn.functional.interpolate(im, (224, 224))
         | 
| 547 | 
            +
                im = (im - .5) * 2
         | 
| 548 | 
            +
                gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32).mean(1)
         | 
| 549 | 
            +
                        
         | 
| 550 | 
            +
                return im_emb.detach().to('cpu').to(torch.float32), gemb
         | 
| 551 |  | 
| 552 | 
            +
            # prep our calibration videos
         | 
| 553 | 
             
            for im in [
         | 
| 554 | 
             
                './first.mp4',
         | 
| 555 | 
             
                './second.mp4',
         | 
|  | |
| 557 | 
             
                './fourth.mp4',
         | 
| 558 | 
             
                './fifth.mp4',
         | 
| 559 | 
             
                './sixth.mp4',
         | 
| 560 | 
            +
                './seventh.mp4',
         | 
| 561 | 
            +
                './eigth.mp4',
         | 
| 562 | 
            +
                './ninth.mp4',
         | 
| 563 | 
            +
                './tenth.mp4',
         | 
| 564 | 
             
                ]:
         | 
| 565 | 
            +
                tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
         | 
| 566 | 
             
                tmp_df['paths'] = [im]
         | 
| 567 | 
             
                image = list(imageio.imiter(im))
         | 
| 568 | 
             
                image = image[len(image)//2]
         | 
| 569 | 
            +
                im_emb, gemb = encode_space(image)
         | 
| 570 |  | 
| 571 | 
             
                tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
         | 
| 572 | 
            +
                tmp_df['gemb'] = [gemb.detach().to('cpu')]
         | 
| 573 | 
             
                tmp_df['user:rating'] = [{' ': ' '}]
         | 
| 574 | 
             
                prevs_df = pd.concat((prevs_df, tmp_df))
         | 
| 575 |  | 
| 576 |  | 
| 577 | 
            +
            demo.launch(share=True, server_port=8443)
         | 
| 578 | 
            +
             | 
| 579 | 
            +
             | 
    	
        eigth.mp4
    ADDED
    
    | Binary file (47.7 kB). View file | 
|  | 
    	
        ninth.mp4
    ADDED
    
    | Binary file (255 kB). View file | 
|  | 
    	
        seventh.mp4
    ADDED
    
    | Binary file (50 kB). View file | 
|  | 
    	
        tenth.mp4
    ADDED
    
    | Binary file (129 kB). View file | 
|  |