import gradio as gr
import whisper
import torch
import os
from diffusers import StableDiffusionPipeline
from typing import BinaryIO, Literal

def get_device() -> Literal['cuda', 'cpu']:
  return "cuda" if torch.cuda.is_available() else "cpu"

def get_token() -> str:
  return os.environ.get("HUGGING_FACE_TOKEN") 

def generate_images(prompt: str, scale: str, iterations: str, seed: str, num_images: str) -> list:
  AUTH_TOKEN = get_token()
  device = get_device()

  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", 
                                                        use_auth_token=AUTH_TOKEN)

  pipe.to(device)
  generator = torch.Generator(device).manual_seed(seed)
  prompt = [prompt] * num_images
  images = pipe(prompt, num_inference_steps = iterations, guidance_scale = scale, generator=generator).images
  
  output_files_names = []
  for id, image in enumerate(images):
    filename = f"output{id}.png"
    image.save(filename)
    output_files_names.append(filename)

  return output_files_names


def transcribe_audio(model_selected :str, audio_input: BinaryIO) -> tuple:

  model = whisper.load_model(model_selected)
  audio_input = whisper.load_audio(audio_input)
  audio_input = whisper.pad_or_trim(audio_input)
  translation_output = ""
  prompt_for_sd = ""
    
  mel = whisper.log_mel_spectrogram(audio_input).to(model.device)

  transcript_options = whisper.DecodingOptions(task="transcribe", fp16 = False)
  transcription = whisper.decode(model, mel, transcript_options)
  prompt_for_sd = transcription.text

  if transcription.language != "en":
    translation_options = whisper.DecodingOptions(task="translate", fp16 = False)
    translation = whisper.decode(model, mel, translation_options)
    translation_output = translation.text
    prompt_for_sd = translation_output

  return transcription.text, translation_output, str(transcription.language).upper(), prompt_for_sd

with gr.Blocks() as demo:
    gr.HTML(
        """
            <div style="text-align: center; max-width: 90%; margin: 0 auto;">
              <div>
                <h1>Whisper App</h1>
              </div>
              <p style="margin-bottom: 10px; font-size: 100%">
                Try Open AI Whisper with a recorded audio to generate images with Stable Diffusion!
              </p>
            </div>
        """
    )
    with gr.Row():
        with gr.Accordion(label="Whisper model selection"):
                with gr.Row():
                    model_selection_radio = gr.Radio(['base','small', 'medium', 'large'], value='medium', interactive=True, label="Model")
    with gr.Tab("Record Prompt"):
      with gr.Row():
        recorded_audio_input = gr.Audio(source="microphone", type="filepath", label="Record your prompt to feed to Stable Diffusion!")
        audio_transcribe_btn = gr.Button("Launch Whisper")
      with gr.Row():
        transcribed_output_box = gr.TextArea(interactive=False, label="Transcription", placeholder="Transcription will appear here")
        translated_output_box = gr.TextArea(interactive=True, label="Translated prompt")
        detected_language_box = gr.Textbox(interactive=False, label="Detected Language")
    with gr.Tab("Stable Diffusion"):
      with gr.Row():
        prompt_box = gr.TextArea(interactive=False, label="Prompt")
      with gr.Row():
        guidance_slider = gr.Slider(2, 15, value = 7, label = 'Guidance Scale', interactive=True)
        iterations_slider = gr.Slider(10, 100, value = 25, step = 1, label = 'Number of Iterations', interactive=True)
        seed_slider = gr.Slider(
                label = "Seed",
                minimum = 0,
                maximum = 2147483647,
                step = 1,
                randomize = True,
                interactive=True)
        num_images_slider = gr.Slider(2, 8, value= 2, label = "Number of Images Asked", interactive=True)
      with gr.Row():
        images_gallery = gr.Gallery(label="Generated Images").style(grid=[2])
      with gr.Row():
        generate_image_btn = gr.Button("Generate Images")
    #####################################################    
    audio_transcribe_btn.click(transcribe_audio,
                              inputs=[
                                        model_selection_radio,
                                        recorded_audio_input
                              ],
                              outputs=[transcribed_output_box,
                                        translated_output_box,
                                        detected_language_box,
                                        prompt_box
                                      ]
                              )
    generate_image_btn.click(generate_images,
                              inputs=[
                                    prompt_box,
                                    guidance_slider,
                                    iterations_slider,
                                    seed_slider,
                                    num_images_slider
                              ],

                              outputs=images_gallery
    )

demo.launch(enable_queue=True, debug=True)