Spaces:
Runtime error
Runtime error
| from huggingface_hub import HfApi, HfFolder | |
| import os | |
| api = HfApi() | |
| api.set_access_token(os.environ['HF_SECRET']) | |
| folder = HfFolder() | |
| folder.save_token(os.environ['HF_SECRET']) | |
| import math | |
| import time | |
| from threading import Lock | |
| from typing import Any, List | |
| import argparse | |
| import numpy as np | |
| from diffusers import StableDiffusionPipeline | |
| from matplotlib import pyplot as plt | |
| import gradio as gr | |
| import torch | |
| from spacy import displacy | |
| from daam import trace | |
| from daam.utils import set_seed, cached_nlp, auto_autocast | |
| def dependency(text): | |
| doc = cached_nlp(text) | |
| svg = displacy.render(doc, style='dep', options={'compact': True, 'distance': 100}) | |
| return svg | |
| def get_tokenizing_mapping(prompt: str, tokenizer: Any) -> List[List[int]]: | |
| tokens = tokenizer.tokenize(prompt) | |
| merge_idxs = [] | |
| words = [] | |
| curr_idxs = [] | |
| curr_word = '' | |
| for i, token in enumerate(tokens): | |
| curr_idxs.append(i + 1) # because of the [CLS] token | |
| curr_word += token | |
| if '</w>' in token: | |
| merge_idxs.append(curr_idxs) | |
| curr_idxs = [] | |
| words.append(curr_word[:-4]) | |
| curr_word = '' | |
| return merge_idxs, words | |
| def get_args(): | |
| model_id_map = { | |
| 'v1': 'runwayml/stable-diffusion-v1-5', | |
| 'v2-base': 'stabilityai/stable-diffusion-2-base', | |
| 'v2-large': 'stabilityai/stable-diffusion-2', | |
| 'v2-1-base': 'stabilityai/stable-diffusion-2-1-base', | |
| 'v2-1-large': 'stabilityai/stable-diffusion-2-1', | |
| } | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--model', '-m', type=str, default='v2-1-base', choices=list(model_id_map.keys()), help="which diffusion model to use") | |
| parser.add_argument('--seed', '-s', type=int, default=0, help="the random seed") | |
| parser.add_argument('--port', '-p', type=int, default=8080, help="the port to launch the demo") | |
| parser.add_argument('--no-cuda', action='store_true', help="Use CPUs instead of GPUs") | |
| args = parser.parse_args() | |
| args.model = model_id_map[args.model] | |
| return args | |
| def main(): | |
| args = get_args() | |
| plt.switch_backend('agg') | |
| device = "cpu" if args.no_cuda else "cuda" | |
| pipe = StableDiffusionPipeline.from_pretrained(args.model, use_auth_token=True).to(device) | |
| lock = Lock() | |
| def update_dropdown(prompt): | |
| tokens = [''] + [x.text for x in cached_nlp(prompt) if x.pos_ == 'ADJ'] | |
| return gr.Dropdown.update(choices=tokens), dependency(prompt) | |
| def plot(prompt, choice, replaced_word, inf_steps, is_random_seed): | |
| new_prompt = prompt.replace(',', ', ').replace('.', '. ') | |
| if choice: | |
| if not replaced_word: | |
| replaced_word = '.' | |
| new_prompt = [replaced_word if tok.text == choice else tok.text for tok in cached_nlp(prompt)] | |
| new_prompt = ' '.join(new_prompt) | |
| merge_idxs, words = get_tokenizing_mapping(prompt, pipe.tokenizer) | |
| with auto_autocast(dtype=torch.float16), lock: | |
| try: | |
| plt.close('all') | |
| plt.clf() | |
| except: | |
| pass | |
| seed = int(time.time()) if is_random_seed else args.seed | |
| gen = set_seed(seed) | |
| prompt = prompt.replace(',', ', ').replace('.', '. ') # hacky fix to address later | |
| if choice: | |
| new_prompt = new_prompt.replace(',', ', ').replace('.', '. ') # hacky fix to address later | |
| with trace(pipe, save_heads=new_prompt != prompt) as tc: | |
| out = pipe(prompt, num_inference_steps=inf_steps, generator=gen) | |
| image = np.array(out.images[0]) / 255 | |
| heat_map = tc.compute_global_heat_map() | |
| if new_prompt == prompt: | |
| image2 = image | |
| else: | |
| gen = set_seed(seed) | |
| with trace(pipe, load_heads=True) as tc: | |
| out2 = pipe(new_prompt, num_inference_steps=inf_steps, generator=gen) | |
| image2 = np.array(out2.images[0]) / 255 | |
| else: | |
| with trace(pipe, load_heads=False, save_heads=False) as tc: | |
| out = pipe(prompt, num_inference_steps=inf_steps, generator=gen) | |
| image = np.array(out.images[0]) / 255 | |
| heat_map = tc.compute_global_heat_map() | |
| # the main image | |
| if new_prompt == prompt: | |
| fig, ax = plt.subplots() | |
| ax.imshow(image) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| else: | |
| fig, ax = plt.subplots(1, 2) | |
| ax[0].imshow(image) | |
| if choice: | |
| ax[1].imshow(image2) | |
| ax[0].set_title(choice) | |
| ax[0].set_xticks([]) | |
| ax[0].set_yticks([]) | |
| ax[1].set_title(replaced_word) | |
| ax[1].set_xticks([]) | |
| ax[1].set_yticks([]) | |
| # the heat maps | |
| num_cells = 4 | |
| w = int(num_cells * 3.5) | |
| h = math.ceil(len(words) / num_cells * 4.5) | |
| fig_soft, axs_soft = plt.subplots(math.ceil(len(words) / num_cells), num_cells, figsize=(w, h)) | |
| axs_soft = axs_soft.flatten() | |
| with torch.cuda.amp.autocast(dtype=torch.float32): | |
| for idx, parsed_map in enumerate(heat_map.parsed_heat_maps()): | |
| word_ax_soft = axs_soft[idx] | |
| word_ax_soft.set_xticks([]) | |
| word_ax_soft.set_yticks([]) | |
| parsed_map.word_heat_map.plot_overlay(out.images[0], ax=word_ax_soft) | |
| word_ax_soft.set_title(parsed_map.word_heat_map.word, fontsize=12) | |
| for idx in range(len(words), len(axs_soft)): | |
| fig_soft.delaxes(axs_soft[idx]) | |
| return fig, fig_soft | |
| with gr.Blocks(css='scrollbar.css') as demo: | |
| md = '''# DAAM: Attention Maps for Interpreting Stable Diffusion | |
| Check out the paper: [What the DAAM: Interpreting Stable Diffusion Using Cross Attention](http://arxiv.org/abs/2210.04885). | |
| See our (much cleaner) [DAAM codebase](https://github.com/castorini/daam) on GitHub. | |
| ''' | |
| gr.Markdown(md) | |
| with gr.Row(): | |
| with gr.Column(): | |
| dropdown = gr.Dropdown([ | |
| 'An angry, bald man doing research', | |
| 'A bear and a moose', | |
| 'A blue car driving through the city', | |
| 'Monkey walking with hat', | |
| 'Doing research at Comcast Applied AI labs', | |
| 'Professor Jimmy Lin from the modern University of Waterloo', | |
| 'Yann Lecun teaching machine learning on a green chalkboard', | |
| 'A brown cat eating yummy cake for her birthday', | |
| 'A brown fox, a white dog, and a blue wolf in a green field', | |
| ], label='Examples', value='An angry, bald man doing research') | |
| text = gr.Textbox(label='Prompt', value='An angry, bald man doing research') | |
| with gr.Row(): | |
| doc = cached_nlp('An angry, bald man doing research') | |
| tokens = [''] + [x.text for x in doc if x.pos_ == 'ADJ'] | |
| dropdown2 = gr.Dropdown(tokens, label='Adjective to replace', interactive=True) | |
| text2 = gr.Textbox(label='New adjective', value='') | |
| checkbox = gr.Checkbox(value=False, label='Random seed') | |
| slider1 = gr.Slider(15, 30, value=25, interactive=True, step=1, label='Inference steps') | |
| submit_btn = gr.Button('Submit', elem_id='submit-btn') | |
| viz = gr.HTML(dependency('An angry, bald man doing research'), elem_id='viz') | |
| with gr.Column(): | |
| with gr.Tab('Images'): | |
| p0 = gr.Plot() | |
| with gr.Tab('DAAM Maps'): | |
| p1 = gr.Plot() | |
| text.change(fn=update_dropdown, inputs=[text], outputs=[dropdown2, viz]) | |
| submit_btn.click( | |
| fn=plot, | |
| inputs=[text, dropdown2, text2, slider1, checkbox], | |
| outputs=[p0, p1]) | |
| dropdown.change(lambda prompt: prompt, dropdown, text) | |
| dropdown.update() | |
| while True: | |
| try: | |
| demo.launch() | |
| except OSError: | |
| gr.close_all() | |
| except KeyboardInterrupt: | |
| gr.close_all() | |
| break | |
| if __name__ == '__main__': | |
| main() | |