File size: 12,050 Bytes
9260206 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
import os
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import (
Application,
CommandHandler,
CallbackQueryHandler,
MessageHandler,
filters,
ContextTypes,
)
from main import song_cover_pipeline # Your song generation pipeline
from webui import download_online_model # Your model download function
# Define paths
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(BASE_DIR, 'song_output')
os.makedirs(output_dir, exist_ok=True)
# Start command shows the main menu
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
keyboard = [
[InlineKeyboardButton("Generate Song", callback_data='generate')],
[InlineKeyboardButton("Download Model", callback_data='download_model')],
[InlineKeyboardButton("Help", callback_data='help')]
]
reply_markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text("Welcome to AICoverGen! Choose an option below:", reply_markup=reply_markup)
# Callback query handler for all button presses
async def handle_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
query = update.callback_query
await query.answer()
data = query.data
# Main menu selections
if data == 'generate':
context.user_data.clear()
context.user_data['mode'] = 'generate'
context.user_data['keep_files'] = False # default value
# Ask which input method to use
keyboard = [
[InlineKeyboardButton("Input YouTube Link", callback_data='gen_input_link')],
[InlineKeyboardButton("Upload Audio", callback_data='gen_input_audio')]
]
reply_markup = InlineKeyboardMarkup(keyboard)
await query.edit_message_text("Choose input method for your song:", reply_markup=reply_markup)
return
if data == 'download_model':
context.user_data.clear()
context.user_data['mode'] = 'download_model'
keyboard = [
[InlineKeyboardButton("Set URL", callback_data='dl_set_url')],
[InlineKeyboardButton("Set Model Name", callback_data='dl_set_model_name')],
[InlineKeyboardButton("Submit Download", callback_data='dl_submit')]
]
reply_markup = InlineKeyboardMarkup(keyboard)
await query.edit_message_text("Download Model - set options using the buttons below:", reply_markup=reply_markup)
return
if data == 'help':
help_text = (
"To generate a song:\n"
"1. Click 'Generate Song'.\n"
"2. Choose input method: YouTube Link or Upload Audio.\n"
"3. Provide the song input (link or audio file).\n"
"4. Enter the model name.\n"
"5. Choose pitch: Female (1) or Male (-1).\n"
"6. Toggle Keep Files if desired.\n"
"7. Submit to generate the song.\n\n"
"To download a model:\n"
"1. Click 'Download Model'.\n"
"2. Set the URL and Model Name using the buttons.\n"
"3. Submit to download the model."
)
await query.edit_message_text(help_text)
return
# --- GENERATE MODE CALLBACKS ---
if context.user_data.get('mode') == 'generate':
if data == 'gen_input_link':
context.user_data['song_input_method'] = 'link'
context.user_data['generate_step'] = 'song_input'
await query.edit_message_text("Please send the YouTube link for the song.")
return
if data == 'gen_input_audio':
context.user_data['song_input_method'] = 'audio'
context.user_data['generate_step'] = 'audio_input'
await query.edit_message_text("Please upload the audio file for the song.")
return
if data.startswith('pitch_'):
# Expecting callback data like 'pitch_1' or 'pitch_-1'
pitch_val = int(data.split('_')[1])
context.user_data['pitch'] = pitch_val
# After setting pitch, offer toggle for keep_files and a submit button
keyboard = [
[InlineKeyboardButton(f"Keep Files: {'On' if context.user_data.get('keep_files') else 'Off'}", callback_data='toggle_keep_files')],
[InlineKeyboardButton("Submit", callback_data='submit_generate')]
]
reply_markup = InlineKeyboardMarkup(keyboard)
await query.edit_message_text("Pitch set. You can toggle Keep Files or submit.", reply_markup=reply_markup)
return
if data == 'toggle_keep_files':
# Toggle the boolean value
current = context.user_data.get('keep_files', False)
context.user_data['keep_files'] = not current
keyboard = [
[InlineKeyboardButton(f"Keep Files: {'On' if context.user_data['keep_files'] else 'Off'}", callback_data='toggle_keep_files')],
[InlineKeyboardButton("Submit", callback_data='submit_generate')]
]
reply_markup = InlineKeyboardMarkup(keyboard)
await query.edit_message_text("Keep Files toggled. You can toggle again or submit.", reply_markup=reply_markup)
return
if data == 'submit_generate':
# Make sure all required parameters are provided
if context.user_data.get('song_input_method') == 'link' and not context.user_data.get('song_link'):
await query.edit_message_text("You haven't provided the YouTube link yet.")
return
if context.user_data.get('song_input_method') == 'audio' and not context.user_data.get('audio_file'):
await query.edit_message_text("You haven't uploaded the audio file yet.")
return
if not context.user_data.get('model_name'):
await query.edit_message_text("You haven't provided the model name yet.")
return
if 'pitch' not in context.user_data:
await query.edit_message_text("You haven't selected the pitch yet.")
return
await query.edit_message_text("Generating your song, please wait...")
song_input = context.user_data.get('song_link') if context.user_data.get('song_input_method') == 'link' else context.user_data.get('audio_file')
model_name = context.user_data.get('model_name')
pitch = context.user_data.get('pitch')
keep_files = context.user_data.get('keep_files')
is_webui = False
# Call the song generation pipeline (adjust as needed for your implementation)
song_output = song_cover_pipeline(song_input, model_name, pitch, keep_files, is_webui)
if os.path.exists(song_output):
await context.bot.send_audio(chat_id=update.effective_chat.id, audio=open(song_output, 'rb'))
os.remove(song_output)
await context.bot.send_message(chat_id=update.effective_chat.id, text="Song generated successfully!")
else:
await context.bot.send_message(chat_id=update.effective_chat.id, text="An error occurred during song generation.")
context.user_data.clear()
return
# --- DOWNLOAD MODE CALLBACKS ---
if context.user_data.get('mode') == 'download_model':
if data == 'dl_set_url':
context.user_data['download_step'] = 'url'
await query.edit_message_text("Please send the download URL for the model.")
return
if data == 'dl_set_model_name':
context.user_data['download_step'] = 'model_name'
await query.edit_message_text("Please enter the custom model name.")
return
if data == 'dl_submit':
if not context.user_data.get('download_url') or not context.user_data.get('download_model_name'):
await query.edit_message_text("Both URL and model name must be set before submitting.")
return
model_url = context.user_data.get('download_url')
model_name = context.user_data.get('download_model_name')
try:
download_online_model(model_url, model_name)
await query.edit_message_text(f"Model '{model_name}' downloaded successfully from {model_url}!")
except Exception as e:
await query.edit_message_text(f"Failed to download model. Error: {str(e)}")
context.user_data.clear()
return
# Message handler for receiving text or file inputs
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
mode = context.user_data.get('mode')
if mode == 'generate':
step = context.user_data.get('generate_step')
if step == 'song_input':
# Expecting a YouTube link
context.user_data['song_link'] = update.message.text.strip()
context.user_data['generate_step'] = 'model_name'
await update.message.reply_text("YouTube link set. Please enter the model name.")
return
if step == 'audio_input':
# Check for an audio file upload
if update.message.audio:
audio_file = await update.message.audio.get_file()
file_path = os.path.join(output_dir, f"{update.message.audio.file_id}.mp3")
await audio_file.download_to_drive(file_path)
context.user_data['audio_file'] = file_path
context.user_data['generate_step'] = 'model_name'
await update.message.reply_text("Audio file received. Please enter the model name.")
else:
await update.message.reply_text("Please upload a valid audio file.")
return
if step == 'model_name':
context.user_data['model_name'] = update.message.text.strip()
# Now ask for pitch via inline keyboard
keyboard = [
[InlineKeyboardButton("Female (1)", callback_data='pitch_1')],
[InlineKeyboardButton("Male (-1)", callback_data='pitch_-1')]
]
reply_markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text("Please choose the pitch:", reply_markup=reply_markup)
context.user_data['generate_step'] = None # move on to waiting for pitch
return
await update.message.reply_text("Please use the provided buttons to navigate the process.")
elif mode == 'download_model':
download_step = context.user_data.get('download_step')
if download_step == 'url':
text = update.message.text.strip()
if not text.startswith("http"):
await update.message.reply_text("Please provide a valid URL.")
return
context.user_data['download_url'] = text
context.user_data['download_step'] = None
await update.message.reply_text("Download URL set.")
return
if download_step == 'model_name':
context.user_data['download_model_name'] = update.message.text.strip()
context.user_data['download_step'] = None
await update.message.reply_text("Model name set.")
return
await update.message.reply_text("Please use the provided buttons for setting options.")
else:
await update.message.reply_text("Please choose an option first by clicking one of the main buttons.")
# Main function to run the bot
def main():
bot_token = os.environ.get("TELEGRAM_BOT_TOKEN")
if not bot_token:
raise ValueError("Bot token not found. Set the TELEGRAM_BOT_TOKEN environment variable.")
application = Application.builder().token(bot_token).build()
application.add_handler(CommandHandler("start", start))
application.add_handler(CallbackQueryHandler(handle_callback))
application.add_handler(MessageHandler(filters.ALL, handle_message))
application.run_polling()
if __name__ == '__main__':
main() |