import streamlit as st
import numpy as np
import os
import pickle
import spotipy
import spotipy.util as sp_util

dir_path = os.path.dirname(os.path.realpath(__file__))

# current mess: https://github.com/plamere/spotipy/issues/632
def centered_button(func, text, n_columns=7, disabled=False, args=None):
    columns = st.columns(np.ones(n_columns))
    with columns[n_columns//2]:
        if 'button' in str(func):
            return func(text, disabled=disabled)
        else:
            return func(text)

# get credentials
def setup_credentials():
    if 'client_id' in os.environ.keys() and 'client_secret' in os.environ.keys():
        client_info = dict(client_id=os.environ['client_id'],
                           client_secret=os.environ['client_secret'])
    else:
        with open(dir_path + "/ids.pk", 'rb') as f:
            client_info = pickle.load(f)

    os.environ['SPOTIPY_CLIENT_ID'] = client_info['client_id']
    os.environ['SPOTIPY_CLIENT_SECRET'] = client_info['client_secret']
    os.environ['SPOTIPY_REDIRECT_URI'] = 'https://huggingface.co/spaces/ccolas/EmotionPlaylist/'
    return client_info

relevant_audio_features = ["danceability", "energy", "loudness", "mode", "valence", "tempo"]


def get_client():
    scope = "playlist-modify-public"
    token = sp_util.prompt_for_user_token(scope=scope)
    sp = spotipy.Spotify(auth=token)
    user_id = sp.me()['id']
    return sp, user_id

def add_button(url, text):
    st.write(f'''
        <center>
        <a style='color:black;' href="{url}">
            <button class='css-1cpxqw2'>
                {text}
            </button>
        </a></center>
        ''',
             unsafe_allow_html=True
             )

def new_get_client(session):
    scope = "playlist-modify-public"

    cache_handler = StreamlitCacheHandler(session)
    auth_manager = spotipy.oauth2.SpotifyOAuth(scope=scope,
                                               cache_handler=cache_handler,
                                               show_dialog=True)
    sp, user_id = None, None

    if not auth_manager.validate_token(cache_handler.get_cached_token()):
        # Step 1. Display sign in link when no token
        auth_url = auth_manager.get_authorize_url()
        if 'code' not in st.experimental_get_query_params():
            add_button(auth_url, 'Log in')

        # st.markdown(f'<a href="{auth_url}" target="_self">Click here to log in</a>', unsafe_allow_html=True)
        # Step 2. Being redirected from Spotify auth page
        if 'code' in st.experimental_get_query_params():
            auth_manager.get_access_token(st.experimental_get_query_params()['code'])
            sp = spotipy.Spotify(auth_manager=auth_manager)
            user_id = sp.me()['id']

    return sp, user_id, auth_manager


def extract_uris_from_links(links, url_type):
    assert url_type in ['playlist', 'artist', 'user']
    urls = links.split('\n')
    uris = []
    for url in urls:
        if 'playlist' in url:
            uri = url.split(f'{url_type}/')[-1].split('?')[0]
        elif 'user' in url:
            uri = url.split(f'{url_type}/')[-1].split('?')[0]
        else:
            uri = url.split('?')[0]
        uris.append(uri)
    return uris

def wall_of_checkboxes(labels, max_width=10):
    n_labels = len(labels)
    n_rows = int(np.ceil(n_labels/max_width))
    checkboxes = []
    for i in range(n_rows):
        columns = st.columns(np.ones(max_width))
        row_length = n_labels % max_width if i == n_rows - 1 else max_width
        for j in range(row_length):
            with columns[j]:
                checkboxes.append(st.empty())
    return checkboxes

def find_legit_genre(glabel, legit_genres, verbose=False):
    legit_genres_formatted = [lg.replace('-', '').replace(' ', '') for lg in legit_genres]
    glabel_formatted = glabel.replace(' ', '').replace('-', '')
    if verbose: print('\n', glabel)
    best_match = None
    best_match_score = 0
    for legit_glabel, legit_glabel_formatted in zip(legit_genres, legit_genres_formatted):
        if 'jazz' in glabel_formatted:
            best_match = 'jazz'
            if verbose: print('\t', 'pop')
            break
        if 'ukpop' in glabel_formatted:
            best_match = 'pop'
            if verbose: print('\t', 'pop')
            break
        if legit_glabel_formatted == glabel_formatted:
            if verbose: print('\t', legit_glabel_formatted)
            best_match = legit_glabel
            break
        elif glabel_formatted in legit_glabel_formatted:
            if verbose: print('\t', legit_glabel_formatted)
            if len(glabel_formatted) > best_match_score:
                best_match = legit_glabel
                best_match_score = len(glabel_formatted)
        elif legit_glabel_formatted in glabel_formatted:
            if verbose: print('\t', legit_glabel_formatted)
            if len(legit_glabel_formatted) > best_match_score:
                best_match = legit_glabel
                best_match_score = len(legit_glabel_formatted)

    if best_match is None:
        return "unknown"
    else:
        return best_match


# def aggregate_genres(genres, legit_genres, verbose=False):
#     genres_output = dict()
#     legit_genres_formatted = [lg.replace('-', '').replace(' ', '') for lg in legit_genres]
#     for glabel in genres.keys():
#         if verbose: print('\n', glabel)
#         glabel_formatted = glabel.replace(' ', '').replace('-', '')
#         best_match = None
#         best_match_score = 0
#         for legit_glabel, legit_glabel_formatted in zip(legit_genres, legit_genres_formatted):
#             if 'jazz' in glabel_formatted:
#                 best_match = 'jazz'
#                 if verbose: print('\t', 'pop')
#                 break
#             if 'ukpop' in glabel_formatted:
#                 best_match = 'pop'
#                 if verbose: print('\t', 'pop')
#                 break
#             if legit_glabel_formatted == glabel_formatted:
#                 if verbose: print('\t', legit_glabel_formatted)
#                 best_match = legit_glabel
#                 break
#             elif glabel_formatted in legit_glabel_formatted:
#                 if verbose: print('\t', legit_glabel_formatted)
#                 if len(glabel_formatted) > best_match_score:
#                     best_match = legit_glabel
#                     best_match_score = len(glabel_formatted)
#             elif legit_glabel_formatted in glabel_formatted:
#                 if verbose: print('\t', legit_glabel_formatted)
#                 if len(legit_glabel_formatted) > best_match_score:
#                     best_match = legit_glabel
#                     best_match_score = len(legit_glabel_formatted)
#
#         if best_match is not None:
#             if verbose: print('\t', '-->', best_match)
#             if best_match in genres_output.keys():
#                 genres_output[best_match] += genres[glabel]
#             else:
#                 genres_output[best_match] = genres[glabel]
#         else:
#             if "unknown" in genres_output.keys():
#                 genres_output["unknown"] += genres[glabel]
#             else:
#                 genres_output["unknown"] = genres[glabel]
#     for k in genres_output.keys():
#         genres_output[k] = sorted(set(genres_output[k]))
#     return genres_output

def get_all_playlists_uris_from_users(sp, user_ids):
    all_uris = []
    all_names = []
    for user_id in user_ids:
        print(user_id)
        offset = 0
        done = False
        while not done:
            playlist_list = sp.user_playlists(user_id, offset=offset, limit=50)
            these_names = [p['name'] for p in playlist_list['items']]
            these_uris = [p['uri'] for p in playlist_list['items']]
            for name, uri in zip(these_names, these_uris):
                if uri not in all_uris:
                    all_uris.append(uri)
                    all_names.append(user_id + '/' + name)
            if len(playlist_list['items']) < offset:
                done = True
            else:
                offset += 50
    return all_uris, all_names




class StreamlitCacheHandler(spotipy.cache_handler.CacheHandler):
    """
    A cache handler that stores the token info in the session framework
    provided by streamlit.
    """

    def __init__(self, session):
        self.session = session

    def get_cached_token(self):
        token_info = None
        try:
            token_info = self.session["token_info"]
        except KeyError:
            print("Token not found in the session")

        return token_info

    def save_token_to_cache(self, token_info):
        try:
            self.session["token_info"] = token_info
        except Exception as e:
            print("Error saving token to cache: " + str(e))