import os
import numpy as np
import torch
from einops import rearrange
from imageio import imwrite
from pydantic import validator
import imageio
import tempfile
import gradio as gr

from PIL import Image

from my.utils import (
    tqdm, EventStorage, HeartBeat, EarlyLoopBreak,
    get_event_storage, get_heartbeat, read_stats
)
from my.config import BaseConf, dispatch, optional_load_config
from my.utils.seed import seed_everything

from adapt import ScoreAdapter
from run_img_sampling import SD
from misc import torch_samps_to_imgs
from pose import PoseConfig

from run_nerf import VoxConfig
from voxnerf.utils import every
from voxnerf.render import (
    as_torch_tsrs, rays_from_img, ray_box_intersect, render_ray_bundle
)
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis

from pytorch3d.renderer import PointsRasterizationSettings

from semantic_coding import semantic_coding, semantic_karlo, semantic_sd
from pc_project import point_e, render_depth_from_cloud
device_glb = torch.device("cuda")

def tsr_stats(tsr):
    return {
        "mean": tsr.mean().item(),
        "std": tsr.std().item(),
        "max": tsr.max().item(),
    }

class SJC_3DFuse(BaseConf):
    family:     str = "sd"
    sd:         SD = SD(
        variant="v1",
        prompt="a comfortable bed",
        scale=100.0,
        dir="./results",
        alpha=0.3
    )
    lr:         float = 0.05
    n_steps:    int = 10000
    vox:        VoxConfig = VoxConfig(
        model_type="V_SD", grid_size=100, density_shift=-1.0, c=3,
        blend_bg_texture=False , bg_texture_hw=4,
        bbox_len=1.0
    )
    pose:       PoseConfig = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)

    emptiness_scale:    int = 10
    emptiness_weight:   int = 1e4
    emptiness_step:     float = 0.5
    emptiness_multiplier: float = 20.0

    depth_weight:       int = 0

    var_red:     bool = True
    exp_dir:     str = "./results"
    ti_step:     int = 800
    pt_step:     int = 800
    initial:    str = ""
    random_seed:     int = 0
    semantic_model:     str = "Karlo"
    bg_preprocess:     bool = True
    num_initial_image:     int = 4
    @validator("vox")
    def check_vox(cls, vox_cfg, values):
        family = values['family']
        if family == "sd":
            vox_cfg.c = 4
        return vox_cfg

    def run(self):
        raise Exception("This version is for huggingface demo, which doesn't support CLI. Please visit https://github.com/KU-CVLAB/3DFuse")
        
    def run_gradio(self, points, images):
            cfgs = self.dict()
            initial = cfgs.pop('initial')
            exp_dir=os.path.join(cfgs.pop('exp_dir'),initial)
            
            # Optimization  and pivotal tuning for LoRA
            yield gr.update(value=None), "Tuning for the LoRA layer is starting now. It will take approximately ~10 mins.", gr.update(value=None) 
            state=semantic_coding(images, cfgs,self.sd,initial)
            self.sd.dir=state
            
            # Load SD with Consistency Injection Module
            family = cfgs.pop("family")
            model = getattr(self, family).make()
            print(model.prompt)
            cfgs.pop("vox")
            vox = self.vox.make()
            
            cfgs.pop("pose")
            poser = self.pose.make()
            
            # Score distillation
            yield from fuse_3d(**cfgs, poser=poser,model=model,vox=vox,exp_dir=exp_dir, points=points, is_gradio=True)


def fuse_3d(
    poser, vox, model: ScoreAdapter,
    lr, n_steps, emptiness_scale, emptiness_weight, emptiness_step, emptiness_multiplier,
    depth_weight, var_red, exp_dir, points, is_gradio, **kwargs
):
    del kwargs

    if is_gradio:
        yield gr.update(visible=True), "LoRA layers tuning has just finished. \nScore distillation has started.", gr.update(visible=True)
    assert model.samps_centered()
    _, target_H, target_W = model.data_shape()
    bs = 1
    aabb = vox.aabb.T.cpu().numpy()
    vox = vox.to(device_glb)
    opt = torch.optim.Adamax(vox.opt_params(), lr=lr)

    H, W = poser.H, poser.W
    Ks_, poses_, prompt_prefixes_, angles_list = poser.sample_train(n_steps,device_glb)

    ts = model.us[30:-10]

    fuse = EarlyLoopBreak(5)
    
    raster_settings = PointsRasterizationSettings(
                image_size= 800, 
                radius = 0.02,
                points_per_pixel = 10
            )

    ts = model.us[30:-10]
    calibration_value=0.0
    

        
    with tqdm(total=n_steps) as pbar:
        # HeartBeat(pbar) as hbeat, \
        #     EventStorage(output_dir=os.path.join(exp_dir,'3d')) as metric:

        for i in range(len(poses_)):
            if fuse.on_break():
                break
                
            depth_map = render_depth_from_cloud(points, angles_list[i], raster_settings, device_glb,calibration_value)
            
            y, depth, ws = render_one_view(vox, aabb, H, W, Ks_[i], poses_[i], return_w=True)


            p = f"{prompt_prefixes_[i]} {model.prompt}"
            score_conds = model.prompts_emb([p])

            score_conds['c']=score_conds['c'].repeat(bs,1,1)
            score_conds['uc']=score_conds['uc'].repeat(bs,1,1)

            opt.zero_grad()
            
            with torch.no_grad():
                chosen_σs = np.random.choice(ts, bs, replace=False)
                chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
                chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)


                noise = torch.randn(bs, *y.shape[1:], device=model.device)

                zs = y + chosen_σs * noise

                Ds = model.denoise(zs, chosen_σs,depth_map.unsqueeze(dim=0),**score_conds)

                if var_red:
                    grad = (Ds - y) / chosen_σs
                else:
                    grad = (Ds - zs) / chosen_σs

                grad = grad.mean(0, keepdim=True)
                
            y.backward(-grad, retain_graph=True)

            if depth_weight > 0:
                center_depth = depth[7:-7, 7:-7]
                border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
                center_depth_mean = center_depth.mean()
                depth_diff = center_depth_mean - border_depth_mean
                depth_loss = - torch.log(depth_diff + 1e-12)
                depth_loss = depth_weight * depth_loss
                depth_loss.backward(retain_graph=True)

            emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
            emptiness_loss = emptiness_weight * emptiness_loss
            if emptiness_step * n_steps <= i:
                emptiness_loss *= emptiness_multiplier
            emptiness_loss.backward()
        
            opt.step()

            # metric.put_scalars(**tsr_stats(y))

            if every(pbar, percent=2):
                with torch.no_grad():
                    y = model.decode(y)
                    # vis_routine(metric, y, depth,p,depth_map[0])
                    
                    if is_gradio :
                        yield torch_samps_to_imgs(y)[0], f"Progress: {pbar.n}/{pbar.total} \nAfter the generation is complete, the video results will be displayed below.", gr.update(value=None)
                        
                        
                        

            # metric.step()
            pbar.update()

            pbar.set_description(p)
            # hbeat.beat()

        # metric.put_artifact(
        #     "ckpt", ".pt","", lambda fn: torch.save(vox.state_dict(), fn)
        # )

        # with EventStorage("result"):
        out=evaluate(model, vox, poser)
        
        if is_gradio:    
            yield gr.update(visible=False), f"Generation complete. Please check the video below.", gr.update(value=out)
        else :
            yield None
    
        # metric.step()

        # hbeat.done()

@torch.no_grad()
def evaluate(score_model, vox, poser):
    H, W = poser.H, poser.W
    vox.eval()
    K, poses = poser.sample_test(100)

    fuse = EarlyLoopBreak(5)
    # metric = get_event_storage()
    # hbeat = get_heartbeat()

    aabb = vox.aabb.T.cpu().numpy()
    vox = vox.to(device_glb)

    num_imgs = len(poses)
    frames=[]
    for i in (pbar := tqdm(range(num_imgs))):
        if fuse.on_break():
            break

        pose = poses[i]
        y, depth = render_one_view(vox, aabb, H, W, K, pose)
        y = score_model.decode(y)
        # vis_routine(metric, y, depth,"",None)
        y=torch_samps_to_imgs(y)[0]
        frames.append(y)
        # metric.step()
        # hbeat.beat()

    # metric.flush_history()

    # metric.put_artifact(
    #     "video", ".mp4","",
    #     lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "img")[1])
    # )
    out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
    writer = imageio.get_writer(out_file.name, fps=10)
    for img in frames:
        writer.append_data(img)
    writer.close()
    # metric.step()
    return out_file.name

def render_one_view(vox, aabb, H, W, K, pose, return_w=False):
    N = H * W
    ro, rd = rays_from_img(H, W, K, pose)
    
    ro, rd, t_min, t_max = scene_box_filter_(ro, rd, aabb)

    assert len(ro) == N, "for now all pixels must be in"
    ro, rd, t_min, t_max = as_torch_tsrs(vox.device, ro, rd, t_min, t_max)
    rgbs, depth, weights = render_ray_bundle(vox, ro, rd, t_min, t_max)

    rgbs = rearrange(rgbs, "(h w) c -> 1 c h w", h=H, w=W)
    depth = rearrange(depth, "(h w) 1 -> h w", h=H, w=W)
    if return_w:
        return rgbs, depth, weights
    else:
        return rgbs, depth


def scene_box_filter_(ro, rd, aabb):
    _, t_min, t_max = ray_box_intersect(ro, rd, aabb)
    # do not render what's behind the ray origin
    t_min, t_max = np.maximum(t_min, 0), np.maximum(t_max, 0)
    return ro, rd, t_min, t_max


def vis_routine(metric, y, depth,prompt,depth_map):
    pane = nerf_vis(y, depth, final_H=256)
    im = torch_samps_to_imgs(y)[0]
    
    depth = depth.cpu().numpy()
    metric.put_artifact("view", ".png","",lambda fn: imwrite(fn, pane))
    metric.put_artifact("img", ".png",prompt, lambda fn: imwrite(fn, im))
    if depth_map != None:
        metric.put_artifact("PC_depth", ".png",prompt, lambda fn: imwrite(fn, depth_map.cpu().squeeze()))
    metric.put_artifact("depth", ".npy","",lambda fn: np.save(fn, depth))


if __name__ == "__main__":
    dispatch(SJC_3DFuse)