import subprocess
from PIL import Image,ImageOps,ImageDraw,ImageFilter
import json
import os
import time
import io
from mp_utils import get_pixel_cordinate_list,extract_landmark,get_pixel_cordinate,get_normalized_xyz
from glibvision.draw_utils import points_to_box,box_to_xy,plus_point,calculate_distance

import numpy as np
from glibvision.pil_utils import fill_points,create_color_image,draw_box

import glibvision.pil_utils

from gradio_utils import save_image,save_buffer,clear_old_files ,read_file


import math
import mp_triangles


from glibvision.cv2_utils import pil_to_bgr_image
from glibvision.cv2_utils import create_color_image as cv2_create_color_image
import cv2
#TODO move to CV2

# i'm not sure this is fast
def apply_affine_transformation_to_triangle_add(src_tri, dst_tri, src_img, dst_img):
    src_tri_np = np.float32(src_tri)
    dst_tri_np = np.float32(dst_tri)

    h_dst, w_dst = dst_img.shape[:2]

    M = cv2.getAffineTransform(src_tri_np, dst_tri_np)

    dst_mask = np.zeros((h_dst, w_dst), dtype=np.uint8)
    cv2.fillPoly(dst_mask, [np.int32(dst_tri)], 255)

    transformed = cv2.warpAffine(src_img, M, (w_dst, h_dst))

    transformed = transformed * (dst_mask[:, :, np.newaxis] / 255).astype(np.uint8)  
    dst_background = dst_img * (1 - (dst_mask[:, :, np.newaxis] / 255)).astype(np.uint8)
    dst_img = transformed + dst_background

    return dst_img

def apply_affine_transformation_to_triangle_add(src_tri, dst_tri, src_img, dst_img):
    src_tri_np = np.float32(src_tri)
    dst_tri_np = np.float32(dst_tri)

    assert src_tri_np.shape == (3, 2), f"src_tri_np の形状が不正 {src_tri_np.shape}"
    assert dst_tri_np.shape == (3, 2), f"dst_tri_np の形状が不正 {dst_tri_np.shape}"


    # 透視変換行列の計算
    M = cv2.getAffineTransform(src_tri_np, dst_tri_np)

    # 画像のサイズ
    h_src, w_src = src_img.shape[:2]
    h_dst, w_dst = dst_img.shape[:2]

    # 元画像から三角形領域を切り抜くマスク生成
    #src_mask = np.zeros((h_src, w_src), dtype=np.uint8)
    #cv2.fillPoly(src_mask, [np.int32(src_tri)], 255)

    # Not 元画像の三角形領域のみをマスクで抽出
    src_triangle = src_img #cv2.bitwise_and(src_img, src_img, mask=src_mask)

    # 変換行列を使って元画像の三角形領域を目標画像のサイズへ変換
   
    transformed = cv2.warpAffine(src_triangle, M, (w_dst, h_dst))
    #print(f"dst_img={dst_img.shape}")
    #print(f"transformed={transformed.shape}")
    # 変換後のマスクの生成
    dst_mask = np.zeros((h_dst, w_dst), dtype=np.uint8)
    cv2.fillPoly(dst_mask, [np.int32(dst_tri)], 255)
    transformed = cv2.bitwise_and(transformed, transformed, mask=dst_mask)

    # 目標画像のマスク領域をクリアするためにデストのインバートマスクを作成
    dst_mask_inv = cv2.bitwise_not(dst_mask)

    # 目標画像のマスク部分をクリア
    dst_background = cv2.bitwise_and(dst_img, dst_img, mask=dst_mask_inv)

    # 変換された元画像の三角形部分と目標画像の背景部分を合成
    dst_img = cv2.add(dst_background, transformed)

    return dst_img

# TODO move PIL
def process_create_webp(images,duration=100, loop=0,quality=85):
    frames = []
    for image_file in images:
        frames.append(image_file)
    
    output_buffer = io.BytesIO()
    frames[0].save(output_buffer, 
                   save_all=True, 
                   append_images=frames[1:], 
                   duration=duration, 
                   loop=loop, 
                   format='WebP',
                   quality=quality
                   )

    return output_buffer.getvalue()
# TODO move numpy
def rotate_point_euler(point, angles,order="xyz"):
  """
  オイラー角を使って3Dポイントを回転させる関数

  Args:
    point: 回転させる3Dポイント (x, y, z)
    angles: 各軸周りの回転角度 (rx, ry, rz) [ラジアン]

  Returns:
    回転後の3Dポイント (x', y', z')
  """

  rx, ry, rz = angles
  point = np.array(point)

  # X軸周りの回転
  Rx = np.array([
      [1, 0, 0],
      [0, np.cos(rx), -np.sin(rx)],
      [0, np.sin(rx), np.cos(rx)]
  ])

  # Y軸周りの回転
  Ry = np.array([
      [np.cos(ry), 0, np.sin(ry)],
      [0, 1, 0],
      [-np.sin(ry), 0, np.cos(ry)]
  ])

  # Z軸周りの回転
  Rz = np.array([
      [np.cos(rz), -np.sin(rz), 0],
      [np.sin(rz), np.cos(rz), 0],
      [0, 0, 1]
  ])

  # 回転行列の合成 (Z軸 -> Y軸 -> X軸 の順で回転)
  order = order.lower()
  if order == "xyz":
    R = Rx @ Ry @ Rz
  elif order == "xzy":
    R = Rx @ Rz @ Ry
  elif order == "yxz":
    R = Ry @ Rx @ Rz
  elif order == "yzx":
    R = Ry @ Rz @ Rx
  elif order == "zxy":
    R = Rz @ Rx @ Ry
  else:
    R = Rz @ Ry @ Rx
      
 

  # 回転後のポイントを計算
  rotated_point = R @ point

  return rotated_point


def process_face_mesh_rotation(image,draw_type,animation,center_scaleup,animation_direction,rotation_order,euler_x,euler_y,euler_z):
    
    offset_x = 0
    offset_y = 0
    scale_up = 1.0

    if image == None:
        #  Box for no Image Case
        image_width = 512
        image_height = 512
        #image = create_color_image(image_width,image_height,(0,0,0))
        points = [(-0.25,-0.25,0),(0.25,-0.25,0),
            (0.25,0.25,0),(-0.25,0.25,0)
            ]
        normalized_center_point = [0.5,0.5]
    else:
        image_width = image.width
        image_height = image.height
        mp_image,face_landmarker_result = extract_landmark(image)
        # cordinate eyes
        # cordinate all
        landmark_points = [get_normalized_xyz(face_landmarker_result.face_landmarks,i) for i in range(0,468)]
        # do centering
        normalized_center_point = landmark_points[4]
        normalized_top_point = landmark_points[10]
        normalized_bottom_point = landmark_points[152]
        

        offset_x = normalized_center_point[0]
        offset_y = normalized_center_point[1]
        
        points = [[point[0]-offset_x,point[1]-offset_y,point[2]] for point in landmark_points]


    # split xy-cordinate and z-depth
    def split_points_xy_z(points,width,height,center_x,center_y):
        xys = []
        zs = []
        for point in points:
            xys.append(
                   [
                        point[0]*width*scale_up+center_x,
                        point[1]*height*scale_up+center_y
                   ]
              )
            zs.append(point[2])
        return xys,zs
    

    def create_triangle_image(points,width,height,center_x,center_y,line_color=(255,255,255),fill_color=None):
        print(center_x,center_y)
        cordinates,angled_depth = split_points_xy_z(points,width,height,center_x,center_y)
       
        img = create_color_image(width,height,(0,0,0))
        draw = ImageDraw.Draw(img)
        triangles = mp_triangles.mesh_triangle_indices
        triangles.sort(key=lambda triangle: sum(angled_depth[index] for index in triangle) / len(triangle)
                       ,reverse=True)
        for triangle in triangles:
            triangle_cordinates = [cordinates[index] for index in triangle]
            glibvision.pil_utils.image_draw_points(draw,triangle_cordinates,line_color,fill_color)
        return img
    
    def create_texture_image(image,origin_points,angled_points,width,height,center_x,center_y,line_color=(255,255,255),fill_color=None):
        cv2_image = pil_to_bgr_image(image)
        #cv2.imwrite("tmp.jpg",cv2_image)
        original_cordinates = []
        cordinates,angled_depth = split_points_xy_z(angled_points,width,height,center_x,center_y)
        # original point need offset
        for point in origin_points:
            original_cordinates.append(
                   [
                        (point[0]+offset_x)*width,
                        (point[1]+offset_y)*height
                   ]
            )
            
        cv2_bg_img = cv2_create_color_image(cv2_image,(0,0,0))
        
        triangles = mp_triangles.mesh_triangle_indices
        triangles.sort(key=lambda triangle: sum(angled_depth[index] for index in triangle) / len(triangle)
                       ,reverse=True)
           
        for triangle in triangles:
                triangle_cordinates = [cordinates[index] for index in triangle]
                origin_triangle_cordinates = [original_cordinates[index] for index in triangle]
                
                cv2_bg_img=apply_affine_transformation_to_triangle_add(origin_triangle_cordinates,triangle_cordinates,cv2_image,cv2_bg_img)
            
        return Image.fromarray(cv2.cvtColor(cv2_bg_img, cv2.COLOR_RGB2BGR))
    
    def create_point_image(points,width,height,center_x,center_y):
        cordinates,_ = split_points_xy_z(points,width,height,center_x,center_y)
        img = create_color_image(width,height,(0,0,0))
        glibvision.pil_utils.draw_points(img,cordinates,None,None,3,(255,0,0),3)
       
        return img
    
    def angled_points(points,angles,order="xyz"):
        angled_cordinates = []
        for point in points:
            rotated_np_point = rotate_point_euler(point,angles,order)
            angled_cordinates.append(
                   [
                       rotated_np_point[0],
                       rotated_np_point[1],rotated_np_point[2]
                   ]
              )
        return angled_cordinates
    

    frames = []


    #frames.append(create_point_image(points))
    frame_duration=100
    start_angle=0
    end_angle=360
    step_angle=10
    
    if draw_type == "Image":
        start_angle=-90
        end_angle=90
        step_angle=30

    if not animation:
        start_angle=0
        end_angle=0
        step_angle=360
    if image == None:
        draw_type="Dot"

    
    if center_scaleup:
        top_distance = calculate_distance(normalized_center_point,normalized_top_point)
        bottom_distance = calculate_distance(normalized_center_point,normalized_bottom_point)
        distance = top_distance if top_distance>bottom_distance else bottom_distance
        #small_size = image_width if image_width<image_height else image_height
        
        scale_up = 0.45 / distance #half - margin
        print(scale_up)
        face_center_x = int(0.5* image_width)#half
        face_center_y = int(0.5* image_height)
    else:
        scale_up = 1.0
        face_center_x = int(normalized_center_point[0]* image_width)
        face_center_y = int(normalized_center_point[1]* image_height)
        

    if animation:
        for i in range(start_angle,end_angle,step_angle):
            if animation_direction == "X":
                angles = [math.radians(i),0,0]
            elif animation_direction == "Y":
                angles = [0,math.radians(i),0]
            else:
                angles = [0,0,math.radians(i)]
           
            if draw_type == "Dot":
                frames.append(create_point_image(angled_points(points,angles),image_width,image_height,face_center_x,face_center_y))
            elif draw_type == "Line":
                frames.append(create_triangle_image(angled_points(points,angles),image_width,image_height,face_center_x,face_center_y))
            elif draw_type == "Line+Fill":
                frames.append(create_triangle_image(angled_points(points,angles),image_width,image_height,face_center_x,face_center_y,(128,128,128),(200,200,200)))
            elif draw_type == "Image":
                frame_duration=500
                frames.append(create_texture_image(image,points,angled_points(points,angles),image_width,image_height,face_center_x,face_center_y))
        webp = process_create_webp(frames,frame_duration)
        path = save_buffer(webp)
    else:
        print(rotation_order,euler_x,euler_y,euler_z)
        angles = [math.radians(float(euler_x)),math.radians(float(euler_y)),math.radians(float(euler_z))]
        if draw_type == "Dot":
            result_image = create_point_image(angled_points(points,angles,rotation_order),image_width,image_height,face_center_x,face_center_y)
            path = save_image(result_image)
        elif draw_type == "Line":
             result_image = create_triangle_image(angled_points(points,angles,rotation_order),image_width,image_height,face_center_x,face_center_y)
             path = save_image(result_image)
        elif draw_type == "Line+Fill":
             result_image = create_triangle_image(angled_points(points,angles,rotation_order),image_width,image_height,face_center_x,face_center_y,(128,128,128),(200,200,200))
             path = save_image(result_image)
        elif draw_type == "Image":
             result_image = create_texture_image(image,points,angled_points(points,angles,rotation_order),image_width,image_height,face_center_x,face_center_y)
             path = save_image(result_image)


    
    return path