import os from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import ( Application, CommandHandler, CallbackQueryHandler, MessageHandler, filters, ContextTypes, ) from main import song_cover_pipeline # Your song generation pipeline from webui import download_online_model # Your model download function # Define paths BASE_DIR = os.path.dirname(os.path.abspath(__file__)) output_dir = os.path.join(BASE_DIR, 'song_output') os.makedirs(output_dir, exist_ok=True) # Start command shows the main menu async def start(update: Update, context: ContextTypes.DEFAULT_TYPE): keyboard = [ [InlineKeyboardButton("Generate Song", callback_data='generate')], [InlineKeyboardButton("Download Model", callback_data='download_model')], [InlineKeyboardButton("Help", callback_data='help')] ] reply_markup = InlineKeyboardMarkup(keyboard) await update.message.reply_text("Welcome to AICoverGen! Choose an option below:", reply_markup=reply_markup) # Callback query handler for all button presses async def handle_callback(update: Update, context: ContextTypes.DEFAULT_TYPE): query = update.callback_query await query.answer() data = query.data # Main menu selections if data == 'generate': context.user_data.clear() context.user_data['mode'] = 'generate' context.user_data['keep_files'] = False # default value # Ask which input method to use keyboard = [ [InlineKeyboardButton("Input YouTube Link", callback_data='gen_input_link')], [InlineKeyboardButton("Upload Audio", callback_data='gen_input_audio')] ] reply_markup = InlineKeyboardMarkup(keyboard) await query.edit_message_text("Choose input method for your song:", reply_markup=reply_markup) return if data == 'download_model': context.user_data.clear() context.user_data['mode'] = 'download_model' keyboard = [ [InlineKeyboardButton("Set URL", callback_data='dl_set_url')], [InlineKeyboardButton("Set Model Name", callback_data='dl_set_model_name')], [InlineKeyboardButton("Submit Download", callback_data='dl_submit')] ] reply_markup = InlineKeyboardMarkup(keyboard) await query.edit_message_text("Download Model - set options using the buttons below:", reply_markup=reply_markup) return if data == 'help': help_text = ( "To generate a song:\n" "1. Click 'Generate Song'.\n" "2. Choose input method: YouTube Link or Upload Audio.\n" "3. Provide the song input (link or audio file).\n" "4. Enter the model name.\n" "5. Choose pitch: Female (1) or Male (-1).\n" "6. Toggle Keep Files if desired.\n" "7. Submit to generate the song.\n\n" "To download a model:\n" "1. Click 'Download Model'.\n" "2. Set the URL and Model Name using the buttons.\n" "3. Submit to download the model." ) await query.edit_message_text(help_text) return # --- GENERATE MODE CALLBACKS --- if context.user_data.get('mode') == 'generate': if data == 'gen_input_link': context.user_data['song_input_method'] = 'link' context.user_data['generate_step'] = 'song_input' await query.edit_message_text("Please send the YouTube link for the song.") return if data == 'gen_input_audio': context.user_data['song_input_method'] = 'audio' context.user_data['generate_step'] = 'audio_input' await query.edit_message_text("Please upload the audio file for the song.") return if data.startswith('pitch_'): # Expecting callback data like 'pitch_1' or 'pitch_-1' pitch_val = int(data.split('_')[1]) context.user_data['pitch'] = pitch_val # After setting pitch, offer toggle for keep_files and a submit button keyboard = [ [InlineKeyboardButton(f"Keep Files: {'On' if context.user_data.get('keep_files') else 'Off'}", callback_data='toggle_keep_files')], [InlineKeyboardButton("Submit", callback_data='submit_generate')] ] reply_markup = InlineKeyboardMarkup(keyboard) await query.edit_message_text("Pitch set. You can toggle Keep Files or submit.", reply_markup=reply_markup) return if data == 'toggle_keep_files': # Toggle the boolean value current = context.user_data.get('keep_files', False) context.user_data['keep_files'] = not current keyboard = [ [InlineKeyboardButton(f"Keep Files: {'On' if context.user_data['keep_files'] else 'Off'}", callback_data='toggle_keep_files')], [InlineKeyboardButton("Submit", callback_data='submit_generate')] ] reply_markup = InlineKeyboardMarkup(keyboard) await query.edit_message_text("Keep Files toggled. You can toggle again or submit.", reply_markup=reply_markup) return if data == 'submit_generate': # Make sure all required parameters are provided if context.user_data.get('song_input_method') == 'link' and not context.user_data.get('song_link'): await query.edit_message_text("You haven't provided the YouTube link yet.") return if context.user_data.get('song_input_method') == 'audio' and not context.user_data.get('audio_file'): await query.edit_message_text("You haven't uploaded the audio file yet.") return if not context.user_data.get('model_name'): await query.edit_message_text("You haven't provided the model name yet.") return if 'pitch' not in context.user_data: await query.edit_message_text("You haven't selected the pitch yet.") return await query.edit_message_text("Generating your song, please wait...") song_input = context.user_data.get('song_link') if context.user_data.get('song_input_method') == 'link' else context.user_data.get('audio_file') model_name = context.user_data.get('model_name') pitch = context.user_data.get('pitch') keep_files = context.user_data.get('keep_files') is_webui = False # Call the song generation pipeline (adjust as needed for your implementation) song_output = song_cover_pipeline(song_input, model_name, pitch, keep_files, is_webui) if os.path.exists(song_output): await context.bot.send_audio(chat_id=update.effective_chat.id, audio=open(song_output, 'rb')) os.remove(song_output) await context.bot.send_message(chat_id=update.effective_chat.id, text="Song generated successfully!") else: await context.bot.send_message(chat_id=update.effective_chat.id, text="An error occurred during song generation.") context.user_data.clear() return # --- DOWNLOAD MODE CALLBACKS --- if context.user_data.get('mode') == 'download_model': if data == 'dl_set_url': context.user_data['download_step'] = 'url' await query.edit_message_text("Please send the download URL for the model.") return if data == 'dl_set_model_name': context.user_data['download_step'] = 'model_name' await query.edit_message_text("Please enter the custom model name.") return if data == 'dl_submit': if not context.user_data.get('download_url') or not context.user_data.get('download_model_name'): await query.edit_message_text("Both URL and model name must be set before submitting.") return model_url = context.user_data.get('download_url') model_name = context.user_data.get('download_model_name') try: download_online_model(model_url, model_name) await query.edit_message_text(f"Model '{model_name}' downloaded successfully from {model_url}!") except Exception as e: await query.edit_message_text(f"Failed to download model. Error: {str(e)}") context.user_data.clear() return # Message handler for receiving text or file inputs async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE): mode = context.user_data.get('mode') if mode == 'generate': step = context.user_data.get('generate_step') if step == 'song_input': # Expecting a YouTube link context.user_data['song_link'] = update.message.text.strip() context.user_data['generate_step'] = 'model_name' await update.message.reply_text("YouTube link set. Please enter the model name.") return if step == 'audio_input': # Check for an audio file upload if update.message.audio: audio_file = await update.message.audio.get_file() file_path = os.path.join(output_dir, f"{update.message.audio.file_id}.mp3") await audio_file.download_to_drive(file_path) context.user_data['audio_file'] = file_path context.user_data['generate_step'] = 'model_name' await update.message.reply_text("Audio file received. Please enter the model name.") else: await update.message.reply_text("Please upload a valid audio file.") return if step == 'model_name': context.user_data['model_name'] = update.message.text.strip() # Now ask for pitch via inline keyboard keyboard = [ [InlineKeyboardButton("Female (1)", callback_data='pitch_1')], [InlineKeyboardButton("Male (-1)", callback_data='pitch_-1')] ] reply_markup = InlineKeyboardMarkup(keyboard) await update.message.reply_text("Please choose the pitch:", reply_markup=reply_markup) context.user_data['generate_step'] = None # move on to waiting for pitch return await update.message.reply_text("Please use the provided buttons to navigate the process.") elif mode == 'download_model': download_step = context.user_data.get('download_step') if download_step == 'url': text = update.message.text.strip() if not text.startswith("http"): await update.message.reply_text("Please provide a valid URL.") return context.user_data['download_url'] = text context.user_data['download_step'] = None await update.message.reply_text("Download URL set.") return if download_step == 'model_name': context.user_data['download_model_name'] = update.message.text.strip() context.user_data['download_step'] = None await update.message.reply_text("Model name set.") return await update.message.reply_text("Please use the provided buttons for setting options.") else: await update.message.reply_text("Please choose an option first by clicking one of the main buttons.") # Main function to run the bot def main(): bot_token = os.environ.get("TELEGRAM_BOT_TOKEN") if not bot_token: raise ValueError("Bot token not found. Set the TELEGRAM_BOT_TOKEN environment variable.") application = Application.builder().token(bot_token).build() application.add_handler(CommandHandler("start", start)) application.add_handler(CallbackQueryHandler(handle_callback)) application.add_handler(MessageHandler(filters.ALL, handle_message)) application.run_polling() if __name__ == '__main__': main()