import os

import imageio
import numpy as np
import torch
import random

import spaces

import gradio as gr

import torchvision
import torchvision.transforms as T
from einops import rearrange
from huggingface_hub import hf_hub_download
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
from torchvision.utils import flow_to_image

from diffusers import AutoencoderKL, MotionAdapter, UNet2DConditionModel
from diffusers import DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer

from onlyflow.models.flow_adaptor import FlowEncoder, FlowAdaptor
from onlyflow.models.unet import UNetMotionModel
from onlyflow.pipelines.pipeline_animation_long import FlowCtrlPipeline
from tools.optical_flow import get_optical_flow


def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
    videos = rearrange(videos, "b c t h w -> t b c h w")
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        if rescale:
            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
        x = (x * 255).numpy().astype(np.uint8)
        outputs.append(x)

    os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.mimsave(path, outputs, fps=fps)

css = """
.toolbutton {
    margin-buttom: 0em 0em 0em 0em;
    max-width: 2.5em;
    min-width: 2.5em !important;
    height: 2.5em;
}
"""


class AnimateController:
    def __init__(self):

        # config dirs
        self.basedir                = os.getcwd()
        self.stable_diffusion_dir   = os.path.join(self.basedir, "models", "StableDiffusion")
        self.motion_module_dir      = os.path.join(self.basedir, "models", "Motion_Module")
        self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
        self.savedir                = os.path.join(self.basedir, "samples")
        os.makedirs(self.savedir, exist_ok=True)


        ckpt_path = hf_hub_download('obvious-research/onlyflow', 'weights_fp16.ckpt')
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
        self.flow_encoder_state_dict = ckpt['flow_encoder_state_dict']
        self.attention_processor_state_dict = ckpt['attention_processor_state_dict']

        self.tokenizer             = None
        self.text_encoder          = None
        self.vae                   = None
        self.unet                  = None
        self.motion_adapter        = None

    def update_base_model(self, base_model_id, progress=gr.Progress()):

        progress(0, desc="Starting...")

        self.tokenizer = CLIPTokenizer.from_pretrained(base_model_id, subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained(base_model_id, subfolder="text_encoder")
        self.vae = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae")
        self.unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet")

        return base_model_id

    def update_motion_module(self, motion_module_id, progress=gr.Progress()):
        self.motion_adapter = MotionAdapter.from_pretrained(motion_module_id)

    def animate(
            self,
            id_base_model,
            id_motion_module,
            prompt_textbox_positive,
            prompt_textbox_negative,
            seed_textbox,
            input_video,
            height,
            width,
            flow_scale,
            cfg,
            diffusion_steps,
            temporal_ds,
            ctx_stride
    ):
        #if any([x is None for x in [self.tokenizer, self.text_encoder, self.vae, self.unet, self.motion_adapter]]) or isinstance(self.unet, str):
        self.update_base_model(id_base_model)
        self.update_motion_module(id_motion_module)

        self.unet = UNetMotionModel.from_unet2d(
            self.unet,
            motion_adapter=self.motion_adapter
        )

        self.raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).eval()

        self.flow_encoder = FlowEncoder(
            downscale_factor=8,
            channels=[320, 640, 1280, 1280],
            nums_rb=2,
            ksize=1,
            sk=True,
            use_conv=False,
            compression_factor=1,
            temporal_attention_nhead=8,
            positional_embeddings="sinusoidal",
            num_positional_embeddings=16,
            checkpointing=False
        ).eval()

        self.vae.requires_grad_(False)
        self.text_encoder.requires_grad_(False)
        self.unet.requires_grad_(False)
        self.raft.requires_grad_(False)
        self.flow_encoder.requires_grad_(False)

        self.unet.set_all_attn(
            flow_channels=[320, 640, 1280, 1280],
            add_spatial=False,
            add_temporal=True,
            encoder_only=False,
            query_condition=True,
            key_value_condition=True,
            flow_scale=1.0,
        )

        self.flow_adaptor = FlowAdaptor(self.unet, self.flow_encoder).eval()

        # load the flow encoder weights
        pose_enc_m, pose_enc_u = self.flow_adaptor.flow_encoder.load_state_dict(
            self.flow_encoder_state_dict,
            strict=False
        )
        assert len(pose_enc_m) == 0 and len(pose_enc_u) == 0

        # load the attention processor weights
        _, attention_processor_u = self.flow_adaptor.unet.load_state_dict(
            self.attention_processor_state_dict,
            strict=False
        )
        assert len(attention_processor_u) == 0

        pipeline = FlowCtrlPipeline(
            vae=self.vae,
            text_encoder=self.text_encoder,
            tokenizer=self.tokenizer,
            unet=self.unet,
            motion_adapter=self.motion_adapter,
            flow_encoder=self.flow_encoder,
            scheduler=DDIMScheduler.from_pretrained(id_base_model, subfolder="scheduler"),
        )

        if int(seed_textbox) > 0:
            seed = int(seed_textbox)
        else:
            seed = random.randint(1, int(1e16))

        return animate_diffusion(seed, pipeline, self.raft, input_video, prompt_textbox_positive, prompt_textbox_negative, width, height, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride)

@spaces.GPU(duration=150)
def animate_diffusion(seed, pipeline, raft_model, base_video, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, flow_scale, cfg, diffusion_steps, temporal_ds, context_stride):
    savedir = './samples'
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    generator = torch.Generator(device="cpu")
    generator.manual_seed(seed)

    raft_model = raft_model.to(device)
    pipeline = pipeline.to(device)

    pixel_values = torchvision.io.read_video(base_video, output_format="TCHW", pts_unit='sec')[0][::temporal_ds]
    print("Video loaded, shape:", pixel_values.shape)
    if width_slider/height_slider > pixel_values.shape[3]/pixel_values.shape[2]:
        print("Resizing video to fit width cause input video is not wide enough")
        temp_height = int(width_slider * pixel_values.shape[2]/pixel_values.shape[3])
        temp_width = width_slider
    else:
        print("Resizing video to fit height cause input video is not tall enough")
        temp_height = height_slider
        temp_width = int(height_slider * pixel_values.shape[3]/pixel_values.shape[2])
    print("Resizing video to:", temp_height, temp_width)
    pixel_values = T.Resize((temp_height, temp_width))(pixel_values)
    pixel_values = T.CenterCrop((height_slider, width_slider))(pixel_values)
    pixel_values = T.ConvertImageDtype(torch.float32)(pixel_values)[None, ...].contiguous().to(device)

    save_sample_path_input = os.path.join(savedir, f"input.mp4")
    pixel_values_save = pixel_values[0] * 255
    pixel_values_save = pixel_values_save.cpu()
    pixel_values_save = torch.permute(pixel_values_save, (0, 2, 3, 1))
    torchvision.io.write_video(save_sample_path_input, pixel_values_save, fps=8)
    del pixel_values_save

    print("Video loaded, shape:", pixel_values.shape)
    flow = get_optical_flow(
        raft_model,
        (pixel_values * 2) - 1,
        pixel_values.shape[1] - 1,
        encode_chunk_size=16,
    ).to('cpu')

    sample_flow = (flow_to_image(rearrange(flow[0], "c f h w -> f c h w")))  # N, 3, H, W
    save_sample_path_flow = os.path.join(savedir, f"flow.mp4")
    sample_flow = (sample_flow).cpu().to(torch.uint8).permute(0, 2, 3, 1)
    torchvision.io.write_video(save_sample_path_flow, sample_flow, fps=8)
    del sample_flow

    original_flow_shape = flow.shape
    print("Optical flow computed, shape:", flow.shape)
    if flow.shape[2] < 16:
        print("Video is too short, padding to 16 frames")
        video_length = 16
        n = 16 - flow.shape[2]
        # create a tensor containing the last frame optical flow repeated n times
        to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1)
        flow = torch.cat([flow, to_add], dim=2).to(device)
    elif flow.shape[2] > 16:
        print("Video is too long, enabling windowing")
        print("Enabling model CPU offload")
        pipeline.enable_model_cpu_offload()
        print("Enabling VAE slicing")
        pipeline.enable_vae_slicing()
        print("Enabling VAE tiling")
        pipeline.enable_vae_tiling()

        print("Enabling free noise")
        pipeline.enable_free_noise(
            context_length=16,
            context_stride=context_stride,
        )

        import math

        def find_divisors(n: int):
            """
            Return sorted list of all positive divisors of n.
            Uses a sqrt(n) approach for efficiency.
            """
            divs = set()
            limit = int(math.isqrt(n))
            for i in range(1, limit + 1):
                if n % i == 0:
                    divs.add(i)
                    divs.add(n // i)
            return sorted(divs)

        def multiples_in_range(k: int, min_val: int, max_val: int):
            """
            Return all multiples of k within [min_val, max_val].
            """
            if k == 0:
                return []

            # First multiple of k >= min_val
            start = ((min_val + k - 1) // k) * k
            # Last multiple of k <= max_val
            end = (max_val // k) * k

            return list(range(start, end + 1, k)) if start <= end else []

        def adjust_video_length(original_length: int,
                                context_stride: int,
                                chunk_size: int,
                                temporal_split_size: int) -> int:
            """
            Find the minimal video_length >= original_length satisfying:
              1) (video_length - 16) is divisible by context_stride.
              2) EITHER (2*video_length) is divisible by temporal_split_size
                 OR (2*video_length) is divisible by chunk_size
                 (when 2*video_length is not multiple of temporal_split_size).
            """

            # We start at least at 16 (though in practice original_length likely > 16)
            candidate = max(original_length, 16)

            # We want (candidate - 16) % context_stride == 0
            # so let n be the multiple to step.
            # n is how many times we add `context_stride` beyond 16.
            # This ensures (candidate - 16) is a multiple of context_stride.
            # Then we check the second condition, else keep stepping.

            # If candidate < 16, bump it to 16
            if candidate < 16:
                candidate = 16

            # Make sure we jump to the correct "starting multiple" of context_stride
            offset = (candidate - 16) % context_stride
            if offset != 0:
                candidate += (context_stride - offset)  # jump to the next multiple

            while True:
                # Condition: (candidate - 16) is multiple of context_stride (already enforced by stepping)
                # Check second part:
                # - if (2*candidate) % temporal_split_size == 0, we are good
                # - else we require (2*candidate) % chunk_size == 0
                twoL = 2 * candidate
                if (twoL % temporal_split_size == 0) or (twoL % chunk_size == 0):
                    return candidate

                # Go to next valid candidate
                candidate += context_stride

        def find_valid_configs(original_video_length: int,
                               width: int,
                               height: int,
                               context_stride: int):
            """
            Generate all valid tuples (chunk_size, spatial_split_size, temporal_split_size, video_length)
            subject to the constraints:
              1) chunk_size divides temporal_split_size
              2) chunk_size divides spatial_split_size
              3) chunk_size divides (2 * (width//64) * (height//64))
              4) if (2*video_length) % temporal_split_size != 0, then chunk_size divides (2*video_length)
              5) context_stride divides (video_length - 16)
              6) 128 <= spatial_split_size <= 512
              7) 1 <= temporal_split_size <= 32
              8) 1 <= chunk_size <= 16

            We allow increasing original_video_length minimally if needed to satisfy constraints #4 and #5.
            """

            factor = 2 * (width // 64) * (height // 64)

            # 1) find all possible chunk_size as divisors of factor, in [1..16]
            possible_chunks = [d for d in find_divisors(factor) if 1 <= d <= 32]

            # For storing results
            valid_tuples = []

            for chunk_size in possible_chunks:
                # 2) generate all spatial_split_size in [128..512] that are multiples of chunk_size
                spatial_splits = multiples_in_range(chunk_size, 480, 512)

                # 3) generate all temporal_split_size in [1..32] that are multiples of chunk_size
                temporal_splits = multiples_in_range(chunk_size, 1, 32)

                for ssp in spatial_splits:
                    for tsp in temporal_splits:
                        # 4) & 5) Adjust video_length minimally to satisfy constraints
                        final_length = adjust_video_length(original_video_length,
                                                           context_stride,
                                                           chunk_size,
                                                           tsp)
                        # Now we have a valid (chunk_size, ssp, tsp, final_length)
                        valid_tuples.append((chunk_size, ssp, tsp, final_length))

            return valid_tuples

        def find_pareto_optimal(configs):
            """
            Given a list of tuples (chunk_size, spatial_split_size, temporal_split_size, video_length),
            return the Pareto-optimal subset under the criteria:
              - chunk_size: larger is better
              - spatial_split_size: larger is better
              - temporal_split_size: larger is better
              - video_length: smaller is better
            """

            def dominates(A, B):
                cA, sA, tA, lA = A
                cB, sB, tB, lB = B

                # A dominates B if:
                #   cA >= cB, sA >= sB, tA >= tB, and lA <= lB
                #   AND at least one of these is a strict inequality.

                better_or_equal = (cA >= cB) and (tA >= tB) and (lA <= lB)
                strictly_better = (cA > cB) or (tA > tB) or (lA < lB)

                return better_or_equal and strictly_better

            pareto = []
            for i, cfg_i in enumerate(configs):
                # Check if cfg_i is dominated by any cfg_j
                is_dominated = False
                for j, cfg_j in enumerate(configs):
                    if i == j:
                        continue
                    if dominates(cfg_j, cfg_i):
                        is_dominated = True
                        break
                if not is_dominated:
                    pareto.append(cfg_i)

            return pareto

        print("Finding valid configurations...")
        valid_configs = find_valid_configs(
            original_video_length=flow.shape[2],
            width=width_slider,
            height=height_slider,
            context_stride=context_stride
        )

        print("Found", len(valid_configs), "valid configurations")
        print("Finding Pareto-optimal configurations...")
        pareto_optimal = find_pareto_optimal(valid_configs)

        print("Found", pareto_optimal)

        criteria = lambda cs, sss, tss, vl: cs + tss - 3 * int(abs(flow.shape[2] - vl) / 10)
        pareto_optimal.sort(key=lambda x: criteria(*x), reverse=True)

        print("Found sorted", pareto_optimal)

        solution = pareto_optimal[0]
        chunk_size, spatial_split_size, temporal_split_size, video_length = solution

        n = video_length - original_flow_shape[2]
        to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1)
        flow = torch.cat([flow, to_add], dim=2)

        pipeline.enable_free_noise_split_inference(
            temporal_split_size=temporal_split_size,
            spatial_split_size=spatial_split_size
        )
        pipeline.unet.enable_forward_chunking(chunk_size)

        print("Chunking enabled with chunk size:", chunk_size)
        print("Temporal split size:", temporal_split_size)
        print("Spatial split size:", spatial_split_size)
        print("Context stride:", context_stride)
        print("Temporal downscale:", temporal_ds)
        print("Video length:", video_length)
        print("Flow shape:", flow.shape)
    else:
        print("Video is just right, no padding or windowing needed")
        flow = flow.to(device)
        video_length = flow.shape[2]

    sample_vid = pipeline(
        prompt_textbox,
        negative_prompt=negative_prompt_textbox,
        optical_flow=flow,
        num_inference_steps=diffusion_steps,
        guidance_scale=cfg,
        width=width_slider,
        height=height_slider,
        num_frames=video_length,
        val_scale_factor_temporal=flow_scale,
        generator=generator,
    ).frames[0]

    del flow
    if device == "cuda":
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

    save_sample_path_video = os.path.join(savedir, f"sample.mp4")
    sample_vid = sample_vid[:original_flow_shape[2]] * 255.
    sample_vid = sample_vid.cpu().numpy()
    sample_vid = np.transpose(sample_vid, axes=(0, 2, 3, 1))
    torchvision.io.write_video(save_sample_path_video, sample_vid, fps=8)

    return gr.Video(value=save_sample_path_flow), gr.Video(value=save_sample_path_video)

controller = AnimateController()


def find_closest_ratio(target_ratio):
    width_list = list(reversed(range(256, 1025, 64)))
    height_list = list(reversed(range(256, 1025, 64)))
    ratio_list = [(h, w, w/h) for h in height_list for w in width_list]
    ratio_list.sort(key=lambda x: abs(x[2] - target_ratio))
    ratio_list = list(filter(lambda x: x[2] == ratio_list[0][2], ratio_list))
    ratio_list.sort(key=lambda x: abs(x[0]*x[1] - 512*512))
    return ratio_list[0][:2]


def find_dimension(video):
    import av
    container = av.open(open(video, 'rb'))
    height, width = container.streams.video[0].height, container.streams.video[0].width
    target_ratio = width / height
    return find_closest_ratio(target_ratio)


def ui():
    with gr.Blocks(css=css) as demo:
        gr.Markdown(
            """
            # <p style="text-align:center;">OnlyFlow: Optical Flow based Motion Conditioning for Video Diffusion Models</p>
            Mathis Koroglu, Hugo Caselles-Dupré, Guillaume Jeanneret Sanmiguel, Matthieu Cord<br>
            [Arxiv Report](https://arxiv.org/abs/2411.10501) | [Project Page](https://obvious-research.github.io/onlyflow/) | [Github](https://github.com/obvious-research/onlyflow/)
            """
        )
        gr.Markdown(
            """
            ### Quick Start:

            1. Select desired `Base Model`.
            2. Select `Motion Module`. We recommend trying guoyww/animatediff-motion-adapter-v1-5-3 for the best results.
            3. Provide `Positive Prompt` and `Negative Prompt`. You are encouraged to refer to each model's webpage on HuggingFace Hub or CivitAI to learn how to write prompts for them.
            4. Upload a video to extract optical flow from.
            5. Select a 'Flow Scale' to modulate the input video optical flow conditioning.
            6. Select a 'CFG' and 'Diffusion Steps' to control the quality of the generated video and prompt adherence.
            7. Select a 'Temporal Downsample' to reduce the number of frames in the input video.
            8. If you want to use a custom dimension, check the `Custom Dimension` box and adjust the `Width` and `Height` sliders.
            9. If the video is too long, you can adjust the generation window offset with the `Context Stride` slider.
            10. Click `Generate`, wait for ~1/3 min, and enjoy the result!

            If you have any error concerning GPU limits, please try again later when your ZeroGPU quota is reset, or try with a shorter video.
            Otherwise, you can also duplicate this space and select a custom GPU plan.
            """
        )
        with gr.Row():
            with gr.Column():

                gr.Markdown("# INPUTS")

                with gr.Row(equal_height=True, show_progress=True):
                    base_model = gr.Dropdown(
                        label="Select or type a base model id",
                        choices=[
                        "stable-diffusion-v1-5/stable-diffusion-v1-5",
                        "digiplay/Photon_v1",
                    ],
                        interactive=True,
                        scale=4,
                        allow_custom_value=True,
                        show_label=True
                    )
                    base_model_btn = gr.Button(value="Update", scale=1, size='lg')
                with gr.Row(equal_height=True, show_progress=True):
                    motion_module  = gr.Dropdown(
                        label="Select or type a motion module id",
                        choices=[
                            "guoyww/animatediff-motion-adapter-v1-5-3",
                            "guoyww/animatediff-motion-adapter-v1-5-2"
                        ],
                        interactive=True,
                        scale=4
                    )
                    motion_module_btn = gr.Button(value="Update", scale=1, size='lg')

                base_model_btn.click(fn=controller.update_base_model, inputs=[base_model])
                motion_module_btn.click(fn=controller.update_motion_module, inputs=[motion_module])

                prompt_textbox_positive = gr.Textbox(label="Positive Prompt", lines=3)
                prompt_textbox_negative = gr.Textbox(label="Negative Prompt", lines=2, value="worst quality, low quality, nsfw, logo")

                flow_scale = gr.Slider(label="Flow Scale", value=1.0, minimum=0, maximum=2, step=0.025)
                diffusion_steps = gr.Slider(label="Diffusion Steps", value=25, minimum=0, maximum=100, step=1)
                cfg = gr.Slider(label="CFG", value=7.5, minimum=0, maximum=30, step=0.1)

                temporal_ds = gr.Slider(label="Temporal Downsample", value=1, minimum=1, maximum=30, step=1)

                input_video = gr.Video(label="Input Video", interactive=True)
                ctx_stride = gr.State(12)

                with gr.Accordion("Advanced", open=False):
                    use_custom_dim = gr.Checkbox(label="Custom Dimension", value=False)

                    with gr.Row(equal_height=True):

                        height, width = gr.State(512), gr.State(512)

                        @gr.render(inputs=[use_custom_dim, input_video])
                        def render_custom_dim(use_custom_dim, input_video):
                            if input_video is not None:
                                loc_height, loc_width = find_dimension(input_video)
                            else:
                                loc_height, loc_width = 512, 512
                            slider_width = gr.Slider(label="Width", value=loc_width, minimum=256, maximum=1024,
                                                     step=64, visible=use_custom_dim)
                            slider_height = gr.Slider(label="Height", value=loc_height, minimum=256, maximum=1024,
                                                      step=64, visible=use_custom_dim)

                            slider_width.change(lambda x: x, inputs=[slider_width], outputs=[width])
                            slider_height.change(lambda x: x, inputs=[slider_height], outputs=[height])


                    with gr.Row():
                        @gr.render(inputs=input_video)
                        def render_ctx_stride(input_video):
                            if input_video is not None:
                                video  = open(input_video, 'rb')
                                import av
                                container = av.open(video)
                                num_frames = container.streams.video[0].frames
                                if num_frames > 17:
                                    stride_slider = gr.Slider(label="Context Stride", value=12, minimum=1, maximum=16, step=1)
                                    stride_slider.input(lambda x: x, inputs=[stride_slider], outputs=[ctx_stride])
                                if num_frames > 32:
                                    gr.Warning(f"Video is long ({num_frames} frames), consider using a shorter video, increasing the context stride, or selecting a custom GPU plan.")
                                elif num_frames > 64:
                                    raise gr.Error(f"Video is too long ({num_frames} frames), please use a shorter video, increase the context stride, or select a custom GPU plan. The current parameters won't allow generation on ZeroGPU.")

                    with gr.Row(equal_height=True):
                        seed_textbox = gr.Textbox(label="Seed",  value='-1')

                        seed_button  = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
                        seed_button.click(
                            fn=lambda: random.randint(1, int(1e16)),
                            inputs=[],
                            outputs=[seed_textbox]
                        )

                with gr.Row():
                    clear_btn = gr.ClearButton(value="Clear & Reset", size='lg', variant='secondary', scale=1)
                    generate_button = gr.Button(value="Generate", variant='primary', scale=2, size='lg')

                    clear_btn.add([base_model, motion_module, input_video, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, use_custom_dim, ctx_stride])

            with gr.Column():

                gr.Markdown("# OUTPUTS")

                result_optical_flow = gr.Video(label="Optical Flow", interactive=False)
                result_video = gr.Video(label="Generated Animation", interactive=False)

            inputs  = [base_model, motion_module, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, input_video, height, width, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride]
            outputs = [result_optical_flow, result_video]

            generate_button.click(fn=controller.animate, inputs=inputs, outputs=outputs)

    return demo


if __name__ == "__main__":
    demo = ui()
    demo.queue(max_size=20)
    demo.launch()