In [3]:
import pickle
import pandas as pd
import numpy as np
import psutil
from tslearn import metrics
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def pickle_read(filename):
    with open(filename, "rb") as f:
        data = pickle.load(f)
    return data

In [4]:
def convert_points(pts):
    pts_bis = []
    for p in pts:
        if p == 19:
            pts_bis.append(15)
        elif p == 21:
            pts_bis.append(16)
        elif p == 22:
            pts_bis.append(17)
        elif p == 24:
            pts_bis.append(18)
        else: 
            pts_bis.append(p)
    return pts_bis

In [10]:
joints = {0: 'Nose', 1: 'Neck', 2: 'RShoulder', 3: 'RElbow', 4: 'RWrist', 5: 'LShoulder',
          6: 'LElbow', 7: 'LWrist', 8: 'MidHip', 9: 'RHip', 10: 'RKnee', 11: 'RAnkle', 12: 'LHip',
          13: 'LKnee', 14: 'LAnkle', 19: 'LBigToe', 21: 'LHeel', 22: 'RBigToe', 24: 'RHeel'}

links = {0: [0, 1], 1: [1, 2, 5, 8], 2: [2, 3], 3:[3, 4], 4: [4], 5: [5, 6], 6: [6, 7], 7: [7], 
         8: [8, 9, 12], 9: [9, 10], 10: [10, 11], 11: [11, 22, 24], 12: [12, 13], 13: [13, 14], 
         14: [14, 19, 21], 19: [19], 21: [21], 22: [22], 24: [24]}

def plotting(seqs):

    x = {0: {}, 1: {}}
    y = {0: {}, 1: {}}
    z = {0: {}, 1: {}}
    x_lines = {0: {}, 1: {}}
    y_lines = {0: {}, 1: {}}
    z_lines = {0: {}, 1: {}}

    frames = []

    for N in range(seqs[0].shape[1]):
        for i, seq in enumerate(seqs):

            x[i][N] = [a for j, a in enumerate(seq[:,N]) if j//19 == 0]
            y[i][N] = [a for j, a in enumerate(seq[:,N]) if j//19 == 1]
            z[i][N] = [a for j, a in enumerate(seq[:,N]) if j//19 == 2]

            x_lines[i][N] = []
            y_lines[i][N] = []
            z_lines[i][N] = []
            lines_colors = []
            for p1, v in links.items():
                for p2 in v:
                    if p1 != p2:
                        p = convert_points([p1, p2])
                        for j in range(2):
                            x_lines[i][N].append(x[i][N][p[j]])
                            y_lines[i][N].append(y[i][N][p[j]])
                            z_lines[i][N].append(z[i][N][p[j]])
                            lines_colors.append(p[i])
                        x_lines[i][N].append(None)
                        y_lines[i][N].append(None)
                        z_lines[i][N].append(None)
                        lines_colors.append(15)

        frames.append(go.Frame(data=[go.Scatter3d(x=x[0][N], y=y[0][N], z=z[0][N], mode='markers', name='Joints',
                                                  hoverinfo='text', hovertext=list(joints.values()), 
                                                  marker=dict(color='red', size=5)),
                                     go.Scatter3d(x=x_lines[0][N], y=y_lines[0][N], z=z_lines[0][N], mode='lines', 
                                                  name='Links', hoverinfo='text', line=dict(color='red', width=5)),
                                     go.Scatter3d(x=x[1][N], y=y[1][N], z=z[1][N], mode='markers', name='Joints', 
                                                  hoverinfo='text', hovertext=list(joints.values()), 
                                                  marker=dict(color='green', size=5)),
                                     go.Scatter3d(x=x_lines[1][N], y=y_lines[1][N], z=z_lines[1][N], mode='lines', 
                                                  name='Links', hoverinfo='text', line=dict(color='green', width=5))]))

    fig = go.Figure(data=[go.Scatter3d(x=x[0][0], y=y[0][0], z=z[0][0], mode='markers', name='Joints',
                                                  hoverinfo='text', hovertext=list(joints.values()), 
                                                  marker=dict(color='red', size=5)),
                          go.Scatter3d(x=x_lines[0][0], y=y_lines[0][0], z=z_lines[0][0], mode='lines', 
                                       name='Links', hoverinfo='text', line=dict(color='red', width=5)),
                          go.Scatter3d(x=x[1][0], y=y[1][0], z=z[1][0], mode='markers', name='Joints', 
                                       hoverinfo='text', hovertext=list(joints.values()), 
                                       marker=dict(color='green', size=5)),
                          go.Scatter3d(x=x_lines[1][0], y=y_lines[1][0], z=z_lines[1][0], mode='lines', 
                                       name='Links', hoverinfo='text', line=dict(color='green', width=5))],
                    frames=frames)

    fig.update_layout(scene = dict(xaxis = dict(range=[-0.33,0.33],
                                                backgroundcolor="rgb(200, 200, 230)",
                                                showgrid=False,
                                                zeroline=False,
                                                showticklabels=False,
                                                showbackground=True,
                                                title='',
                                                visible=False), 
                                   yaxis = dict(range=[-0.33,0.33],
                                                backgroundcolor="rgb(230, 200,230)",
                                                showgrid=False,
                                                zeroline=False,
                                                showticklabels=False,
                                                showbackground=False,
                                                title='',
                                                visible=True), 
                                   zaxis = dict(range=[-0.33,0.33],
                                                backgroundcolor="rgb(230, 230,200)",
                                                showgrid=False,
                                                zeroline=False,
                                                showticklabels=False,
                                                showbackground=True,
                                                title='',
                                                visible=False),
                                   aspectmode='cube'),
                      updatemenus = [dict(type="buttons", 
                                          buttons=[dict(label="Play", method="animate",
                                                        args=[None, {"frame": {"duration": 33, "redraw": True},
                                                                     "transition": {"duration": 0}}]),
                                                  dict(label="Pause", method="animate",
                                                        args=[[None], {"frame": {"duration": 0, "redraw": False},
                                                                      "mode": "immediate",
                                                                      "transition": {"duration": 0}}])])],
                     showlegend=False,
                     height=500, width=500)

    fig.show()
    #fig.write_html("fig.html")

In [11]:
test = pickle_read('../PoseCorrection/Results/results.pickle')

In [13]:
seqs = [test['SQUAT_C']['poses_original'].numpy(), test['SQUAT_C']['poses_corrected'].numpy()]
plotting(seqs)