|
import os |
|
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup |
|
from telegram.ext import ( |
|
Application, |
|
CommandHandler, |
|
CallbackQueryHandler, |
|
MessageHandler, |
|
filters, |
|
ContextTypes, |
|
) |
|
from main import song_cover_pipeline |
|
from webui import download_online_model |
|
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
output_dir = os.path.join(BASE_DIR, 'song_output') |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE): |
|
keyboard = [ |
|
[InlineKeyboardButton("Generate Song", callback_data='generate')], |
|
[InlineKeyboardButton("Download Model", callback_data='download_model')], |
|
[InlineKeyboardButton("Help", callback_data='help')] |
|
] |
|
reply_markup = InlineKeyboardMarkup(keyboard) |
|
await update.message.reply_text("Welcome to AICoverGen! Choose an option below:", reply_markup=reply_markup) |
|
|
|
|
|
async def handle_callback(update: Update, context: ContextTypes.DEFAULT_TYPE): |
|
query = update.callback_query |
|
await query.answer() |
|
data = query.data |
|
|
|
|
|
if data == 'generate': |
|
context.user_data.clear() |
|
context.user_data['mode'] = 'generate' |
|
context.user_data['keep_files'] = False |
|
|
|
keyboard = [ |
|
[InlineKeyboardButton("Input YouTube Link", callback_data='gen_input_link')], |
|
[InlineKeyboardButton("Upload Audio", callback_data='gen_input_audio')] |
|
] |
|
reply_markup = InlineKeyboardMarkup(keyboard) |
|
await query.edit_message_text("Choose input method for your song:", reply_markup=reply_markup) |
|
return |
|
|
|
if data == 'download_model': |
|
context.user_data.clear() |
|
context.user_data['mode'] = 'download_model' |
|
keyboard = [ |
|
[InlineKeyboardButton("Set URL", callback_data='dl_set_url')], |
|
[InlineKeyboardButton("Set Model Name", callback_data='dl_set_model_name')], |
|
[InlineKeyboardButton("Submit Download", callback_data='dl_submit')] |
|
] |
|
reply_markup = InlineKeyboardMarkup(keyboard) |
|
await query.edit_message_text("Download Model - set options using the buttons below:", reply_markup=reply_markup) |
|
return |
|
|
|
if data == 'help': |
|
help_text = ( |
|
"To generate a song:\n" |
|
"1. Click 'Generate Song'.\n" |
|
"2. Choose input method: YouTube Link or Upload Audio.\n" |
|
"3. Provide the song input (link or audio file).\n" |
|
"4. Enter the model name.\n" |
|
"5. Choose pitch: Female (1) or Male (-1).\n" |
|
"6. Toggle Keep Files if desired.\n" |
|
"7. Submit to generate the song.\n\n" |
|
"To download a model:\n" |
|
"1. Click 'Download Model'.\n" |
|
"2. Set the URL and Model Name using the buttons.\n" |
|
"3. Submit to download the model." |
|
) |
|
await query.edit_message_text(help_text) |
|
return |
|
|
|
|
|
if context.user_data.get('mode') == 'generate': |
|
if data == 'gen_input_link': |
|
context.user_data['song_input_method'] = 'link' |
|
context.user_data['generate_step'] = 'song_input' |
|
await query.edit_message_text("Please send the YouTube link for the song.") |
|
return |
|
if data == 'gen_input_audio': |
|
context.user_data['song_input_method'] = 'audio' |
|
context.user_data['generate_step'] = 'audio_input' |
|
await query.edit_message_text("Please upload the audio file for the song.") |
|
return |
|
if data.startswith('pitch_'): |
|
|
|
pitch_val = int(data.split('_')[1]) |
|
context.user_data['pitch'] = pitch_val |
|
|
|
keyboard = [ |
|
[InlineKeyboardButton(f"Keep Files: {'On' if context.user_data.get('keep_files') else 'Off'}", callback_data='toggle_keep_files')], |
|
[InlineKeyboardButton("Submit", callback_data='submit_generate')] |
|
] |
|
reply_markup = InlineKeyboardMarkup(keyboard) |
|
await query.edit_message_text("Pitch set. You can toggle Keep Files or submit.", reply_markup=reply_markup) |
|
return |
|
if data == 'toggle_keep_files': |
|
|
|
current = context.user_data.get('keep_files', False) |
|
context.user_data['keep_files'] = not current |
|
keyboard = [ |
|
[InlineKeyboardButton(f"Keep Files: {'On' if context.user_data['keep_files'] else 'Off'}", callback_data='toggle_keep_files')], |
|
[InlineKeyboardButton("Submit", callback_data='submit_generate')] |
|
] |
|
reply_markup = InlineKeyboardMarkup(keyboard) |
|
await query.edit_message_text("Keep Files toggled. You can toggle again or submit.", reply_markup=reply_markup) |
|
return |
|
if data == 'submit_generate': |
|
|
|
if context.user_data.get('song_input_method') == 'link' and not context.user_data.get('song_link'): |
|
await query.edit_message_text("You haven't provided the YouTube link yet.") |
|
return |
|
if context.user_data.get('song_input_method') == 'audio' and not context.user_data.get('audio_file'): |
|
await query.edit_message_text("You haven't uploaded the audio file yet.") |
|
return |
|
if not context.user_data.get('model_name'): |
|
await query.edit_message_text("You haven't provided the model name yet.") |
|
return |
|
if 'pitch' not in context.user_data: |
|
await query.edit_message_text("You haven't selected the pitch yet.") |
|
return |
|
|
|
await query.edit_message_text("Generating your song, please wait...") |
|
song_input = context.user_data.get('song_link') if context.user_data.get('song_input_method') == 'link' else context.user_data.get('audio_file') |
|
model_name = context.user_data.get('model_name') |
|
pitch = context.user_data.get('pitch') |
|
keep_files = context.user_data.get('keep_files') |
|
is_webui = False |
|
|
|
song_output = song_cover_pipeline(song_input, model_name, pitch, keep_files, is_webui) |
|
if os.path.exists(song_output): |
|
await context.bot.send_audio(chat_id=update.effective_chat.id, audio=open(song_output, 'rb')) |
|
os.remove(song_output) |
|
await context.bot.send_message(chat_id=update.effective_chat.id, text="Song generated successfully!") |
|
else: |
|
await context.bot.send_message(chat_id=update.effective_chat.id, text="An error occurred during song generation.") |
|
context.user_data.clear() |
|
return |
|
|
|
|
|
if context.user_data.get('mode') == 'download_model': |
|
if data == 'dl_set_url': |
|
context.user_data['download_step'] = 'url' |
|
await query.edit_message_text("Please send the download URL for the model.") |
|
return |
|
if data == 'dl_set_model_name': |
|
context.user_data['download_step'] = 'model_name' |
|
await query.edit_message_text("Please enter the custom model name.") |
|
return |
|
if data == 'dl_submit': |
|
if not context.user_data.get('download_url') or not context.user_data.get('download_model_name'): |
|
await query.edit_message_text("Both URL and model name must be set before submitting.") |
|
return |
|
model_url = context.user_data.get('download_url') |
|
model_name = context.user_data.get('download_model_name') |
|
try: |
|
download_online_model(model_url, model_name) |
|
await query.edit_message_text(f"Model '{model_name}' downloaded successfully from {model_url}!") |
|
except Exception as e: |
|
await query.edit_message_text(f"Failed to download model. Error: {str(e)}") |
|
context.user_data.clear() |
|
return |
|
|
|
|
|
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE): |
|
mode = context.user_data.get('mode') |
|
if mode == 'generate': |
|
step = context.user_data.get('generate_step') |
|
if step == 'song_input': |
|
|
|
context.user_data['song_link'] = update.message.text.strip() |
|
context.user_data['generate_step'] = 'model_name' |
|
await update.message.reply_text("YouTube link set. Please enter the model name.") |
|
return |
|
if step == 'audio_input': |
|
|
|
if update.message.audio: |
|
audio_file = await update.message.audio.get_file() |
|
file_path = os.path.join(output_dir, f"{update.message.audio.file_id}.mp3") |
|
await audio_file.download_to_drive(file_path) |
|
context.user_data['audio_file'] = file_path |
|
context.user_data['generate_step'] = 'model_name' |
|
await update.message.reply_text("Audio file received. Please enter the model name.") |
|
else: |
|
await update.message.reply_text("Please upload a valid audio file.") |
|
return |
|
if step == 'model_name': |
|
context.user_data['model_name'] = update.message.text.strip() |
|
|
|
keyboard = [ |
|
[InlineKeyboardButton("Female (1)", callback_data='pitch_1')], |
|
[InlineKeyboardButton("Male (-1)", callback_data='pitch_-1')] |
|
] |
|
reply_markup = InlineKeyboardMarkup(keyboard) |
|
await update.message.reply_text("Please choose the pitch:", reply_markup=reply_markup) |
|
context.user_data['generate_step'] = None |
|
return |
|
await update.message.reply_text("Please use the provided buttons to navigate the process.") |
|
|
|
elif mode == 'download_model': |
|
download_step = context.user_data.get('download_step') |
|
if download_step == 'url': |
|
text = update.message.text.strip() |
|
if not text.startswith("http"): |
|
await update.message.reply_text("Please provide a valid URL.") |
|
return |
|
context.user_data['download_url'] = text |
|
context.user_data['download_step'] = None |
|
await update.message.reply_text("Download URL set.") |
|
return |
|
if download_step == 'model_name': |
|
context.user_data['download_model_name'] = update.message.text.strip() |
|
context.user_data['download_step'] = None |
|
await update.message.reply_text("Model name set.") |
|
return |
|
await update.message.reply_text("Please use the provided buttons for setting options.") |
|
else: |
|
await update.message.reply_text("Please choose an option first by clicking one of the main buttons.") |
|
|
|
|
|
def main(): |
|
bot_token = os.environ.get("TELEGRAM_BOT_TOKEN") |
|
if not bot_token: |
|
raise ValueError("Bot token not found. Set the TELEGRAM_BOT_TOKEN environment variable.") |
|
|
|
application = Application.builder().token(bot_token).build() |
|
application.add_handler(CommandHandler("start", start)) |
|
application.add_handler(CallbackQueryHandler(handle_callback)) |
|
application.add_handler(MessageHandler(filters.ALL, handle_message)) |
|
application.run_polling() |
|
|
|
if __name__ == '__main__': |
|
main() |