from generation_utils import *
from utils import WriteTextMidiToFile, get_miditok
from load import LoadModel
from decoder import TextDecoder
from playback import get_music


class GenerateMidiText:
    """Generating music with Class

    LOGIC:

    FOR GENERATING FROM SCRATCH:
    - self.generate_one_new_track()
    it calls
        - self.generate_until_track_end()

    FOR GENERATING NEW BARS:
    - self.generate_one_more_bar()
    it calls
        - self.process_prompt_for_next_bar()
        - self.generate_until_track_end()"""

    def __init__(self, model, tokenizer, piece_by_track=[]):
        self.model = model
        self.tokenizer = tokenizer
        # default initialization
        self.initialize_default_parameters()
        self.initialize_dictionaries(piece_by_track)

    """Setters"""

    def initialize_default_parameters(self):
        self.set_device()
        self.set_attention_length()
        self.generate_until = "TRACK_END"
        self.set_force_sequence_lenth()
        self.set_nb_bars_generated()
        self.set_improvisation_level(0)

    def initialize_dictionaries(self, piece_by_track):
        self.piece_by_track = piece_by_track

    def set_device(self, device="cpu"):
        self.device = ("cpu",)

    def set_attention_length(self):
        self.max_length = self.model.config.n_positions
        print(
            f"Attention length set to {self.max_length} -> 'model.config.n_positions'"
        )

    def set_force_sequence_lenth(self, force_sequence_length=True):
        self.force_sequence_length = force_sequence_length

    def set_improvisation_level(self, improvisation_value):
        self.no_repeat_ngram_size = improvisation_value
        print("--------------------")
        print(f"no_repeat_ngram_size set to {improvisation_value}")
        print("--------------------")

    def reset_temperatures(self, track_id, temperature):
        self.piece_by_track[track_id]["temperature"] = temperature

    def set_nb_bars_generated(self, n_bars=8):  # default is a 8 bar model
        self.model_n_bar = n_bars

    """ Generation Tools - Dictionnaries """

    def initiate_track_dict(self, instr, density, temperature):
        label = len(self.piece_by_track)
        self.piece_by_track.append(
            {
                "label": f"track_{label}",
                "instrument": instr,
                "density": density,
                "temperature": temperature,
                "bars": [],
            }
        )

    def update_track_dict__add_bars(self, bars, track_id):
        """Add bars to the track dictionnary"""
        for bar in self.striping_track_ends(bars).split("BAR_START "):
            if bar == "":  # happens is there is one bar only
                continue
            else:
                if "TRACK_START" in bar:
                    self.piece_by_track[track_id]["bars"].append(bar)
                else:
                    self.piece_by_track[track_id]["bars"].append("BAR_START " + bar)

    def get_all_instr_bars(self, track_id):
        return self.piece_by_track[track_id]["bars"]

    def striping_track_ends(self, text):
        if "TRACK_END" in text:
            # first get rid of extra space if any
            # then gets rid of "TRACK_END"
            text = text.rstrip(" ").rstrip("TRACK_END")
        return text

    def get_last_generated_track(self, full_piece):
        track = (
            "TRACK_START "
            + self.striping_track_ends(full_piece.split("TRACK_START ")[-1])
            + "TRACK_END "
        )  # forcing the space after track and
        return track

    def get_selected_track_as_text(self, track_id):
        text = ""
        for bar in self.piece_by_track[track_id]["bars"]:
            text += bar
        text += "TRACK_END "
        return text

    @staticmethod
    def get_newly_generated_text(input_prompt, full_piece):
        return full_piece[len(input_prompt) :]

    def get_whole_piece_from_bar_dict(self):
        text = "PIECE_START "
        for track_id, _ in enumerate(self.piece_by_track):
            text += self.get_selected_track_as_text(track_id)
        return text

    def delete_one_track(self, track):  # TO BE TESTED
        self.piece_by_track.pop(track)

    # def update_piece_dict__add_track(self, track_id, track):
    #     self.piece_dict[track_id] = track

    # def update_all_dictionnaries__add_track(self, track):
    # self.update_piece_dict__add_track(track_id, track)

    """Basic generation tools"""

    def tokenize_input_prompt(self, input_prompt, verbose=True):
        """Tokenizing prompt

        Args:
        - input_prompt (str): prompt to tokenize

        Returns:
        - input_prompt_ids (torch.tensor): tokenized prompt
        """
        if verbose:
            print("Tokenizing input_prompt...")

        return self.tokenizer.encode(input_prompt, return_tensors="pt")

    def generate_sequence_of_token_ids(
        self,
        input_prompt_ids,
        temperature,
        verbose=True,
    ):
        """
        generate a sequence of token ids based on input_prompt_ids
        The sequence length depends on the trained model (self.model_n_bar)
        """
        generated_ids = self.model.generate(
            input_prompt_ids,
            max_length=self.max_length,
            do_sample=True,
            temperature=temperature,
            no_repeat_ngram_size=self.no_repeat_ngram_size,  # default = 0
            eos_token_id=self.tokenizer.encode(self.generate_until)[0],  # good
        )

        if verbose:
            print("Generating a token_id sequence...")

        return generated_ids

    def convert_ids_to_text(self, generated_ids, verbose=True):
        """converts the token_ids to text"""
        generated_text = self.tokenizer.decode(generated_ids[0])
        if verbose:
            print("Converting token sequence to MidiText...")
        return generated_text

    def generate_until_track_end(
        self,
        input_prompt="PIECE_START ",
        instrument=None,
        density=None,
        temperature=None,
        verbose=True,
        expected_length=None,
    ):

        """generate until the TRACK_END token is reached
        full_piece = input_prompt + generated"""
        if expected_length is None:
            expected_length = self.model_n_bar

        if instrument is not None:
            input_prompt = f"{input_prompt}TRACK_START INST={str(instrument)} "
            if density is not None:
                input_prompt = f"{input_prompt}DENSITY={str(density)} "

        if instrument is None and density is not None:
            print("Density cannot be defined without an input_prompt instrument #TOFIX")

        if temperature is None:
            ValueError("Temperature must be defined")

        if verbose:
            print("--------------------")
            print(
                f"Generating {instrument} - Density {density} - temperature {temperature}"
            )
        bar_count_checks = False
        failed = 0
        while not bar_count_checks:  # regenerate until right length
            input_prompt_ids = self.tokenize_input_prompt(input_prompt, verbose=verbose)
            generated_tokens = self.generate_sequence_of_token_ids(
                input_prompt_ids, temperature, verbose=verbose
            )
            full_piece = self.convert_ids_to_text(generated_tokens, verbose=verbose)
            generated = self.get_newly_generated_text(input_prompt, full_piece)
            # bar_count_checks
            bar_count_checks, bar_count = bar_count_check(generated, expected_length)

            if not self.force_sequence_length:
                # set bar_count_checks to true to exist the while loop
                bar_count_checks = True

            if not bar_count_checks and self.force_sequence_length:
                # if the generated sequence is not the expected length
                if failed > -1:  # deactivated for speed
                    full_piece, bar_count_checks = forcing_bar_count(
                        input_prompt,
                        generated,
                        bar_count,
                        expected_length,
                    )
                else:
                    print('"--- Wrong length - Regenerating ---')
            if not bar_count_checks:
                failed += 1
                if failed > 2:
                    bar_count_checks = True  # TOFIX exit the while loop

        return full_piece

    def generate_one_new_track(
        self,
        instrument,
        density,
        temperature,
        input_prompt="PIECE_START ",
    ):
        self.initiate_track_dict(instrument, density, temperature)
        full_piece = self.generate_until_track_end(
            input_prompt=input_prompt,
            instrument=instrument,
            density=density,
            temperature=temperature,
        )

        track = self.get_last_generated_track(full_piece)
        self.update_track_dict__add_bars(track, -1)
        full_piece = self.get_whole_piece_from_bar_dict()
        return full_piece

    """ Piece generation - Basics """

    def generate_piece(self, instrument_list, density_list, temperature_list):
        """generate a sequence with mutiple tracks
        - inst_list sets the list of instruments of the order of generation
        - density is paired with inst_list
        Each track/intrument is generated on a prompt which contains the previously generated track/instrument
        This means that the first instrument is generated with less bias than the next one, and so on.

        'generated_piece' keeps track of the entire piece
        'generated_piece' is returned by self.generate_until_track_end
        # it is returned by self.generate_until_track_end"""

        generated_piece = "PIECE_START "
        for instrument, density, temperature in zip(
            instrument_list, density_list, temperature_list
        ):
            generated_piece = self.generate_one_new_track(
                instrument,
                density,
                temperature,
                input_prompt=generated_piece,
            )

        # generated_piece = self.get_whole_piece_from_bar_dict()
        self.check_the_piece_for_errors()
        return generated_piece

    """ Piece generation - Extra Bars """

    @staticmethod
    def process_prompt_for_next_bar(self, track_idx):
        """Processing the prompt for the model to generate one more bar only.
        The prompt containts:
                if not the first bar: the previous, already processed, bars of the track
                the bar initialization (ex: "TRACK_START INST=DRUMS DENSITY=2 ")
                the last (self.model_n_bar)-1 bars of the track
        Args:
            track_idx (int): the index of the track to be processed

        Returns:
            the processed prompt for generating the next bar
        """
        track = self.piece_by_track[track_idx]
        # for bars which are not the bar to prolong
        pre_promt = "PIECE_START "
        for i, othertrack in enumerate(self.piece_by_track):
            if i != track_idx:
                len_diff = len(othertrack["bars"]) - len(track["bars"])
                if len_diff > 0:
                    # if other bars are longer, it mean that this one should catch up
                    pre_promt += othertrack["bars"][0]
                    for bar in track["bars"][-self.model_n_bar :]:
                        pre_promt += bar
                    pre_promt += "TRACK_END "
                elif False:  # len_diff <= 0: # THIS GENERATES EMPTINESS
                    # adding an empty bars at the end of the other tracks if they have not been processed yet
                    pre_promt += othertracks["bars"][0]
                    for bar in track["bars"][-(self.model_n_bar - 1) :]:
                        pre_promt += bar
                    for _ in range(abs(len_diff) + 1):
                        pre_promt += "BAR_START BAR_END "
                    pre_promt += "TRACK_END "

        # for the bar to prolong
        # initialization e.g TRACK_START INST=DRUMS DENSITY=2
        processed_prompt = track["bars"][0]
        for bar in track["bars"][-(self.model_n_bar - 1) :]:
            # adding the "last" bars of the track
            processed_prompt += bar

        processed_prompt += "BAR_START "
        print(
            f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
        )
        return pre_promt + processed_prompt

    def generate_one_more_bar(self, i):
        """Generate one more bar from the input_prompt"""
        processed_prompt = self.process_prompt_for_next_bar(self, i)
        prompt_plus_bar = self.generate_until_track_end(
            input_prompt=processed_prompt,
            temperature=self.piece_by_track[i]["temperature"],
            expected_length=1,
            verbose=False,
        )
        added_bar = self.get_newly_generated_bar(prompt_plus_bar)
        self.update_track_dict__add_bars(added_bar, i)

    def get_newly_generated_bar(self, prompt_plus_bar):
        return "BAR_START " + self.striping_track_ends(
            prompt_plus_bar.split("BAR_START ")[-1]
        )

    def generate_n_more_bars(self, n_bars, only_this_track=None, verbose=True):
        """Generate n more bars from the input_prompt"""
        if only_this_track is None:
            only_this_track

        print(f"================== ")
        print(f"Adding {n_bars} more bars to the piece ")
        for bar_id in range(n_bars):
            print(f"----- added bar #{bar_id+1} --")
            for i, track in enumerate(self.piece_by_track):
                if only_this_track is None or i == only_this_track:
                    print(f"--------- {track['label']}")
                    self.generate_one_more_bar(i)
        self.check_the_piece_for_errors()

    def check_the_piece_for_errors(self, piece: str = None):

        if piece is None:
            piece = generate_midi.get_whole_piece_from_bar_dict()
        errors = []
        errors.append(
            [
                (token, id)
                for id, token in enumerate(piece.split(" "))
                if token not in self.tokenizer.vocab or token == "UNK"
            ]
        )
        if len(errors) > 0:
            # print(piece)
            for er in errors:
                er
                print(f"Token not found in the piece at {er[0][1]}: {er[0][0]}")
                print(piece.split(" ")[er[0][1] - 5 : er[0][1] + 5])


if __name__ == "__main__":

    # worker
    DEVICE = "cpu"

    # define generation parameters
    N_FILES_TO_GENERATE = 2
    Temperatures_to_try = [0.7]

    USE_FAMILIZED_MODEL = True
    force_sequence_length = True

    if USE_FAMILIZED_MODEL:
        # model_repo = "misnaej/the-jam-machine-elec-famil"
        # model_repo = "misnaej/the-jam-machine-elec-famil-ft32"

        # model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
        # n_bar_generated = 8

        model_repo = "JammyMachina/improved_4bars-mdl"
        n_bar_generated = 4
        instrument_promt_list = ["4", "DRUMS", "3"]
        # DRUMS = drums, 0 = piano, 1 = chromatic percussion, 2 = organ, 3 = guitar, 4 = bass, 5 = strings, 6 = ensemble, 7 = brass, 8 = reed, 9 = pipe, 10 = synth lead, 11 = synth pad, 12 = synth effects, 13 = ethnic, 14 = percussive, 15 = sound effects
        density_list = [3, 2, 2]
        # temperature_list = [0.7, 0.7, 0.75]
    else:
        model_repo = "misnaej/the-jam-machine"
        instrument_promt_list = ["30"]  # , "DRUMS", "0"]
        density_list = [3]  # , 2, 3]
        # temperature_list = [0.7, 0.5, 0.75]
        pass

    # define generation directory
    generated_sequence_files_path = define_generation_dir(model_repo)

    # load model and tokenizer
    model, tokenizer = LoadModel(
        model_repo, from_huggingface=True
    ).load_model_and_tokenizer()

    # does the prompt make sense
    check_if_prompt_inst_in_tokenizer_vocab(tokenizer, instrument_promt_list)

    for temperature in Temperatures_to_try:
        print(f"================= TEMPERATURE {temperature} =======================")
        for _ in range(N_FILES_TO_GENERATE):
            print(f"========================================")
            # 1 - instantiate
            generate_midi = GenerateMidiText(model, tokenizer)
            # 0 - set the n_bar for this model
            generate_midi.set_nb_bars_generated(n_bars=n_bar_generated)
            # 1 - defines the instruments, densities and temperatures
            # 2- generate the first 8 bars for each instrument
            generate_midi.set_improvisation_level(30)
            generate_midi.generate_piece(
                instrument_promt_list,
                density_list,
                [temperature for _ in density_list],
            )
            # 3 - force the model to improvise
            # generate_midi.set_improvisation_level(20)
            # 4 - generate the next 4 bars for each instrument
            # generate_midi.generate_n_more_bars(n_bar_generated)
            # 5 - lower the improvisation level
            generate_midi.generated_piece = (
                generate_midi.get_whole_piece_from_bar_dict()
            )

            # print the generated sequence in terminal
            print("=========================================")
            print(generate_midi.generated_piece)
            print("=========================================")

            # write to JSON file
            filename = WriteTextMidiToFile(
                generate_midi,
                generated_sequence_files_path,
            ).text_midi_to_file()

            # decode the sequence to MIDI """
            decode_tokenizer = get_miditok()
            TextDecoder(decode_tokenizer, USE_FAMILIZED_MODEL).get_midi(
                generate_midi.generated_piece, filename=filename.split(".")[0] + ".mid"
            )
            inst_midi, mixed_audio = get_music(filename.split(".")[0] + ".mid")
            max_time = get_max_time(inst_midi)
            plot_piano_roll(inst_midi)

            print("Et voilĂ ! Your MIDI file is ready! GO JAM!")