Spaces:
Sleeping
Sleeping
Commit
·
87ae0b7
1
Parent(s):
2d7a385
Initial commit
Browse files- app.py +245 -0
- assets/asciilogo.txt +11 -0
- requirements.txt +11 -0
- source/languagemodel.py +288 -0
- source/utilities.py +331 -0
app.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from source.languagemodel import LanguageModel
|
| 3 |
+
from source.utilities import (
|
| 4 |
+
convert_tokens_to_songdata,
|
| 5 |
+
convert_songdata_to_notesequence,
|
| 6 |
+
convert_songdata_to_pianoroll,
|
| 7 |
+
convert_notesequence_to_wave,
|
| 8 |
+
convert_notesequence_to_midi
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# Define the MIDI instruments.
|
| 12 |
+
midi_instruments = {
|
| 13 |
+
"Harpsichord": 6,
|
| 14 |
+
"Church Organ": 19,
|
| 15 |
+
"Piano": 0,
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
# Load the model once and cache it.
|
| 19 |
+
@st.cache_resource
|
| 20 |
+
def load_model():
|
| 21 |
+
model = LanguageModel("TristanBehrens/bach-garland-mambaplus")
|
| 22 |
+
return model
|
| 23 |
+
model = load_model()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Initialize token_sequence in session state if it doesn't exist
|
| 27 |
+
if "token_sequence" not in st.session_state:
|
| 28 |
+
st.session_state.token_sequence = "GARLAND_START"
|
| 29 |
+
st.session_state.song_data = None
|
| 30 |
+
st.session_state.piano_roll = None
|
| 31 |
+
st.session_state.wave = None
|
| 32 |
+
st.session_state.note_sequence = None
|
| 33 |
+
st.session_state.midi_file_content = None
|
| 34 |
+
st.session_state.temperature = 0.1
|
| 35 |
+
st.session_state.bpm = 100
|
| 36 |
+
st.session_state.instrument = "Piano"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Define the main function.
|
| 40 |
+
def main():
|
| 41 |
+
|
| 42 |
+
columns = st.columns([0.7, 0.3])
|
| 43 |
+
|
| 44 |
+
# Set up the Streamlit application
|
| 45 |
+
column = columns.pop(0)
|
| 46 |
+
with column:
|
| 47 |
+
|
| 48 |
+
# Change the colors of the a-tag to (255, 75, 75).
|
| 49 |
+
st.markdown("<style>a:link { color: #FF4B4B; } a:visited { color: #FF4B4B; }</style>", unsafe_allow_html=True)
|
| 50 |
+
|
| 51 |
+
# Add a title.
|
| 52 |
+
st.title("Garland Composer")
|
| 53 |
+
linkedin_url = "https://huggingface.co/TristanBehrens/bach-garland-mambaplus/"
|
| 54 |
+
x_url = "https://huggingface.co/TristanBehrens/bach-garland-mambaplus/"
|
| 55 |
+
st.write(f"By Dr. Tristan Behrens. Find me on [LinkedIn]({linkedin_url}) and [X]({x_url}).")
|
| 56 |
+
hf_url = "https://huggingface.co/TristanBehrens/bach-garland-mambaplus/"
|
| 57 |
+
st.write(f"Model available on [Hugging Face]({hf_url}).")
|
| 58 |
+
|
| 59 |
+
# Add a picture.
|
| 60 |
+
column = columns.pop(0)
|
| 61 |
+
with column:
|
| 62 |
+
st.write(" ")
|
| 63 |
+
st.write(" ")
|
| 64 |
+
st.write(" ")
|
| 65 |
+
st.image("garland.jpg", use_column_width=True)
|
| 66 |
+
|
| 67 |
+
# Add a horizontal line.
|
| 68 |
+
st.markdown("---")
|
| 69 |
+
|
| 70 |
+
# Create two columns.
|
| 71 |
+
columns = st.columns(3)
|
| 72 |
+
|
| 73 |
+
# Add a slider to control the temperature.
|
| 74 |
+
state_temperature = st.session_state.temperature
|
| 75 |
+
with columns.pop(0):
|
| 76 |
+
temperature = st.slider("Temperature", 0.0, 1.0, state_temperature)
|
| 77 |
+
st.session_state.temperature = temperature
|
| 78 |
+
|
| 79 |
+
# Add a slider to control the bpm.
|
| 80 |
+
state_bpm = st.session_state.bpm
|
| 81 |
+
with columns.pop(0):
|
| 82 |
+
bpm = st.slider("BPM", 80, 120, state_bpm, 5)
|
| 83 |
+
st.session_state.bpm = bpm
|
| 84 |
+
|
| 85 |
+
# Dropdown for the instrument.
|
| 86 |
+
state_instrument = st.session_state.instrument
|
| 87 |
+
with columns.pop(0):
|
| 88 |
+
instrument = st.selectbox("Instrument", list(midi_instruments.keys()), index=list(midi_instruments.keys()).index(state_instrument))
|
| 89 |
+
st.session_state.instrument = instrument
|
| 90 |
+
|
| 91 |
+
# Get the token sequence from the session state.
|
| 92 |
+
token_sequence = st.session_state.token_sequence
|
| 93 |
+
|
| 94 |
+
# Columns for the buttons.
|
| 95 |
+
columns = st.columns(5)
|
| 96 |
+
|
| 97 |
+
# Add a button to generate the next bar.
|
| 98 |
+
column = columns.pop(0)
|
| 99 |
+
with column:
|
| 100 |
+
if st.button("Add a bar", use_container_width=True):
|
| 101 |
+
token_sequence = extend_sequence(model, token_sequence, temperature)
|
| 102 |
+
refresh(token_sequence, bpm, instrument)
|
| 103 |
+
|
| 104 |
+
# Add a button to compose long.
|
| 105 |
+
column = columns.pop(0)
|
| 106 |
+
with column:
|
| 107 |
+
if st.button("Auto compose", use_container_width=True):
|
| 108 |
+
token_sequence = auto_compose(model, token_sequence, temperature)
|
| 109 |
+
refresh(token_sequence, bpm, instrument)
|
| 110 |
+
|
| 111 |
+
# Add a button to remove the last bar.
|
| 112 |
+
column = columns.pop(0)
|
| 113 |
+
with column:
|
| 114 |
+
if st.button("Remove last", use_container_width=True):
|
| 115 |
+
token_sequence = shortened_sequence(token_sequence)
|
| 116 |
+
refresh(token_sequence, bpm, instrument)
|
| 117 |
+
|
| 118 |
+
# Add a button to reset the sequence.
|
| 119 |
+
column = columns.pop(0)
|
| 120 |
+
if token_sequence != "GARLAND_START":
|
| 121 |
+
with column:
|
| 122 |
+
if st.button("Reset", use_container_width=True):
|
| 123 |
+
with columns.pop(0):
|
| 124 |
+
token_sequence = "GARLAND_START"
|
| 125 |
+
refresh(token_sequence, bpm, instrument)
|
| 126 |
+
|
| 127 |
+
# Provide a download button for the MIDI file.
|
| 128 |
+
column = columns.pop(0)
|
| 129 |
+
if "midi_file_content" in st.session_state and st.session_state.midi_file_content is not None:
|
| 130 |
+
with column:
|
| 131 |
+
midi_file_content = st.session_state.midi_file_content
|
| 132 |
+
if st.download_button(
|
| 133 |
+
label="Download MIDI",
|
| 134 |
+
data=midi_file_content,
|
| 135 |
+
file_name="music.mid",
|
| 136 |
+
mime="audio/midi",
|
| 137 |
+
use_container_width=True
|
| 138 |
+
):
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
# Add a horizontal line.
|
| 142 |
+
st.markdown("---")
|
| 143 |
+
|
| 144 |
+
# Display the piano roll.
|
| 145 |
+
if "piano_roll" in st.session_state and st.session_state.piano_roll is not None:
|
| 146 |
+
st.image(st.session_state.piano_roll)
|
| 147 |
+
|
| 148 |
+
# Display an audio player.
|
| 149 |
+
if "wave" in st.session_state and st.session_state.wave is not None:
|
| 150 |
+
st.audio(st.session_state.wave, format="audio/wav", sample_rate=44100, autoplay=True)
|
| 151 |
+
|
| 152 |
+
# Add a horizontal line.
|
| 153 |
+
st.markdown("---")
|
| 154 |
+
|
| 155 |
+
# Set the text color to (255, 31, 75).
|
| 156 |
+
if token_sequence.endswith("GARLAND_END"):
|
| 157 |
+
st.write("The AI believes that the music is finished.")
|
| 158 |
+
else:
|
| 159 |
+
st.write("The AI believes that the music is not finished.")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def auto_compose(model, token_sequence, temperature):
|
| 163 |
+
|
| 164 |
+
max_iterations = 100
|
| 165 |
+
for _ in range(max_iterations):
|
| 166 |
+
token_sequence = extend_sequence(model, token_sequence, temperature)
|
| 167 |
+
if token_sequence.endswith("GARLAND_END"):
|
| 168 |
+
break
|
| 169 |
+
return token_sequence
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def extend_sequence(model, token_sequence, temperature):
|
| 173 |
+
|
| 174 |
+
# Replace the last GARLAND_END token with NEXT.
|
| 175 |
+
if token_sequence.endswith("GARLAND_END"):
|
| 176 |
+
token_sequence = token_sequence.replace("GARLAND_END", "NEXT")
|
| 177 |
+
|
| 178 |
+
# The maximum length of the generated music.
|
| 179 |
+
max_length = 16_384
|
| 180 |
+
|
| 181 |
+
# When to stop the generation.
|
| 182 |
+
end_tokens = ["NEXT", "GARLAND_END"]
|
| 183 |
+
|
| 184 |
+
# Compose the music iterativelybar by bar.
|
| 185 |
+
output_dict = model.generate(
|
| 186 |
+
prompt=token_sequence,
|
| 187 |
+
temperature=temperature,
|
| 188 |
+
max_length=max_length,
|
| 189 |
+
end_tokens=end_tokens,
|
| 190 |
+
forbidden_tokens=["[PAD]", "[EOS]"],
|
| 191 |
+
return_structured_output=True
|
| 192 |
+
)
|
| 193 |
+
output = output_dict["output"]
|
| 194 |
+
return output
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def shortened_sequence(token_sequence):
|
| 198 |
+
|
| 199 |
+
# Find the position of the next to last NEXT token.
|
| 200 |
+
next_tokens = token_sequence.split()
|
| 201 |
+
next_positions = [i for i, x in enumerate(next_tokens) if x == "NEXT" or x == "GARLAND_END"]
|
| 202 |
+
if len(next_positions) <= 1:
|
| 203 |
+
token_sequence = "GARLAND_START"
|
| 204 |
+
else:
|
| 205 |
+
next_position = next_positions[-2]
|
| 206 |
+
token_sequence = " ".join(next_tokens[:next_position + 1])
|
| 207 |
+
return token_sequence
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def refresh(token_sequence="GARLAND_START", bpm=120, instrument="Piano"):
|
| 211 |
+
|
| 212 |
+
# Get the token sequence into the session state.
|
| 213 |
+
st.session_state.token_sequence = token_sequence
|
| 214 |
+
|
| 215 |
+
# Convert to song data.
|
| 216 |
+
song_data = convert_tokens_to_songdata(token_sequence)
|
| 217 |
+
song_data["bpm"] = bpm
|
| 218 |
+
st.session_state.song_data = song_data
|
| 219 |
+
|
| 220 |
+
# Set the instrument.
|
| 221 |
+
for track in song_data["tracks"]:
|
| 222 |
+
track["instrument"] = midi_instruments[instrument]
|
| 223 |
+
|
| 224 |
+
# Convert to piano roll.
|
| 225 |
+
piano_roll = convert_songdata_to_pianoroll(song_data)
|
| 226 |
+
st.session_state.piano_roll = piano_roll
|
| 227 |
+
|
| 228 |
+
# Convert to note sequence.
|
| 229 |
+
note_sequence = convert_songdata_to_notesequence(song_data)
|
| 230 |
+
st.session_state.note_sequence = note_sequence
|
| 231 |
+
|
| 232 |
+
# Play the note sequence.
|
| 233 |
+
wave = convert_notesequence_to_wave(note_sequence)
|
| 234 |
+
st.session_state.wave = wave
|
| 235 |
+
|
| 236 |
+
# Get the MIDI file content.
|
| 237 |
+
midi_file_content = convert_notesequence_to_midi(note_sequence)
|
| 238 |
+
st.session_state.midi_file_content = midi_file_content
|
| 239 |
+
|
| 240 |
+
# Rerun the app.
|
| 241 |
+
st.rerun()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
main()
|
assets/asciilogo.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
▄█ █▄ ▄████████ ▄█ ▄█ ▀█████████▄ ▄████████ ███ █▄ ███▄▄▄▄ ███▄▄▄▄ ▄████████
|
| 2 |
+
███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███▀▀▀██▄ ███▀▀▀██▄ ███ ███
|
| 3 |
+
███ ███ ███ █▀ ███ ███▌ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███
|
| 4 |
+
▄███▄▄▄▄███▄▄ ▄███▄▄▄ ███ ███▌ ▄███▄▄▄██▀ ▄███▄▄▄▄██▀ ███ ███ ███ ███ ███ ███ ███ ███
|
| 5 |
+
▀▀███▀▀▀▀███▀ ▀▀███▀▀▀ ███ ███▌ ▀▀███▀▀▀██▄ ▀▀███▀▀▀▀▀ ███ ███ ███ ███ ███ ███ ▀███████████
|
| 6 |
+
███ ███ ███ █▄ ███ ███ ███ ██▄ ▀███████████ ███ ███ ███ ███ ███ ███ ███ ███
|
| 7 |
+
███ ███ ███ ███ ███▌ ▄ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███
|
| 8 |
+
███ █▀ ██████████ █████▄▄██ █▀ ▄█████████▀ ███ ███ ████████▀ ▀█ █▀ ▀█ █▀ ███ █▀
|
| 9 |
+
▀ ███ ███
|
| 10 |
+
|
| 11 |
+
By Dr. Tristan Behrens
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dacite==1.8.1
|
| 2 |
+
colorama==0.4.6
|
| 3 |
+
omegaconf==2.3.0
|
| 4 |
+
streamlit==1.38.0
|
| 5 |
+
note_seq==0.0.5
|
| 6 |
+
pyfluidsynth==1.3.2
|
| 7 |
+
torch==2.2.0
|
| 8 |
+
transformers==4.44.0
|
| 9 |
+
mamba-ssm==2.2.2
|
| 10 |
+
einops==0.8.0
|
| 11 |
+
mambapy==1.2.0
|
source/languagemodel.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Helibrunna - A HuggingFace compatible xLSTM trainer.
|
| 2 |
+
# Copyright (c) 2024 Dr. Tristan Behrens
|
| 3 |
+
#
|
| 4 |
+
# This program is free software: you can redistribute it and/or modify
|
| 5 |
+
# it under the terms of the GNU Affero General Public License as published by
|
| 6 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 7 |
+
# (at your option) any later version.
|
| 8 |
+
#
|
| 9 |
+
# This program is distributed in the hope that it will be useful,
|
| 10 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 11 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 12 |
+
# GNU Affero General Public License for more details.
|
| 13 |
+
#
|
| 14 |
+
# You should have received a copy of the GNU Affero General Public License
|
| 15 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import glob
|
| 19 |
+
from omegaconf import OmegaConf
|
| 20 |
+
from transformers import PreTrainedTokenizerFast
|
| 21 |
+
import torch
|
| 22 |
+
from safetensors.torch import load_file
|
| 23 |
+
import time
|
| 24 |
+
from .utilities import display_logo, model_from_config
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LanguageModel:
|
| 28 |
+
|
| 29 |
+
def __init__(self, model_path_or_repo, config_overrides={}, mask_special_tokens=True, device="auto"):
|
| 30 |
+
"""
|
| 31 |
+
Initializes the LanguageModel object.
|
| 32 |
+
Args:
|
| 33 |
+
model_path_or_repo (str): The path to the model or the repository ID.
|
| 34 |
+
Raises:
|
| 35 |
+
ValueError: If the model checkpoint, tokenizer, config, or weights are not found.
|
| 36 |
+
Exception: If failed to download the model.
|
| 37 |
+
Returns:
|
| 38 |
+
None
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
# Set the maskt_special_tokens flag.
|
| 42 |
+
self.mask_special_tokens = mask_special_tokens
|
| 43 |
+
|
| 44 |
+
# Set the device. CPU is default.
|
| 45 |
+
if device != "auto":
|
| 46 |
+
|
| 47 |
+
# Check if CUDA is available.
|
| 48 |
+
if not torch.cuda.is_available() and device == "cuda":
|
| 49 |
+
raise ValueError("CUDA is not available on this system.")
|
| 50 |
+
|
| 51 |
+
# Check if MPS is available.
|
| 52 |
+
if not torch.backends.mps.is_available() and device == "mps":
|
| 53 |
+
raise ValueError("MPS is not available on this system.")
|
| 54 |
+
|
| 55 |
+
# Set the device.
|
| 56 |
+
self.device = device
|
| 57 |
+
|
| 58 |
+
# Set the device to auto.
|
| 59 |
+
else:
|
| 60 |
+
|
| 61 |
+
# Set the device to CPU if auto is selected.
|
| 62 |
+
self.device = "cpu" if device == "auto" else device
|
| 63 |
+
|
| 64 |
+
# Check if CUDA is available.
|
| 65 |
+
if torch.cuda.is_available() and device == "auto":
|
| 66 |
+
self.device = "cuda"
|
| 67 |
+
|
| 68 |
+
# See if MPS is available.
|
| 69 |
+
# Note: This is disabled for now. It's not working as expected. It is very slow.
|
| 70 |
+
#if torch.backends.mps.is_available():
|
| 71 |
+
# self.device = "mps"
|
| 72 |
+
|
| 73 |
+
# Display the logo.
|
| 74 |
+
display_logo()
|
| 75 |
+
|
| 76 |
+
# Download the model if it doesn't exist. Or at least try to.
|
| 77 |
+
if not os.path.exists(model_path_or_repo):
|
| 78 |
+
from huggingface_hub import snapshot_download
|
| 79 |
+
try:
|
| 80 |
+
model_path=snapshot_download(repo_id=model_path_or_repo)
|
| 81 |
+
tokenizer_path=model_path
|
| 82 |
+
except Exception as e:
|
| 83 |
+
raise f"Failed to download the model: {e}"
|
| 84 |
+
|
| 85 |
+
# Use a local model.
|
| 86 |
+
else:
|
| 87 |
+
# Set the model path and tokenizer path.
|
| 88 |
+
model_path = None
|
| 89 |
+
tokenizer_path = model_path_or_repo
|
| 90 |
+
|
| 91 |
+
# Find all the checkpoint folders, folders that start with "checkpoint-". Then find the last one.
|
| 92 |
+
checkpoint_folders = glob.glob(os.path.join(model_path_or_repo, "checkpoint-*"))
|
| 93 |
+
for checkpoint_folder in checkpoint_folders:
|
| 94 |
+
if checkpoint_folder.endswith("-last"):
|
| 95 |
+
model_path = checkpoint_folder
|
| 96 |
+
break
|
| 97 |
+
if model_path is None:
|
| 98 |
+
raise ValueError("No model checkpoint found.")
|
| 99 |
+
|
| 100 |
+
# Find the tokenizer folder.
|
| 101 |
+
if os.path.exists(os.path.join(model_path_or_repo, "tokenizer.json")):
|
| 102 |
+
tokenizer_path = model_path_or_repo
|
| 103 |
+
if not os.path.exists(tokenizer_path):
|
| 104 |
+
raise ValueError("Tokenizer not found.")
|
| 105 |
+
|
| 106 |
+
# Load the config.
|
| 107 |
+
config_path = os.path.join(model_path, "config.yaml")
|
| 108 |
+
if not os.path.exists(config_path):
|
| 109 |
+
raise ValueError(f"Config not found at {config_path}")
|
| 110 |
+
model_config = OmegaConf.load(config_path)
|
| 111 |
+
|
| 112 |
+
# Override the config.
|
| 113 |
+
if config_overrides != {} and config_overrides is not None:
|
| 114 |
+
model_config = OmegaConf.merge(model_config, config_overrides)
|
| 115 |
+
import json
|
| 116 |
+
print(json.dumps(OmegaConf.to_container(model_config), indent=4))
|
| 117 |
+
|
| 118 |
+
# Create the model from the config.
|
| 119 |
+
model = model_from_config(model_config, device=self.device)
|
| 120 |
+
model.to(self.device)
|
| 121 |
+
self.config = model_config
|
| 122 |
+
|
| 123 |
+
# Load the weights from the checkpoint.
|
| 124 |
+
weights_path = os.path.join(model_path, "model.safetensors")
|
| 125 |
+
if not os.path.exists(weights_path):
|
| 126 |
+
raise ValueError(f"Weights not found at {weights_path}")
|
| 127 |
+
state_dict = load_file(weights_path)
|
| 128 |
+
|
| 129 |
+
# TODO: Permute the last two dimensions of these parameters: xlstm_block_stack.blocks.2.xlstm.slstm_cell._recurrent_kernel_:
|
| 130 |
+
# Check if we have an xLSTM model and if CUDA is not available.
|
| 131 |
+
if not torch.cuda.is_available() and model_config.get("type", "xLSTMLMModel") == "xLSTMLMModel":
|
| 132 |
+
print(state_dict.keys())
|
| 133 |
+
endings = ["xlstm.slstm_cell._recurrent_kernel_"]
|
| 134 |
+
for key, values in state_dict.items():
|
| 135 |
+
for ending in endings:
|
| 136 |
+
if key.endswith(ending):
|
| 137 |
+
print(key)
|
| 138 |
+
print(values.shape)
|
| 139 |
+
|
| 140 |
+
# Option: Permute the last two dimensions.
|
| 141 |
+
values = values.permute(0, 2, 1)
|
| 142 |
+
|
| 143 |
+
# Option: View the tensor.
|
| 144 |
+
#new_shape = (values.shape[0], values.shape[2], values.shape[1])
|
| 145 |
+
#values = values.view(new_shape)
|
| 146 |
+
|
| 147 |
+
print(values.shape)
|
| 148 |
+
state_dict[key] = values
|
| 149 |
+
break
|
| 150 |
+
|
| 151 |
+
# Load the weights into the model.
|
| 152 |
+
model.load_state_dict(state_dict)
|
| 153 |
+
self.model = model
|
| 154 |
+
|
| 155 |
+
# Load the tokenizer.
|
| 156 |
+
tokenizer_path = os.path.join(tokenizer_path, "tokenizer.json")
|
| 157 |
+
if not os.path.exists(tokenizer_path):
|
| 158 |
+
raise ValueError(f"Tokenizer not found at {tokenizer_path}")
|
| 159 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
|
| 160 |
+
self.tokenizer = tokenizer
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def generate(
|
| 164 |
+
self,
|
| 165 |
+
prompt: str,
|
| 166 |
+
temperature: float = 1.0,
|
| 167 |
+
max_length: int = 100,
|
| 168 |
+
end_tokens: list[str] = [],
|
| 169 |
+
forbidden_tokens: list[str] = [],
|
| 170 |
+
return_structured_output: bool = False
|
| 171 |
+
):
|
| 172 |
+
"""
|
| 173 |
+
Generates a continuation for a given prompt using the language model.
|
| 174 |
+
Args:
|
| 175 |
+
prompt (str): The prompt to generate a continuation for.
|
| 176 |
+
temperature (float, optional): The temperature value for controlling the randomness of the generated output.
|
| 177 |
+
Higher values (e.g., 1.0) make the output more random, while lower values (e.g., 0.5) make it more deterministic.
|
| 178 |
+
Defaults to 1.0.
|
| 179 |
+
max_length (int, optional): The maximum length of the generated output. Defaults to 100.
|
| 180 |
+
end_tokens (list[str], optional): A list of end tokens that, if encountered, will stop the generation process.
|
| 181 |
+
Defaults to an empty list.
|
| 182 |
+
return_structured_output (bool, optional): If True, returns a dictionary with the generated output, elapsed time,
|
| 183 |
+
and tokens per second. If False, returns only the generated output as a string. Defaults to False.
|
| 184 |
+
Returns:
|
| 185 |
+
str or dict: The generated output as a string if return_structured_output is False.
|
| 186 |
+
A dictionary with the generated output, elapsed time, and tokens per second if return_structured_output is True.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
# Tokenize the prompt.
|
| 190 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
|
| 191 |
+
assert inputs.shape[0] == 1
|
| 192 |
+
|
| 193 |
+
# Determine the end tokens ids.
|
| 194 |
+
end_token_ids = []
|
| 195 |
+
for end_token in end_tokens:
|
| 196 |
+
assert end_token in self.tokenizer.vocab
|
| 197 |
+
end_token_ids.append(self.tokenizer(end_token).input_ids[0])
|
| 198 |
+
|
| 199 |
+
# Initialize the ids to mask.
|
| 200 |
+
ids_to_mask = []
|
| 201 |
+
|
| 202 |
+
# Mask the forbidden tokens.
|
| 203 |
+
for forbidden_token in forbidden_tokens:
|
| 204 |
+
assert forbidden_token in self.tokenizer.vocab
|
| 205 |
+
ids_to_mask.extend(self.tokenizer(forbidden_token).input_ids)
|
| 206 |
+
|
| 207 |
+
# Generate the continuation.
|
| 208 |
+
start_time = time.time()
|
| 209 |
+
tokens_count = 0
|
| 210 |
+
while inputs.shape[1] < max_length:
|
| 211 |
+
|
| 212 |
+
# Stop if the maximum context length is reached.
|
| 213 |
+
if inputs.shape[1] >= self.config.context_length:
|
| 214 |
+
print("Warning: The maximum context length has been reached.")
|
| 215 |
+
break
|
| 216 |
+
|
| 217 |
+
# Generate the continuation.
|
| 218 |
+
outputs = self.model(inputs.to(device=self.device))
|
| 219 |
+
assert outputs.shape[0] == 1
|
| 220 |
+
|
| 221 |
+
# Mask the tokens.
|
| 222 |
+
outputs[:, :, self.tokenizer.all_special_ids] = float("-inf")
|
| 223 |
+
|
| 224 |
+
# Use the temperature to sample from the distribution.
|
| 225 |
+
outputs = outputs / temperature
|
| 226 |
+
outputs = torch.nn.functional.softmax(outputs, dim=-1)
|
| 227 |
+
outputs = torch.multinomial(outputs[0, -1], num_samples=1)
|
| 228 |
+
|
| 229 |
+
# Add to the inputs.
|
| 230 |
+
inputs = torch.cat([inputs, outputs.unsqueeze(0)], dim=1)
|
| 231 |
+
|
| 232 |
+
# Increment the tokens count.
|
| 233 |
+
tokens_count += 1
|
| 234 |
+
|
| 235 |
+
# Check if the end token is reached.
|
| 236 |
+
if outputs[0] in end_token_ids:
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
# Print the elapsed time and tokens per second.
|
| 240 |
+
elapsed_time = time.time() - start_time
|
| 241 |
+
tokens_per_second = tokens_count / elapsed_time
|
| 242 |
+
|
| 243 |
+
# Decode the output.
|
| 244 |
+
output = self.tokenizer.decode(inputs[0].tolist())
|
| 245 |
+
|
| 246 |
+
# Return the output.
|
| 247 |
+
if not return_structured_output:
|
| 248 |
+
return output
|
| 249 |
+
|
| 250 |
+
# Return the structured output.
|
| 251 |
+
else:
|
| 252 |
+
return {
|
| 253 |
+
"output": output,
|
| 254 |
+
"elapsed_time": elapsed_time,
|
| 255 |
+
"tokens_per_second": tokens_per_second
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
def summary(self):
|
| 259 |
+
"""
|
| 260 |
+
Prints a summary of the model. Makes the model architecture readable. Includes the number of parameters.
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
# Print the model.
|
| 264 |
+
print(self.model)
|
| 265 |
+
|
| 266 |
+
# Get the number of parameters.
|
| 267 |
+
number_of_parameters = sum(p.numel() for p in self.model.parameters())
|
| 268 |
+
print(f"Number of parameters: {number_of_parameters:_}")
|
| 269 |
+
sizes = ["", "K", "M", "B", "T"]
|
| 270 |
+
size_index = 0
|
| 271 |
+
while number_of_parameters > 1000:
|
| 272 |
+
number_of_parameters /= 1000
|
| 273 |
+
size_index += 1
|
| 274 |
+
print(f"Number of parameters: {number_of_parameters:.2f}{sizes[size_index]}")
|
| 275 |
+
|
| 276 |
+
# Size of the model.
|
| 277 |
+
# Get the total size of all the markdown files. And make it human readable.
|
| 278 |
+
number_of_parameters = sum(p.numel() for p in self.model.parameters())
|
| 279 |
+
total_size = number_of_parameters * 4
|
| 280 |
+
sizes = ["B", "KB", "MB", "GB", "TB"]
|
| 281 |
+
size_index = 0
|
| 282 |
+
while total_size > 1024:
|
| 283 |
+
total_size /= 1024
|
| 284 |
+
size_index += 1
|
| 285 |
+
print(f"Total size of the model: {total_size:.2f}{sizes[size_index]} for precision 32-bit floats.")
|
| 286 |
+
|
| 287 |
+
# Print on which device the model is running.
|
| 288 |
+
print(f"Device: {self.device}")
|
source/utilities.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import note_seq
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import tempfile
|
| 5 |
+
import os
|
| 6 |
+
import colorama
|
| 7 |
+
from omegaconf import DictConfig, OmegaConf
|
| 8 |
+
import torch
|
| 9 |
+
from typing import List, Tuple, Dict
|
| 10 |
+
from dacite import from_dict
|
| 11 |
+
from collections.abc import MutableMapping
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# NOTE: Imported from helibrunna.
|
| 16 |
+
def display_logo():
|
| 17 |
+
"""
|
| 18 |
+
Display the logo by printing it line by line with a cyberpunk color scheme.
|
| 19 |
+
|
| 20 |
+
Raises:
|
| 21 |
+
FileNotFoundError: If the logo file is missing.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
# Get the path of this script and use it to find the logo.
|
| 25 |
+
script_path = os.path.dirname(os.path.realpath(__file__))
|
| 26 |
+
search_path = os.path.dirname(script_path)
|
| 27 |
+
|
| 28 |
+
# Load the logo.
|
| 29 |
+
logo_path = os.path.join(search_path, "assets", "asciilogo.txt")
|
| 30 |
+
if not os.path.exists(logo_path):
|
| 31 |
+
raise FileNotFoundError("The logo file is missing.")
|
| 32 |
+
with open(logo_path, "r") as f:
|
| 33 |
+
logo = f.read()
|
| 34 |
+
|
| 35 |
+
# Print the logo line by line. Use colorama to colorize the output. Use a cyberpunk color scheme.
|
| 36 |
+
for line_index, line in enumerate(logo.split("\n")):
|
| 37 |
+
color = colorama.Fore.GREEN
|
| 38 |
+
style = colorama.Style.BRIGHT if line_index % 2 == 0 else colorama.Style.NORMAL
|
| 39 |
+
print(color + style + line)
|
| 40 |
+
print(colorama.Style.RESET_ALL)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# NOTE: Imported from helibrunna.
|
| 44 |
+
def model_from_config(model_config: DictConfig, device:str) -> torch.nn.Module:
|
| 45 |
+
"""
|
| 46 |
+
Create a model based on the provided model configuration.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
model_config (DictConfig): The configuration for the model.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
The created model.
|
| 53 |
+
|
| 54 |
+
Raises:
|
| 55 |
+
ValueError: If the model type is unknown.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
# Get the model type from the configuration.
|
| 59 |
+
model_type = model_config.get("type", "xLSTMLMModel")
|
| 60 |
+
|
| 61 |
+
# Create the xLSTMLMModel.
|
| 62 |
+
if model_type == "xLSTMLMModel":
|
| 63 |
+
print("Creating xLSTMLMModel...")
|
| 64 |
+
from xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig
|
| 65 |
+
|
| 66 |
+
# If there is no GPU, use the vanilla backend.
|
| 67 |
+
if not torch.cuda.is_available():
|
| 68 |
+
#model_config.backend = "vanilla"
|
| 69 |
+
model_config.slstm_block.slstm.backend = "vanilla"
|
| 70 |
+
model_config.mlstm_block.mlstm.backend = "vanilla"
|
| 71 |
+
model_config_object = from_dict(xLSTMLMModelConfig, OmegaConf.to_container(model_config))
|
| 72 |
+
|
| 73 |
+
# Create the model.
|
| 74 |
+
model = xLSTMLMModel(model_config_object)
|
| 75 |
+
model.reset_parameters()
|
| 76 |
+
|
| 77 |
+
# Create the GPT2LMModel.
|
| 78 |
+
elif model_type == "gpt2":
|
| 79 |
+
print("Creating GPT2LMModel...")
|
| 80 |
+
from .models.gpttwo import GPT2LMModel, GPT2LMModelConfig
|
| 81 |
+
model_config_object = from_dict(GPT2LMModelConfig, OmegaConf.to_container(model_config))
|
| 82 |
+
model = GPT2LMModel(model_config_object)
|
| 83 |
+
|
| 84 |
+
# Create the MambaLM.
|
| 85 |
+
elif model_type == "mamba":
|
| 86 |
+
print("Creating Mamba LM...")
|
| 87 |
+
from mambapy.lm import LM, MambaConfig
|
| 88 |
+
model_config_object = from_dict(MambaConfig, OmegaConf.to_container(model_config))
|
| 89 |
+
model = LM(model_config_object, model_config.vocab_size)
|
| 90 |
+
|
| 91 |
+
# Create the Transformer.
|
| 92 |
+
elif model_type == "transformer":
|
| 93 |
+
from .models.transformer import TransformerConfig, Transformer
|
| 94 |
+
model_config_object = from_dict(TransformerConfig, OmegaConf.to_container(model_config))
|
| 95 |
+
model = Transformer(model_config_object)
|
| 96 |
+
|
| 97 |
+
# Create a Pharia instance.
|
| 98 |
+
elif model_type == "pharia":
|
| 99 |
+
from .models.pharia import PhariaConfig, PhariaModel
|
| 100 |
+
model_config_object = from_dict(PhariaConfig, OmegaConf.to_container(model_config))
|
| 101 |
+
model = PhariaModel(model_config_object)
|
| 102 |
+
|
| 103 |
+
# Create a TransformerXL instance.
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 106 |
+
|
| 107 |
+
# Move the model to the device.
|
| 108 |
+
model.to(device)
|
| 109 |
+
return model
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def convert_tokens_to_songdata(tokens):
|
| 113 |
+
|
| 114 |
+
if isinstance(tokens, str):
|
| 115 |
+
tokens = tokens.split()
|
| 116 |
+
|
| 117 |
+
song_data = {}
|
| 118 |
+
|
| 119 |
+
song_data["tracks"] = []
|
| 120 |
+
|
| 121 |
+
current_track_index = 0
|
| 122 |
+
current_timestep = 0
|
| 123 |
+
for token in tokens:
|
| 124 |
+
if token == "GARLAND_START":
|
| 125 |
+
pass
|
| 126 |
+
elif token == "BAR_START":
|
| 127 |
+
if current_track_index == len(song_data["tracks"]):
|
| 128 |
+
song_data["tracks"] += [{"bars": [], "instrument": "0"}]
|
| 129 |
+
bar_data = {"notes": []}
|
| 130 |
+
song_data["tracks"][current_track_index]["bars"] += [bar_data]
|
| 131 |
+
current_timestep = 0
|
| 132 |
+
elif token.startswith("INST="):
|
| 133 |
+
instrument = token.split("=")[1]
|
| 134 |
+
song_data["tracks"][current_track_index]["instrument"] = instrument
|
| 135 |
+
elif token.startswith("DENSITY="):
|
| 136 |
+
pass
|
| 137 |
+
elif token.startswith("NOTE_ON="):
|
| 138 |
+
note_pitch = int(token.split("=")[1])
|
| 139 |
+
note_data = {
|
| 140 |
+
"note": note_pitch,
|
| 141 |
+
"start": current_timestep,
|
| 142 |
+
"end": current_timestep,
|
| 143 |
+
"veloctiy": 80
|
| 144 |
+
}
|
| 145 |
+
song_data["tracks"][current_track_index]["bars"][-1]["notes"] += [note_data]
|
| 146 |
+
pass
|
| 147 |
+
elif token.startswith("TIME_DELTA="):
|
| 148 |
+
current_timestep += int(token.split("=")[1])
|
| 149 |
+
elif token.startswith("NOTE_OFF="):
|
| 150 |
+
note_pitch = int(token.split("=")[1])
|
| 151 |
+
for note_data in song_data["tracks"][current_track_index]["bars"][-1]["notes"]:
|
| 152 |
+
if note_data["note"] == note_pitch and note_data["start"] == note_data["end"]:
|
| 153 |
+
note_data["end"] = current_timestep
|
| 154 |
+
break
|
| 155 |
+
pass
|
| 156 |
+
elif token == "BAR_END":
|
| 157 |
+
current_track_index += 1
|
| 158 |
+
elif token == "NEXT":
|
| 159 |
+
current_track_index = 0
|
| 160 |
+
elif token == "GARLAND_END":
|
| 161 |
+
pass
|
| 162 |
+
elif token == "[PAD]":
|
| 163 |
+
pass
|
| 164 |
+
elif token == "[EOS]":
|
| 165 |
+
pass
|
| 166 |
+
else:
|
| 167 |
+
raise Exception(f"Unknown token: {token}")
|
| 168 |
+
|
| 169 |
+
assert isinstance(song_data, dict)
|
| 170 |
+
return song_data
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def convert_songdata_to_notesequence(song_data:dict, quantize_steps_per_quarter=8, remove_disabled_tracks=True):
|
| 174 |
+
|
| 175 |
+
assert isinstance(song_data, dict), f"Invalid song data type: {type(song_data)}"
|
| 176 |
+
|
| 177 |
+
# Clone the song data.
|
| 178 |
+
song_data = copy.deepcopy(song_data)
|
| 179 |
+
|
| 180 |
+
# Sort the tracks by instrument.
|
| 181 |
+
assert "tracks" in song_data, f"Invalid song data: {song_data.keys()}"
|
| 182 |
+
tracks = sorted(song_data["tracks"], key=lambda t: t["instrument"])
|
| 183 |
+
song_data["tracks"] = tracks
|
| 184 |
+
|
| 185 |
+
# Remove tracks that are not enabled.
|
| 186 |
+
if remove_disabled_tracks:
|
| 187 |
+
song_data["tracks"] = [t for t in song_data["tracks"] if t.get("enabled", True)]
|
| 188 |
+
|
| 189 |
+
# Create an empy note sequence.
|
| 190 |
+
note_sequence = note_seq.protobuf.music_pb2.NoteSequence()
|
| 191 |
+
|
| 192 |
+
# Add the tempo.
|
| 193 |
+
bpm = song_data["bpm"] if "bpm" in song_data else 120
|
| 194 |
+
note_sequence.tempos.add().qpm = bpm
|
| 195 |
+
|
| 196 |
+
# Compute some lengths.
|
| 197 |
+
step_length_seconds = 60.0 / bpm / quantize_steps_per_quarter
|
| 198 |
+
bar_length_seconds = 4 * step_length_seconds * quantize_steps_per_quarter
|
| 199 |
+
|
| 200 |
+
# Get the instruments.
|
| 201 |
+
instruments = list(set([t["instrument"] for t in song_data["tracks"]]))
|
| 202 |
+
|
| 203 |
+
# Add the tracks.
|
| 204 |
+
for track_index, track_data in enumerate(song_data["tracks"]):
|
| 205 |
+
instrument = track_data["instrument"]
|
| 206 |
+
for bar_index, bar_data in enumerate(track_data["bars"]):
|
| 207 |
+
bar_start_time = bar_index * bar_length_seconds
|
| 208 |
+
for note_data in bar_data["notes"]:
|
| 209 |
+
assert "note" in note_data
|
| 210 |
+
assert "start" in note_data
|
| 211 |
+
assert "end" in note_data
|
| 212 |
+
note = note_sequence.notes.add()
|
| 213 |
+
#note.instrument = instrument TODO
|
| 214 |
+
note.pitch = note_data["note"]
|
| 215 |
+
note.start_time = note_data["start"] * step_length_seconds + bar_start_time
|
| 216 |
+
note.end_time = note_data["end"] * step_length_seconds + bar_start_time
|
| 217 |
+
if "velocity" in note_data:
|
| 218 |
+
note.velocity = note_data["velocity"]
|
| 219 |
+
else:
|
| 220 |
+
note.velocity = 80
|
| 221 |
+
note.instrument = track_index
|
| 222 |
+
if instrument == "drums":
|
| 223 |
+
note.is_drum = True
|
| 224 |
+
else:
|
| 225 |
+
note.is_drum = False
|
| 226 |
+
note.program = int(instrument)
|
| 227 |
+
|
| 228 |
+
return note_sequence
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def convert_songdata_to_pianoroll(song_data):
|
| 232 |
+
|
| 233 |
+
# The bars are 4/4 and the quantization is 8 steps per quarter, aka 32 steps per bar.
|
| 234 |
+
# We will render a grid. The height is 64 pixels. The width is 32 pixels per bar
|
| 235 |
+
|
| 236 |
+
# Create a new image.
|
| 237 |
+
lengths = [len(track["bars"]) for track in song_data["tracks"]]
|
| 238 |
+
if lengths == []:
|
| 239 |
+
return None
|
| 240 |
+
assert len(set(lengths)) == 1, f"Unequal number of bars: {lengths}"
|
| 241 |
+
num_bars = lengths[0]
|
| 242 |
+
|
| 243 |
+
# Get the note extremes.
|
| 244 |
+
min_note = 128
|
| 245 |
+
max_note = 0
|
| 246 |
+
for track_data in song_data["tracks"]:
|
| 247 |
+
for bar_data in track_data["bars"]:
|
| 248 |
+
for note_data in bar_data["notes"]:
|
| 249 |
+
min_note = min(min_note, note_data["note"])
|
| 250 |
+
max_note = max(max_note, note_data["note"])
|
| 251 |
+
|
| 252 |
+
# The width depends on the bars.
|
| 253 |
+
width = 32 * num_bars
|
| 254 |
+
|
| 255 |
+
# The width depends on the notes.
|
| 256 |
+
height = 1 + max_note - min_note
|
| 257 |
+
|
| 258 |
+
# Create the image.
|
| 259 |
+
image = Image.new("RGB", (width, height), (14, 17, 23))
|
| 260 |
+
|
| 261 |
+
# Define some colors.
|
| 262 |
+
base_color = (255, 75, 75)
|
| 263 |
+
adjustments = [1.2, 1.0, 0.8, 0.6]
|
| 264 |
+
colors = []
|
| 265 |
+
for adjustment in adjustments:
|
| 266 |
+
import colorsys
|
| 267 |
+
rgb = base_color
|
| 268 |
+
rgb = [float(c) / 255.0 for c in rgb]
|
| 269 |
+
hsv = colorsys.rgb_to_hsv(*rgb)
|
| 270 |
+
# Rotate the hue.
|
| 271 |
+
offset = (adjustment - 1.0) * 0.1
|
| 272 |
+
hsv = (hsv[0] + offset, hsv[1], hsv[2])
|
| 273 |
+
rgb = colorsys.hsv_to_rgb(*hsv)
|
| 274 |
+
rgb = tuple([int(255.0 * c) for c in rgb])
|
| 275 |
+
colors += [rgb]
|
| 276 |
+
print("")
|
| 277 |
+
|
| 278 |
+
for color in colors:
|
| 279 |
+
print(color)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# Draw the grid.
|
| 284 |
+
for track_index, track_data in enumerate(song_data["tracks"]):
|
| 285 |
+
color = colors[track_index % len(colors)]
|
| 286 |
+
for bar_index, bar_data in enumerate(track_data["bars"]):
|
| 287 |
+
x = bar_index * 32
|
| 288 |
+
|
| 289 |
+
for note_data in bar_data["notes"]:
|
| 290 |
+
y = max_note - note_data["note"]
|
| 291 |
+
assert y >= 0 and y < height, f"Invalid y: {y}, note {note_data['note']} min_note: {min_note}, max_note: {max_note}, difference: {max_note - min_note}, height: {height}"
|
| 292 |
+
for i in range(note_data["start"], note_data["end"]):
|
| 293 |
+
image.putpixel((x + i, y), color)
|
| 294 |
+
|
| 295 |
+
# Resize the image. Use nearest neighbor for pixel art.
|
| 296 |
+
factor = 4
|
| 297 |
+
image = image.resize((width * factor, height * factor), Image.NEAREST)
|
| 298 |
+
|
| 299 |
+
return image
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def convert_notesequence_to_wave(note_sequence):
|
| 303 |
+
|
| 304 |
+
if len(note_sequence.notes) == 0:
|
| 305 |
+
return None
|
| 306 |
+
|
| 307 |
+
try:
|
| 308 |
+
synthesizer = note_seq.fluidsynth
|
| 309 |
+
wave = synthesizer(note_sequence, sample_rate=44100)
|
| 310 |
+
return wave
|
| 311 |
+
except Exception as e:
|
| 312 |
+
synthesizer = note_seq.synthesize
|
| 313 |
+
wave = synthesizer(note_sequence)
|
| 314 |
+
return wave
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def convert_notesequence_to_midi(note_sequence, filename="output.mid"):
|
| 318 |
+
|
| 319 |
+
if len(note_sequence.notes) == 0:
|
| 320 |
+
return None
|
| 321 |
+
|
| 322 |
+
# Returns the file content of the midi file.
|
| 323 |
+
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
| 324 |
+
filename = temp_file.name
|
| 325 |
+
note_seq.sequence_proto_to_midi_file(note_sequence, filename)
|
| 326 |
+
with open(filename, "rb") as file:
|
| 327 |
+
content = file.read()
|
| 328 |
+
return content
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|