#@title Prepare the Concepts Library to be used
import requests
import os
import gradio as gr
import wget
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from huggingface_hub import HfApi
from transformers import CLIPTextModel, CLIPTokenizer
import html

from share_btn import community_icon_html, loading_icon_html, share_js

api = HfApi()
models_list = api.list_models(author="sd-concepts-library", sort="likes", direction=-1)
models = []

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16).to("cuda")

def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
  loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
  
  # separate token and the embeds
  trained_token = list(loaded_learned_embeds.keys())[0]
  embeds = loaded_learned_embeds[trained_token]

  # cast to dtype of text_encoder
  dtype = text_encoder.get_input_embeddings().weight.dtype
  
  # add the token in tokenizer
  token = token if token is not None else trained_token
  num_added_tokens = tokenizer.add_tokens(token)
  i = 1
  while(num_added_tokens == 0):
    print(f"The tokenizer already contains the token {token}.")
    token = f"{token[:-1]}-{i}>"
    print(f"Attempting to add the token {token}.")
    num_added_tokens = tokenizer.add_tokens(token)
    i+=1
  
  # resize the token embeddings
  text_encoder.resize_token_embeddings(len(tokenizer))
  
  # get the id for the token and assign the embeds
  token_id = tokenizer.convert_tokens_to_ids(token)
  text_encoder.get_input_embeddings().weight.data[token_id] = embeds
  return token

print("Setting up the public library")
for model in models_list:
  model_content = {}
  model_id = model.modelId
  model_content["id"] = model_id
  embeds_url = f"https://huggingface.co/{model_id}/resolve/main/learned_embeds.bin"
  os.makedirs(model_id,exist_ok = True)
  if not os.path.exists(f"{model_id}/learned_embeds.bin"):
    try:
      wget.download(embeds_url, out=model_id)
    except:
      continue
  token_identifier = f"https://huggingface.co/{model_id}/raw/main/token_identifier.txt"
  response = requests.get(token_identifier)
  token_name = response.text
  
  concept_type = f"https://huggingface.co/{model_id}/raw/main/type_of_concept.txt"
  response = requests.get(concept_type)
  concept_name = response.text
  model_content["concept_type"] = concept_name
  images = []
  for i in range(4):
    url = f"https://huggingface.co/{model_id}/resolve/main/concept_images/{i}.jpeg"
    image_download = requests.get(url)
    url_code = image_download.status_code
    if(url_code == 200):
      file = open(f"{model_id}/{i}.jpeg", "wb") ## Creates the file for image
      file.write(image_download.content) ## Saves file content
      file.close()
      images.append(f"{model_id}/{i}.jpeg")
  model_content["images"] = images
  #if token cannot be loaded, skip it
  try:
    learned_token = load_learned_embed_in_clip(f"{model_id}/learned_embeds.bin", pipe.text_encoder, pipe.tokenizer, token_name)
  except: 
    continue
  model_content["token"] = learned_token
  models.append(model_content)
  
#@title Run the app to navigate around [the Library](https://huggingface.co/sd-concepts-library)
#@markdown Click the `Running on public URL:` result to run the Gradio app

SELECT_LABEL = "Select concept"
def assembleHTML(model):
  html_gallery = ''
  html_gallery = html_gallery+'''
  <div class="flex gr-gap gr-form-gap row gap-4 w-full flex-wrap" id="main_row">
  '''
  cap = 0
  for model in models:
    html_gallery = html_gallery+f'''
    <div class="gr-block gr-box relative w-full overflow-hidden border-solid border border-gray-200 gr-panel">
      <div class="output-markdown gr-prose" style="max-width: 100%;">
        <h3>
          <a href="https://huggingface.co/{model["id"]}" target="_blank">
            <code>{html.escape(model["token"])}</code>
          </a>
        </h3>
      </div>
      <div id="gallery" class="gr-block gr-box relative w-full overflow-hidden border-solid border border-gray-200">
        <div class="wrap svelte-17ttdjv opacity-0"></div>
        <div class="absolute left-0 top-0 py-1 px-2 rounded-br-lg shadow-sm text-xs text-gray-500 flex items-center pointer-events-none bg-white z-20 border-b border-r border-gray-100 dark:bg-gray-900">
          <span class="mr-2 h-[12px] w-[12px] opacity-80">
            <svg xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image">
              <rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect>
              <circle cx="8.5" cy="8.5" r="1.5"></circle>
              <polyline points="21 15 16 10 5 21"></polyline>
            </svg>
          </span> {model["concept_type"]}
        </div>
        <div class="overflow-y-auto h-full p-2" style="position: relative;">
          <div class="grid gap-2 grid-cols-2 sm:grid-cols-2 md:grid-cols-2 lg:grid-cols-2 xl:grid-cols-2 2xl:grid-cols-2 svelte-1g9btlg pt-6">
        '''
    for image in model["images"]:
                html_gallery = html_gallery + f'''    
                <button class="gallery-item svelte-1g9btlg">
                  <img alt="" loading="lazy" class="h-full w-full overflow-hidden object-contain" src="file/{image}">
                </button>
                '''
    html_gallery = html_gallery+'''
              </div>
              <iframe style="display: block; position: absolute; top: 0; left: 0; width: 100%; height: 100%; overflow: hidden; border: 0; opacity: 0; pointer-events: none; z-index: -1;" aria-hidden="true" tabindex="-1" src="about:blank"></iframe>
            </div>
          </div>
        </div>
        '''
    cap += 1
    if(cap == 99):
      break  
  html_gallery = html_gallery+'''
  </div>
  '''
  return html_gallery
  
def title_block(title, id):
  return gr.Markdown(f"### [`{title}`](https://huggingface.co/{id})")

def image_block(image_list, concept_type):
  return gr.Gallery(
          label=concept_type, value=image_list, elem_id="gallery"
          ).style(grid=[2], height="auto")

def checkbox_block():
  checkbox = gr.Checkbox(label=SELECT_LABEL).style(container=False)
  return checkbox

def infer(text):
  with autocast("cuda"):
    images_list = pipe(
              [text]*2,
              num_inference_steps=50,
              guidance_scale=7.5
  )
  output_images = []
  for i, image in enumerate(images_list["sample"]):
    output_images.append(image)
  return output_images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)

# idetnical to `infer` function without gradio state updates for share btn
def infer_examples(text):
  with autocast("cuda"):
    images_list = pipe(
              [text]*2,
              num_inference_steps=50,
              guidance_scale=7.5
  )
  output_images = []
  for i, image in enumerate(images_list["sample"]):
    output_images.append(image)
  return output_images
  
css = '''
.gradio-container {font-family: 'IBM Plex Sans', sans-serif}
#top_title{margin-bottom: .5em}
#top_title h2{margin-bottom: 0; text-align: center}
/*#main_row{flex-wrap: wrap; gap: 1em; max-height: 550px; overflow-y: scroll; flex-direction: row}*/
#component-3{height: 760px; overflow: auto}
#component-9{position: sticky;top: 0;align-self: flex-start;}
@media (min-width: 768px){#main_row > div{flex: 1 1 32%; margin-left: 0 !important}}
.gr-prose code::before, .gr-prose code::after {content: "" !important}
::-webkit-scrollbar {width: 10px}
::-webkit-scrollbar-track {background: #f1f1f1}
::-webkit-scrollbar-thumb {background: #888}
::-webkit-scrollbar-thumb:hover {background: #555}
.gr-button {white-space: nowrap}
.gr-button:focus {
  border-color: rgb(147 197 253 / var(--tw-border-opacity));
  outline: none;
  box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
  --tw-border-opacity: 1;
  --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
  --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
  --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
  --tw-ring-opacity: .5;
}
#prompt_input{flex: 1 3 auto; width: auto !important;}
#prompt_area{margin-bottom: .75em}
#prompt_area > div:first-child{flex: 1 3 auto}
.animate-spin {
    animation: spin 1s linear infinite;
}
@keyframes spin {
    from {
        transform: rotate(0deg);
    }
    to {
        transform: rotate(360deg);
    }
}
#share-btn-container {
    display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
    all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
    all: unset;
}
'''
examples = ["a <cat-toy> in <madhubani-art> style", "a <line-art> style mecha robot", "a piano being played by <bonzi>", "Candid photo of <cheburashka>, high resolution photo, trending on artstation, interior design"]

with gr.Blocks(css=css) as demo:
  state = gr.Variable({
        'selected': -1
  })
  state = {}
  def update_state(i):
        global checkbox_states
        if(checkbox_states[i]):
          checkbox_states[i] = False
          state[i] = False
        else:
          state[i] = True
          checkbox_states[i] = True
  gr.HTML('''
  <div style="text-align: center; max-width: 720px; margin: 0 auto;">
              <div
                style="
                  display: inline-flex;
                  align-items: center;
                  gap: 0.8rem;
                  font-size: 1.75rem;
                "
              >
                <svg
                  width="0.65em"
                  height="0.65em"
                  viewBox="0 0 115 115"
                  fill="none"
                  xmlns="http://www.w3.org/2000/svg"
                >
                  <rect width="23" height="23" fill="white"></rect>
                  <rect y="69" width="23" height="23" fill="white"></rect>
                  <rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="46" width="23" height="23" fill="white"></rect>
                  <rect x="46" y="69" width="23" height="23" fill="white"></rect>
                  <rect x="69" width="23" height="23" fill="black"></rect>
                  <rect x="69" y="69" width="23" height="23" fill="black"></rect>
                  <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
                  <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="115" y="46" width="23" height="23" fill="white"></rect>
                  <rect x="115" y="115" width="23" height="23" fill="white"></rect>
                  <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
                  <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="92" y="69" width="23" height="23" fill="white"></rect>
                  <rect x="69" y="46" width="23" height="23" fill="white"></rect>
                  <rect x="69" y="115" width="23" height="23" fill="white"></rect>
                  <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
                  <rect x="46" y="46" width="23" height="23" fill="black"></rect>
                  <rect x="46" y="115" width="23" height="23" fill="black"></rect>
                  <rect x="46" y="69" width="23" height="23" fill="black"></rect>
                  <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
                  <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="23" y="69" width="23" height="23" fill="black"></rect>
                </svg>
                <h1 style="font-weight: 900; margin-bottom: 7px;">
                  Stable Diffusion Conceptualizer
                </h1>
              </div>
              <p style="margin-bottom: 10px; font-size: 94%">
                Navigate through community created concepts and styles via Stable Diffusion Textual Inversion and pick yours for inference.
                To train your own concepts and contribute to the library <a style="text-decoration: underline" href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb">check out this notebook</a>.
              </p>
            </div>
  ''')
  with gr.Row():
        with gr.Column():
          gr.Markdown(f"### Navigate the top 100 Textual-Inversion community trained concepts. Use 600+ from [The Library](https://huggingface.co/sd-concepts-library)")
          with gr.Row():
                  image_blocks = []
                  #for i, model in enumerate(models):
                  with gr.Box().style(border=None):
                    gr.HTML(assembleHTML(models))
                      #title_block(model["token"], model["id"])
                      #image_blocks.append(image_block(model["images"], model["concept_type"]))
        with gr.Column():
          with gr.Box():
                  with gr.Row(elem_id="prompt_area").style(mobile_collapse=False, equal_height=True):
                      text = gr.Textbox(
                          label="Enter your prompt", placeholder="Enter your prompt", show_label=False, max_lines=1, elem_id="prompt_input"
                      ).style(
                          border=(True, False, True, True),
                          rounded=(True, False, False, True),
                          container=False,
                          full_width=False,
                      )
                      btn = gr.Button("Run",elem_id="run_btn").style(
                          margin=False,
                          rounded=(False, True, True, False),
                          full_width=False,
                      )  
                  with gr.Row().style():
                      infer_outputs = gr.Gallery(show_label=False, elem_id="generated-gallery").style(grid=[2], height="512px")
                  with gr.Row():
                    gr.HTML("<p style=\"font-size: 95%;margin-top: .75em\">Prompting may not work as you are used to. <code>objects</code> may need the concept added at the end, <code>styles</code> may work better at the beginning. You can navigate on <a href='https://lexica.art'>lexica.art</a> to get inspired on prompts</p>")
                  with gr.Row():
                    gr.Examples(examples=examples, fn=infer_examples, inputs=[text], outputs=infer_outputs, cache_examples=True)
          with gr.Group(elem_id="share-btn-container"):
            community_icon = gr.HTML(community_icon_html, visible=False)
            loading_icon = gr.HTML(loading_icon_html, visible=False)
            share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
  checkbox_states = {}
  inputs = [text]
  btn.click(
        infer,
        inputs=inputs,
        outputs=[infer_outputs, community_icon, loading_icon, share_button]
    )
  share_button.click(
      None,
      [],
      [],
      _js=share_js,
  )
demo.queue(max_size=20).launch()