Spaces:
Runtime error
Runtime error
| # TODO unify/merge origin and this | |
| # TODO save & restart from (if it exists) dataframe parquet | |
| import torch | |
| # lol | |
| DEVICE = 'cuda' | |
| STEPS = 8 | |
| output_hidden_state = False | |
| device = "cuda" | |
| dtype = torch.bfloat16 | |
| import spaces | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import logging | |
| import os | |
| import imageio | |
| import gradio as gr | |
| import numpy as np | |
| from sklearn.svm import LinearSVC | |
| import pandas as pd | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| import sched | |
| import threading | |
| import random | |
| import time | |
| from PIL import Image | |
| from safety_checker_improved import maybe_nsfw | |
| torch.set_grad_enabled(False) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id', 'text', 'gemb']) | |
| import spaces | |
| start_time = time.time() | |
| prompt_list = [p for p in list(set( | |
| pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] | |
| ####################### Setup Model | |
| from diffusers import EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL, AutoPipelineForText2Image | |
| from transformers import CLIPTextModel | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from PIL import Image | |
| from transformers import CLIPVisionModelWithProjection | |
| import uuid | |
| import av | |
| def write_video(file_name, images, fps=16): | |
| container = av.open(file_name, mode="w") | |
| stream = container.add_stream("h264", rate=fps) | |
| # stream.options = {'preset': 'faster'} | |
| stream.thread_count = 1 | |
| stream.width = 512 | |
| stream.height = 512 | |
| stream.pix_fmt = "yuv420p" | |
| for img in images: | |
| img = np.array(img) | |
| img = np.round(img).astype(np.uint8) | |
| frame = av.VideoFrame.from_ndarray(img, format="rgb24") | |
| for packet in stream.encode(frame): | |
| container.mux(packet) | |
| # Flush stream | |
| for packet in stream.encode(): | |
| container.mux(packet) | |
| # Close the file | |
| container.close() | |
| def imio_write_video(file_name, images, fps=15): | |
| writer = imageio.get_writer(file_name, fps=fps) | |
| for im in images: | |
| writer.append_data(np.array(im)) | |
| writer.close() | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype, | |
| device_map='cuda') | |
| #vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype) | |
| # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype) | |
| # vae = compile_unet(vae, config=config) | |
| #finetune_path = '''/home/ryn_mote/Misc/finetune-sd1.5/dreambooth-model best''''' | |
| #unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype) | |
| #text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype) | |
| #rynmurdock/Sea_Claws | |
| model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
| sdxl_lightening = "ByteDance/SDXL-Lightning" | |
| ckpt = "sdxl_lightning_8step_unet.safetensors" | |
| unet = UNet2DConditionModel.from_config(model_id, subfolder="unet", low_cpu_mem_usage=True, device_map=DEVICE).to(torch.float16) | |
| unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt))) | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=DEVICE) | |
| pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder, low_cpu_mem_usage=True) | |
| pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin'))) | |
| pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin") | |
| pipe.register_modules(image_encoder = image_encoder) | |
| pipe.set_ip_adapter_scale(0.8) | |
| #pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16, low_cpu_mem_usage=True) | |
| pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") | |
| pipe.to(device=DEVICE).to(dtype=dtype) | |
| output_hidden_state = False | |
| # pipe.unet.fuse_qkv_projections() | |
| #pipe.enable_free_init(method="gaussian", use_fast_sampling=True) | |
| #pipe.unet = torch.compile(pipe.unet) | |
| #pipe.vae = torch.compile(pipe.vae) | |
| def generate_gpu(in_im_embs, prompt='the scene'): | |
| with torch.no_grad(): | |
| print(prompt) | |
| in_im_embs = in_im_embs.to('cuda').unsqueeze(0) | |
| output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS) | |
| im_emb, _ = pipe.encode_image( | |
| output.images[0], 'cuda', 1, output_hidden_state | |
| ) | |
| im_emb = im_emb.detach().to('cpu').to(torch.float32) | |
| return output, im_emb | |
| def generate(in_im_embs, prompt='the scene'): | |
| output, im_emb = generate_gpu(in_im_embs, prompt) | |
| nsfw = maybe_nsfw(output.images[0]) | |
| name = str(uuid.uuid4()).replace("-", "") | |
| path = f"/tmp/{name}.png" | |
| if nsfw: | |
| gr.Warning("NSFW content detected.") | |
| # TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring. | |
| return None, im_emb | |
| output.images[0].save(path) | |
| return path, im_emb | |
| ####################### | |
| def solver(embs, ys): | |
| print('ys:', ys,'EMBS:', embs.shape, embs) | |
| ys = torch.tensor(ys).to('cpu', dtype=torch.float32).squeeze().unsqueeze(1) | |
| sol = LinearSVC(class_weight='balanced').fit(np.array(embs), np.array(torch.tensor(ys).float() * 2 - 1)).coef_ | |
| return torch.tensor(sol).to('cpu', dtype=torch.float32) | |
| def get_user_emb(embs, ys): | |
| # sample only as many negatives as there are positives | |
| indices = range(len(ys)) | |
| pos_indices = [i for i in indices if ys[i] > .5] | |
| neg_indices = [i for i in indices if ys[i] <= .5] | |
| mini = min(len(pos_indices), len(neg_indices)) | |
| if len(ys) > 20: # drop earliest of whichever of neg or pos is most abundant | |
| if len(pos_indices) > len(neg_indices): | |
| ind = pos_indices[0] | |
| else: | |
| ind = neg_indices[0] | |
| ys.pop(ind) | |
| embs.pop(ind) | |
| print('Dropping at 20') | |
| if mini < 1: | |
| feature_embs = torch.stack([torch.randn(1280), torch.randn(1280)]) | |
| ys_t = [0, 1] | |
| print('Not enough ratings.') | |
| else: | |
| indices = range(len(ys)) | |
| ys_t = [ys[i] for i in indices] | |
| feature_embs = torch.stack([embs[e].detach().cpu() for e in indices]).squeeze() | |
| # scaler = preprocessing.StandardScaler().fit(feature_embs) | |
| # feature_embs = scaler.transform(feature_embs) | |
| # ys_t = ys | |
| print(np.array(feature_embs).shape, np.array(ys_t).shape) | |
| sol = solver(feature_embs.squeeze(), ys_t) | |
| dif = torch.tensor(sol, dtype=dtype).to(device) | |
| # could j have a base vector of a black image | |
| latest_pos = (random.sample([feature_embs[i] for i in range(len(ys_t)) if ys_t[i] > .5], 1)[0]).to(device, dtype) | |
| dif = ((dif / dif.std()) * latest_pos.std()) | |
| sol = (1*latest_pos + 3*dif)/4 | |
| return sol | |
| def pluck_img(user_id, user_emb): | |
| not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]] | |
| while len(not_rated_rows) == 0: | |
| not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]] | |
| time.sleep(.1) | |
| # TODO optimize this lol | |
| best_sim = -100000 | |
| for i in not_rated_rows.iterrows(): | |
| # TODO sloppy .to but it is 3am. | |
| sim = torch.cosine_similarity(i[1]['embeddings'].detach().to('cpu'), user_emb.detach().to('cpu')) | |
| if sim > best_sim: | |
| best_sim = sim | |
| best_row = i[1] | |
| img = best_row['paths'] | |
| return img | |
| def background_next_image(): | |
| global prevs_df | |
| # only let it get N (maybe 3) ahead of the user | |
| #not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]] | |
| rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]] | |
| if len(rated_rows) < 4: | |
| time.sleep(.1) | |
| # not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]] | |
| return | |
| user_id_list = set(rated_rows['latest_user_to_rate'].to_list()) | |
| for uid in user_id_list: | |
| rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is not None for i in prevs_df.iterrows()]] | |
| not_rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is None for i in prevs_df.iterrows()]] | |
| # we need to intersect not_rated_rows from this user's embed > 7. Just add a new column on which user_id spawned the | |
| # media. | |
| unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]] | |
| rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]] | |
| # we pop previous ratings if there are > n | |
| if len(rated_from_user) >= 15: | |
| oldest = rated_from_user.iloc[0]['paths'] | |
| prevs_df = prevs_df[prevs_df['paths'] != oldest] | |
| # we don't compute more after n are in the queue for them | |
| if len(unrated_from_user) >= 10: | |
| continue | |
| if len(rated_rows) < 5: | |
| continue | |
| embs, ys = pluck_embs_ys(uid) | |
| user_emb = get_user_emb(embs, [y[1] for y in ys]) | |
| global glob_idx | |
| glob_idx += 1 | |
| if glob_idx >= (len(prompt_list)-1): | |
| glob_idx = 0 | |
| if glob_idx % 7 == 0: | |
| text = prompt_list[glob_idx] | |
| else: | |
| text = 'an image' | |
| img, embs = generate(user_emb, text) | |
| if img: | |
| tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb']) | |
| tmp_df['paths'] = [img] | |
| tmp_df['embeddings'] = [embs] | |
| tmp_df['user:rating'] = [{' ': ' '}] | |
| tmp_df['from_user_id'] = [uid] | |
| tmp_df['text'] = [text] | |
| prevs_df = pd.concat((prevs_df, tmp_df)) | |
| # we can free up storage by deleting the image | |
| if len(prevs_df) > 500: | |
| oldest_path = prevs_df.iloc[6]['paths'] | |
| if os.path.isfile(oldest_path): | |
| os.remove(oldest_path) | |
| else: | |
| # If it fails, inform the user. | |
| print("Error: %s file not found" % oldest_path) | |
| # only keep 50 images & embeddings & ips, then remove oldest besides calibrating | |
| prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:])) | |
| def pluck_embs_ys(user_id): | |
| rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]] | |
| #not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]] | |
| #while len(not_rated_rows) == 0: | |
| # not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]] | |
| # rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]] | |
| # time.sleep(.01) | |
| # print('current user has 0 not_rated_rows') | |
| embs = rated_rows['embeddings'].to_list() | |
| ys = [i[user_id] for i in rated_rows['user:rating'].to_list()] | |
| return embs, ys | |
| def next_image(calibrate_prompts, user_id): | |
| with torch.no_grad(): | |
| if len(calibrate_prompts) > 0: | |
| cal_video = calibrate_prompts.pop(0) | |
| image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0] | |
| return image, calibrate_prompts, | |
| else: | |
| embs, ys = pluck_embs_ys(user_id) | |
| ys_here = [y[1] for y in ys] | |
| user_emb = get_user_emb(embs, ys_here) | |
| image = pluck_img(user_id, user_emb) | |
| return image, calibrate_prompts, | |
| def start(_, calibrate_prompts, user_id, request: gr.Request): | |
| user_id = int(str(time.time())[-7:].replace('.', '')) | |
| image, calibrate_prompts = next_image(calibrate_prompts, user_id) | |
| return [ | |
| gr.Button(value='π', interactive=True), | |
| gr.Button(value='Neither (Space)', interactive=True, visible=False), | |
| gr.Button(value='π', interactive=True), | |
| gr.Button(value='Start', interactive=False), | |
| gr.Button(value='π Content', interactive=True, visible=False), | |
| gr.Button(value='π Style', interactive=True, visible=False), | |
| image, | |
| calibrate_prompts, | |
| user_id, | |
| ] | |
| def choose(img, choice, calibrate_prompts, user_id, request: gr.Request): | |
| global prevs_df | |
| if choice == 'π': | |
| choice = [1, 1] | |
| elif choice == 'Neither (Space)': | |
| img, calibrate_prompts, = next_image(calibrate_prompts, user_id) | |
| return img, calibrate_prompts, | |
| elif choice == 'π': | |
| choice = [0, 0] | |
| elif choice == 'π Style': | |
| choice = [0, 1] | |
| elif choice == 'π Content': | |
| choice = [1, 0] | |
| else: | |
| assert False, f'choice is {choice}' | |
| # if we detected NSFW, leave that area of latent space regardless of how they rated chosen. | |
| # TODO skip allowing rating & just continue | |
| if img is None: | |
| print('NSFW -- choice is disliked') | |
| choice = [0, 0] | |
| row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()] | |
| # if it's still in the dataframe, add the choice | |
| if len(prevs_df.loc[row_mask, 'user:rating']) > 0: | |
| prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice | |
| prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id] | |
| img, calibrate_prompts, = next_image(calibrate_prompts, user_id) | |
| return img, calibrate_prompts | |
| css = '''.gradio-container{max-width: 700px !important} | |
| #description{text-align: center} | |
| #description h1, #description h3{display: block} | |
| #description p{margin-top: 0} | |
| .fade-in-out {animation: fadeInOut 3s forwards} | |
| @keyframes fadeInOut { | |
| 0% { | |
| background: var(--bg-color); | |
| } | |
| 100% { | |
| background: var(--button-secondary-background-fill); | |
| } | |
| } | |
| ''' | |
| js_head = ''' | |
| <script> | |
| document.addEventListener('keydown', function(event) { | |
| if (event.key === 'a' || event.key === 'A') { | |
| // Trigger click on 'dislike' if 'A' is pressed | |
| document.getElementById('dislike').click(); | |
| } else if (event.key === ' ' || event.keyCode === 32) { | |
| // Trigger click on 'neither' if Spacebar is pressed | |
| document.getElementById('neither').click(); | |
| } else if (event.key === 'l' || event.key === 'L') { | |
| // Trigger click on 'like' if 'L' is pressed | |
| document.getElementById('like').click(); | |
| } | |
| }); | |
| function fadeInOut(button, color) { | |
| button.style.setProperty('--bg-color', color); | |
| button.classList.remove('fade-in-out'); | |
| void button.offsetWidth; // This line forces a repaint by accessing a DOM property | |
| button.classList.add('fade-in-out'); | |
| button.addEventListener('animationend', () => { | |
| button.classList.remove('fade-in-out'); // Reset the animation state | |
| }, {once: true}); | |
| } | |
| document.body.addEventListener('click', function(event) { | |
| const target = event.target; | |
| if (target.id === 'dislike') { | |
| fadeInOut(target, '#ff1717'); | |
| } else if (target.id === 'like') { | |
| fadeInOut(target, '#006500'); | |
| } else if (target.id === 'neither') { | |
| fadeInOut(target, '#cccccc'); | |
| } | |
| }); | |
| </script> | |
| ''' | |
| with gr.Blocks(css=css, head=js_head) as demo: | |
| gr.Markdown('''# Blue Tigers | |
| ### Generative Recommenders for Exporation of Video | |
| Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). | |
| ''', elem_id="description") | |
| user_id = gr.State() | |
| # calibration videos -- this is a misnomer now :D | |
| calibrate_prompts = gr.State([ | |
| './first.png', | |
| './second.png', | |
| './sixth.png', | |
| './fifth.png', | |
| './fourth.png', | |
| ]) | |
| def l(): | |
| return None | |
| with gr.Row(elem_id='output-image'): | |
| img = gr.Image( | |
| label='Lightning', | |
| # autoplay=True, | |
| interactive=False, | |
| # height=512, | |
| # width=512, | |
| #include_audio=False, | |
| elem_id="video_output", | |
| type='filepath', | |
| ) | |
| #img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''') | |
| with gr.Row(equal_height=True): | |
| b3 = gr.Button(value='π', interactive=False, elem_id="dislike") | |
| b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False) | |
| b1 = gr.Button(value='π', interactive=False, elem_id="like") | |
| with gr.Row(equal_height=True): | |
| b6 = gr.Button(value='π Style', interactive=False, elem_id="dislike like", visible=False) | |
| b5 = gr.Button(value='π Content', interactive=False, elem_id="like dislike", visible=False) | |
| b1.click( | |
| choose, | |
| [img, b1, calibrate_prompts, user_id], | |
| [img, calibrate_prompts, ], | |
| ) | |
| b2.click( | |
| choose, | |
| [img, b2, calibrate_prompts, user_id], | |
| [img, calibrate_prompts, ], | |
| ) | |
| b3.click( | |
| choose, | |
| [img, b3, calibrate_prompts, user_id], | |
| [img, calibrate_prompts, ], | |
| ) | |
| b5.click( | |
| choose, | |
| [img, b5, calibrate_prompts, user_id], | |
| [img, calibrate_prompts, ], | |
| ) | |
| b6.click( | |
| choose, | |
| [img, b6, calibrate_prompts, user_id], | |
| [img, calibrate_prompts, ], | |
| ) | |
| with gr.Row(): | |
| b4 = gr.Button(value='Start') | |
| b4.click(start, | |
| [b4, calibrate_prompts, user_id], | |
| [b1, b2, b3, b4, b5, b6, img, calibrate_prompts, user_id, ] | |
| ) | |
| with gr.Row(): | |
| html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several videos and then roam. </ div><br><br><br> | |
| <div style='text-align:center; font-size:14px'>Note that while the AnimateLCM model with NSFW filtering is unlikely to produce NSFW images, this may still occur, and users should avoid NSFW content when rating. | |
| </ div> | |
| <br><br> | |
| <div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback. | |
| </ div>''') | |
| # TODO quiet logging | |
| scheduler = BackgroundScheduler() | |
| scheduler.add_job(func=background_next_image, trigger="interval", seconds=.2) | |
| scheduler.start() | |
| #thread = threading.Thread(target=background_next_image,) | |
| #thread.start() | |
| # TODO shouldn't call this before gradio launch, yeah? | |
| def encode_space(x): | |
| im_emb, _ = pipe.encode_image( | |
| image, DEVICE, 1, output_hidden_state | |
| ) | |
| return im_emb.detach().to('cpu').to(torch.float32) | |
| # prep our calibration videos | |
| for im, txt in [ # TODO more movement | |
| ('./first.png', 'describe the scene: a sketch'), | |
| ('./second.png', 'describe the scene: omens in the suburbs'), | |
| ('./sixth.png', 'describe the scene: geometric abstract art of a windmill'), | |
| ('./fifth.png', 'describe the scene: memento mori'), | |
| ('./fourth.png', 'describe the scene: a green plate with anespresso'), | |
| ]: | |
| tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb']) | |
| tmp_df['paths'] = [im] | |
| image = list(imageio.imiter(im)) | |
| image = image[len(image)//2] | |
| im_emb = encode_space(image) | |
| tmp_df['embeddings'] = [im_emb.detach().to('cpu')] | |
| tmp_df['user:rating'] = [{' ': ' '}] | |
| tmp_df['text'] = [txt] | |
| prevs_df = pd.concat((prevs_df, tmp_df)) | |
| glob_idx = 0 | |
| demo.launch(share=True,) | |