NeoPy commited on
Commit
9260206
·
verified ·
1 Parent(s): 6471467

Create cover.py

Browse files
Files changed (1) hide show
  1. cover.py +239 -0
cover.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
3
+ from telegram.ext import (
4
+ Application,
5
+ CommandHandler,
6
+ CallbackQueryHandler,
7
+ MessageHandler,
8
+ filters,
9
+ ContextTypes,
10
+ )
11
+ from main import song_cover_pipeline # Your song generation pipeline
12
+ from webui import download_online_model # Your model download function
13
+
14
+ # Define paths
15
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ output_dir = os.path.join(BASE_DIR, 'song_output')
17
+ os.makedirs(output_dir, exist_ok=True)
18
+
19
+ # Start command shows the main menu
20
+ async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
21
+ keyboard = [
22
+ [InlineKeyboardButton("Generate Song", callback_data='generate')],
23
+ [InlineKeyboardButton("Download Model", callback_data='download_model')],
24
+ [InlineKeyboardButton("Help", callback_data='help')]
25
+ ]
26
+ reply_markup = InlineKeyboardMarkup(keyboard)
27
+ await update.message.reply_text("Welcome to AICoverGen! Choose an option below:", reply_markup=reply_markup)
28
+
29
+ # Callback query handler for all button presses
30
+ async def handle_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
31
+ query = update.callback_query
32
+ await query.answer()
33
+ data = query.data
34
+
35
+ # Main menu selections
36
+ if data == 'generate':
37
+ context.user_data.clear()
38
+ context.user_data['mode'] = 'generate'
39
+ context.user_data['keep_files'] = False # default value
40
+ # Ask which input method to use
41
+ keyboard = [
42
+ [InlineKeyboardButton("Input YouTube Link", callback_data='gen_input_link')],
43
+ [InlineKeyboardButton("Upload Audio", callback_data='gen_input_audio')]
44
+ ]
45
+ reply_markup = InlineKeyboardMarkup(keyboard)
46
+ await query.edit_message_text("Choose input method for your song:", reply_markup=reply_markup)
47
+ return
48
+
49
+ if data == 'download_model':
50
+ context.user_data.clear()
51
+ context.user_data['mode'] = 'download_model'
52
+ keyboard = [
53
+ [InlineKeyboardButton("Set URL", callback_data='dl_set_url')],
54
+ [InlineKeyboardButton("Set Model Name", callback_data='dl_set_model_name')],
55
+ [InlineKeyboardButton("Submit Download", callback_data='dl_submit')]
56
+ ]
57
+ reply_markup = InlineKeyboardMarkup(keyboard)
58
+ await query.edit_message_text("Download Model - set options using the buttons below:", reply_markup=reply_markup)
59
+ return
60
+
61
+ if data == 'help':
62
+ help_text = (
63
+ "To generate a song:\n"
64
+ "1. Click 'Generate Song'.\n"
65
+ "2. Choose input method: YouTube Link or Upload Audio.\n"
66
+ "3. Provide the song input (link or audio file).\n"
67
+ "4. Enter the model name.\n"
68
+ "5. Choose pitch: Female (1) or Male (-1).\n"
69
+ "6. Toggle Keep Files if desired.\n"
70
+ "7. Submit to generate the song.\n\n"
71
+ "To download a model:\n"
72
+ "1. Click 'Download Model'.\n"
73
+ "2. Set the URL and Model Name using the buttons.\n"
74
+ "3. Submit to download the model."
75
+ )
76
+ await query.edit_message_text(help_text)
77
+ return
78
+
79
+ # --- GENERATE MODE CALLBACKS ---
80
+ if context.user_data.get('mode') == 'generate':
81
+ if data == 'gen_input_link':
82
+ context.user_data['song_input_method'] = 'link'
83
+ context.user_data['generate_step'] = 'song_input'
84
+ await query.edit_message_text("Please send the YouTube link for the song.")
85
+ return
86
+ if data == 'gen_input_audio':
87
+ context.user_data['song_input_method'] = 'audio'
88
+ context.user_data['generate_step'] = 'audio_input'
89
+ await query.edit_message_text("Please upload the audio file for the song.")
90
+ return
91
+ if data.startswith('pitch_'):
92
+ # Expecting callback data like 'pitch_1' or 'pitch_-1'
93
+ pitch_val = int(data.split('_')[1])
94
+ context.user_data['pitch'] = pitch_val
95
+ # After setting pitch, offer toggle for keep_files and a submit button
96
+ keyboard = [
97
+ [InlineKeyboardButton(f"Keep Files: {'On' if context.user_data.get('keep_files') else 'Off'}", callback_data='toggle_keep_files')],
98
+ [InlineKeyboardButton("Submit", callback_data='submit_generate')]
99
+ ]
100
+ reply_markup = InlineKeyboardMarkup(keyboard)
101
+ await query.edit_message_text("Pitch set. You can toggle Keep Files or submit.", reply_markup=reply_markup)
102
+ return
103
+ if data == 'toggle_keep_files':
104
+ # Toggle the boolean value
105
+ current = context.user_data.get('keep_files', False)
106
+ context.user_data['keep_files'] = not current
107
+ keyboard = [
108
+ [InlineKeyboardButton(f"Keep Files: {'On' if context.user_data['keep_files'] else 'Off'}", callback_data='toggle_keep_files')],
109
+ [InlineKeyboardButton("Submit", callback_data='submit_generate')]
110
+ ]
111
+ reply_markup = InlineKeyboardMarkup(keyboard)
112
+ await query.edit_message_text("Keep Files toggled. You can toggle again or submit.", reply_markup=reply_markup)
113
+ return
114
+ if data == 'submit_generate':
115
+ # Make sure all required parameters are provided
116
+ if context.user_data.get('song_input_method') == 'link' and not context.user_data.get('song_link'):
117
+ await query.edit_message_text("You haven't provided the YouTube link yet.")
118
+ return
119
+ if context.user_data.get('song_input_method') == 'audio' and not context.user_data.get('audio_file'):
120
+ await query.edit_message_text("You haven't uploaded the audio file yet.")
121
+ return
122
+ if not context.user_data.get('model_name'):
123
+ await query.edit_message_text("You haven't provided the model name yet.")
124
+ return
125
+ if 'pitch' not in context.user_data:
126
+ await query.edit_message_text("You haven't selected the pitch yet.")
127
+ return
128
+
129
+ await query.edit_message_text("Generating your song, please wait...")
130
+ song_input = context.user_data.get('song_link') if context.user_data.get('song_input_method') == 'link' else context.user_data.get('audio_file')
131
+ model_name = context.user_data.get('model_name')
132
+ pitch = context.user_data.get('pitch')
133
+ keep_files = context.user_data.get('keep_files')
134
+ is_webui = False
135
+ # Call the song generation pipeline (adjust as needed for your implementation)
136
+ song_output = song_cover_pipeline(song_input, model_name, pitch, keep_files, is_webui)
137
+ if os.path.exists(song_output):
138
+ await context.bot.send_audio(chat_id=update.effective_chat.id, audio=open(song_output, 'rb'))
139
+ os.remove(song_output)
140
+ await context.bot.send_message(chat_id=update.effective_chat.id, text="Song generated successfully!")
141
+ else:
142
+ await context.bot.send_message(chat_id=update.effective_chat.id, text="An error occurred during song generation.")
143
+ context.user_data.clear()
144
+ return
145
+
146
+ # --- DOWNLOAD MODE CALLBACKS ---
147
+ if context.user_data.get('mode') == 'download_model':
148
+ if data == 'dl_set_url':
149
+ context.user_data['download_step'] = 'url'
150
+ await query.edit_message_text("Please send the download URL for the model.")
151
+ return
152
+ if data == 'dl_set_model_name':
153
+ context.user_data['download_step'] = 'model_name'
154
+ await query.edit_message_text("Please enter the custom model name.")
155
+ return
156
+ if data == 'dl_submit':
157
+ if not context.user_data.get('download_url') or not context.user_data.get('download_model_name'):
158
+ await query.edit_message_text("Both URL and model name must be set before submitting.")
159
+ return
160
+ model_url = context.user_data.get('download_url')
161
+ model_name = context.user_data.get('download_model_name')
162
+ try:
163
+ download_online_model(model_url, model_name)
164
+ await query.edit_message_text(f"Model '{model_name}' downloaded successfully from {model_url}!")
165
+ except Exception as e:
166
+ await query.edit_message_text(f"Failed to download model. Error: {str(e)}")
167
+ context.user_data.clear()
168
+ return
169
+
170
+ # Message handler for receiving text or file inputs
171
+ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
172
+ mode = context.user_data.get('mode')
173
+ if mode == 'generate':
174
+ step = context.user_data.get('generate_step')
175
+ if step == 'song_input':
176
+ # Expecting a YouTube link
177
+ context.user_data['song_link'] = update.message.text.strip()
178
+ context.user_data['generate_step'] = 'model_name'
179
+ await update.message.reply_text("YouTube link set. Please enter the model name.")
180
+ return
181
+ if step == 'audio_input':
182
+ # Check for an audio file upload
183
+ if update.message.audio:
184
+ audio_file = await update.message.audio.get_file()
185
+ file_path = os.path.join(output_dir, f"{update.message.audio.file_id}.mp3")
186
+ await audio_file.download_to_drive(file_path)
187
+ context.user_data['audio_file'] = file_path
188
+ context.user_data['generate_step'] = 'model_name'
189
+ await update.message.reply_text("Audio file received. Please enter the model name.")
190
+ else:
191
+ await update.message.reply_text("Please upload a valid audio file.")
192
+ return
193
+ if step == 'model_name':
194
+ context.user_data['model_name'] = update.message.text.strip()
195
+ # Now ask for pitch via inline keyboard
196
+ keyboard = [
197
+ [InlineKeyboardButton("Female (1)", callback_data='pitch_1')],
198
+ [InlineKeyboardButton("Male (-1)", callback_data='pitch_-1')]
199
+ ]
200
+ reply_markup = InlineKeyboardMarkup(keyboard)
201
+ await update.message.reply_text("Please choose the pitch:", reply_markup=reply_markup)
202
+ context.user_data['generate_step'] = None # move on to waiting for pitch
203
+ return
204
+ await update.message.reply_text("Please use the provided buttons to navigate the process.")
205
+
206
+ elif mode == 'download_model':
207
+ download_step = context.user_data.get('download_step')
208
+ if download_step == 'url':
209
+ text = update.message.text.strip()
210
+ if not text.startswith("http"):
211
+ await update.message.reply_text("Please provide a valid URL.")
212
+ return
213
+ context.user_data['download_url'] = text
214
+ context.user_data['download_step'] = None
215
+ await update.message.reply_text("Download URL set.")
216
+ return
217
+ if download_step == 'model_name':
218
+ context.user_data['download_model_name'] = update.message.text.strip()
219
+ context.user_data['download_step'] = None
220
+ await update.message.reply_text("Model name set.")
221
+ return
222
+ await update.message.reply_text("Please use the provided buttons for setting options.")
223
+ else:
224
+ await update.message.reply_text("Please choose an option first by clicking one of the main buttons.")
225
+
226
+ # Main function to run the bot
227
+ def main():
228
+ bot_token = os.environ.get("TELEGRAM_BOT_TOKEN")
229
+ if not bot_token:
230
+ raise ValueError("Bot token not found. Set the TELEGRAM_BOT_TOKEN environment variable.")
231
+
232
+ application = Application.builder().token(bot_token).build()
233
+ application.add_handler(CommandHandler("start", start))
234
+ application.add_handler(CallbackQueryHandler(handle_callback))
235
+ application.add_handler(MessageHandler(filters.ALL, handle_message))
236
+ application.run_polling()
237
+
238
+ if __name__ == '__main__':
239
+ main()