mrfakename commited on
Commit
44ba063
·
verified ·
1 Parent(s): ff6edf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -1441
app.py CHANGED
@@ -1,1448 +1,33 @@
1
- import os
2
- from huggingface_hub import HfApi, hf_hub_download
3
- from apscheduler.schedulers.background import BackgroundScheduler
4
- from concurrent.futures import ThreadPoolExecutor
5
- from datetime import datetime
6
- import threading # Added for locking
7
- from sqlalchemy import or_ # Added for vote counting query
8
-
9
- year = datetime.now().year
10
- month = datetime.now().month
11
-
12
- # Check if running in a Huggin Face Space
13
- IS_SPACES = False
14
- if os.getenv("SPACE_REPO_NAME"):
15
- print("Running in a Hugging Face Space 🤗")
16
- IS_SPACES = True
17
-
18
- # Setup database sync for HF Spaces
19
- if not os.path.exists("instance/tts_arena.db"):
20
- os.makedirs("instance", exist_ok=True)
21
- try:
22
- print("Database not found, downloading from HF dataset...")
23
- hf_hub_download(
24
- repo_id="TTS-AGI/database-arena-v2",
25
- filename="tts_arena.db",
26
- repo_type="dataset",
27
- local_dir="instance",
28
- token=os.getenv("HF_TOKEN"),
29
- )
30
- print("Database downloaded successfully ✅")
31
- except Exception as e:
32
- print(f"Error downloading database from HF dataset: {str(e)} ⚠️")
33
-
34
- from flask import (
35
- Flask,
36
- render_template,
37
- g,
38
- request,
39
- jsonify,
40
- send_file,
41
- redirect,
42
- url_for,
43
- session,
44
- abort,
45
- )
46
- from flask_login import LoginManager, current_user
47
- from models import *
48
- from auth import auth, init_oauth, is_admin
49
- from admin import admin
50
- import os
51
- from dotenv import load_dotenv
52
- from flask_limiter import Limiter
53
- from flask_limiter.util import get_remote_address
54
- import uuid
55
- import tempfile
56
- import shutil
57
- from tts import predict_tts
58
- import random
59
- import json
60
- from datetime import datetime, timedelta
61
- from flask_migrate import Migrate
62
- import requests
63
- import functools
64
- import time # Added for potential retries
65
-
66
-
67
- # Load environment variables
68
- if not IS_SPACES:
69
- load_dotenv() # Only load .env if not running in a Hugging Face Space
70
 
71
  app = Flask(__name__)
72
- app.config["SECRET_KEY"] = os.getenv("SECRET_KEY", os.urandom(24))
73
- app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv(
74
- "DATABASE_URI", "sqlite:///tts_arena.db"
75
- )
76
- app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
77
- app.config["SESSION_COOKIE_SECURE"] = True
78
- app.config["SESSION_COOKIE_SAMESITE"] = (
79
- "None" if IS_SPACES else "Lax"
80
- ) # HF Spaces uses iframes to load the app, so we need to set SAMESITE to None
81
- app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=30) # Set to desired duration
82
-
83
- # Force HTTPS when running in HuggingFace Spaces
84
- if IS_SPACES:
85
- app.config["PREFERRED_URL_SCHEME"] = "https"
86
-
87
- # Cloudflare Turnstile settings
88
- app.config["TURNSTILE_ENABLED"] = (
89
- os.getenv("TURNSTILE_ENABLED", "False").lower() == "true"
90
- )
91
- app.config["TURNSTILE_SITE_KEY"] = os.getenv("TURNSTILE_SITE_KEY", "")
92
- app.config["TURNSTILE_SECRET_KEY"] = os.getenv("TURNSTILE_SECRET_KEY", "")
93
- app.config["TURNSTILE_VERIFY_URL"] = (
94
- "https://challenges.cloudflare.com/turnstile/v0/siteverify"
95
- )
96
-
97
- migrate = Migrate(app, db)
98
-
99
- # Initialize extensions
100
- db.init_app(app)
101
- login_manager = LoginManager()
102
- login_manager.init_app(app)
103
- login_manager.login_view = "auth.login"
104
-
105
- # Initialize OAuth
106
- init_oauth(app)
107
-
108
- # Configure rate limits
109
- limiter = Limiter(
110
- app=app,
111
- key_func=get_remote_address,
112
- default_limits=["2000 per day", "50 per minute"],
113
- storage_uri="memory://",
114
- )
115
-
116
- # TTS Cache Configuration - Read from environment
117
- TTS_CACHE_SIZE = int(os.getenv("TTS_CACHE_SIZE", "10"))
118
- CACHE_AUDIO_SUBDIR = "cache"
119
- tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at}
120
- tts_cache_lock = threading.Lock()
121
- SMOOTHING_FACTOR_MODEL_SELECTION = 500 # For weighted random model selection
122
- # Increased max_workers to 8 for concurrent generation/refill
123
- cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer')
124
- all_harvard_sentences = [] # Keep the full list available
125
-
126
- # Create temp directories
127
- TEMP_AUDIO_DIR = os.path.join(tempfile.gettempdir(), "tts_arena_audio")
128
- CACHE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, CACHE_AUDIO_SUBDIR)
129
- os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
130
- os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache subdir exists
131
-
132
-
133
- # Store active TTS sessions
134
- app.tts_sessions = {}
135
- tts_sessions = app.tts_sessions
136
-
137
- # Store active conversational sessions
138
- app.conversational_sessions = {}
139
- conversational_sessions = app.conversational_sessions
140
-
141
- # Register blueprints
142
- app.register_blueprint(auth, url_prefix="/auth")
143
- app.register_blueprint(admin)
144
-
145
-
146
- @login_manager.user_loader
147
- def load_user(user_id):
148
- return User.query.get(int(user_id))
149
-
150
-
151
- @app.before_request
152
- def before_request():
153
- g.user = current_user
154
- g.is_admin = is_admin(current_user)
155
-
156
- # Ensure HTTPS for HuggingFace Spaces environment
157
- if IS_SPACES and request.headers.get("X-Forwarded-Proto") == "http":
158
- url = request.url.replace("http://", "https://", 1)
159
- return redirect(url, code=301)
160
-
161
- # Check if Turnstile verification is required
162
- if app.config["TURNSTILE_ENABLED"]:
163
- # Exclude verification routes
164
- excluded_routes = ["verify_turnstile", "turnstile_page", "static"]
165
- if request.endpoint not in excluded_routes:
166
- # Check if user is verified
167
- if not session.get("turnstile_verified"):
168
- # Save original URL for redirect after verification
169
- redirect_url = request.url
170
- # Force HTTPS in HuggingFace Spaces
171
- if IS_SPACES and redirect_url.startswith("http://"):
172
- redirect_url = redirect_url.replace("http://", "https://", 1)
173
-
174
- # If it's an API request, return a JSON response
175
- if request.path.startswith("/api/"):
176
- return jsonify({"error": "Turnstile verification required"}), 403
177
- # For regular requests, redirect to verification page
178
- return redirect(url_for("turnstile_page", redirect_url=redirect_url))
179
- else:
180
- # Check if verification has expired (default: 24 hours)
181
- verification_timeout = (
182
- int(os.getenv("TURNSTILE_TIMEOUT_HOURS", "24")) * 3600
183
- ) # Convert hours to seconds
184
- verified_at = session.get("turnstile_verified_at", 0)
185
- current_time = datetime.utcnow().timestamp()
186
-
187
- if current_time - verified_at > verification_timeout:
188
- # Verification expired, clear status and redirect to verification page
189
- session.pop("turnstile_verified", None)
190
- session.pop("turnstile_verified_at", None)
191
-
192
- redirect_url = request.url
193
- # Force HTTPS in HuggingFace Spaces
194
- if IS_SPACES and redirect_url.startswith("http://"):
195
- redirect_url = redirect_url.replace("http://", "https://", 1)
196
-
197
- if request.path.startswith("/api/"):
198
- return jsonify({"error": "Turnstile verification expired"}), 403
199
- return redirect(
200
- url_for("turnstile_page", redirect_url=redirect_url)
201
- )
202
-
203
-
204
- @app.route("/turnstile", methods=["GET"])
205
- def turnstile_page():
206
- """Display Cloudflare Turnstile verification page"""
207
- redirect_url = request.args.get("redirect_url", url_for("arena", _external=True))
208
-
209
- # Force HTTPS in HuggingFace Spaces
210
- if IS_SPACES and redirect_url.startswith("http://"):
211
- redirect_url = redirect_url.replace("http://", "https://", 1)
212
 
213
- return render_template(
214
- "turnstile.html",
215
- turnstile_site_key=app.config["TURNSTILE_SITE_KEY"],
216
- redirect_url=redirect_url,
217
- )
218
-
219
-
220
- @app.route("/verify-turnstile", methods=["POST"])
221
- def verify_turnstile():
222
- """Verify Cloudflare Turnstile token"""
223
- token = request.form.get("cf-turnstile-response")
224
- redirect_url = request.form.get("redirect_url", url_for("arena", _external=True))
225
-
226
- # Force HTTPS in HuggingFace Spaces
227
- if IS_SPACES and redirect_url.startswith("http://"):
228
- redirect_url = redirect_url.replace("http://", "https://", 1)
229
-
230
- if not token:
231
- # If AJAX request, return JSON error
232
- if request.headers.get("X-Requested-With") == "XMLHttpRequest":
233
- return (
234
- jsonify({"success": False, "error": "Missing verification token"}),
235
- 400,
236
- )
237
- # Otherwise redirect back to turnstile page
238
- return redirect(url_for("turnstile_page", redirect_url=redirect_url))
239
-
240
- # Verify token with Cloudflare
241
- data = {
242
- "secret": app.config["TURNSTILE_SECRET_KEY"],
243
- "response": token,
244
- "remoteip": request.remote_addr,
245
- }
246
-
247
- try:
248
- response = requests.post(app.config["TURNSTILE_VERIFY_URL"], data=data)
249
- result = response.json()
250
-
251
- if result.get("success"):
252
- # Set verification status in session
253
- session["turnstile_verified"] = True
254
- session["turnstile_verified_at"] = datetime.utcnow().timestamp()
255
-
256
- # Determine response type based on request
257
- is_xhr = request.headers.get("X-Requested-With") == "XMLHttpRequest"
258
- accepts_json = "application/json" in request.headers.get("Accept", "")
259
-
260
- # If AJAX or JSON request, return success JSON
261
- if is_xhr or accepts_json:
262
- return jsonify({"success": True, "redirect": redirect_url})
263
-
264
- # For regular form submissions, redirect to the target URL
265
- return redirect(redirect_url)
266
- else:
267
- # Verification failed
268
- app.logger.warning(f"Turnstile verification failed: {result}")
269
-
270
- # If AJAX request, return JSON error
271
- if request.headers.get("X-Requested-With") == "XMLHttpRequest":
272
- return jsonify({"success": False, "error": "Verification failed"}), 403
273
-
274
- # Otherwise redirect back to turnstile page
275
- return redirect(url_for("turnstile_page", redirect_url=redirect_url))
276
-
277
- except Exception as e:
278
- app.logger.error(f"Turnstile verification error: {str(e)}")
279
-
280
- # If AJAX request, return JSON error
281
- if request.headers.get("X-Requested-With") == "XMLHttpRequest":
282
- return (
283
- jsonify(
284
- {"success": False, "error": "Server error during verification"}
285
- ),
286
- 500,
287
- )
288
-
289
- # Otherwise redirect back to turnstile page
290
- return redirect(url_for("turnstile_page", redirect_url=redirect_url))
291
-
292
- with open("sentences.txt", "r") as f, open("emotional_sentences.txt", "r") as f_emotional:
293
- # Store all sentences and clean them up
294
- all_harvard_sentences = [line.strip() for line in f.readlines() if line.strip()] + [line.strip() for line in f_emotional.readlines() if line.strip()]
295
- # Shuffle for initial random selection if needed, but main list remains ordered
296
- initial_sentences = random.sample(all_harvard_sentences, min(len(all_harvard_sentences), 500)) # Limit initial pass for template
297
 
298
  @app.route("/")
299
- def arena():
300
- # Pass a subset of sentences for the random button fallback
301
- return render_template("arena.html", harvard_sentences=json.dumps(initial_sentences))
302
-
303
-
304
- @app.route("/leaderboard")
305
- def leaderboard():
306
- tts_leaderboard = get_leaderboard_data(ModelType.TTS)
307
- conversational_leaderboard = get_leaderboard_data(ModelType.CONVERSATIONAL)
308
- top_voters = get_top_voters(10) # Get top 10 voters
309
-
310
- # Initialize personal leaderboard data
311
- tts_personal_leaderboard = None
312
- conversational_personal_leaderboard = None
313
- user_leaderboard_visibility = None
314
-
315
- # If user is logged in, get their personal leaderboard and visibility setting
316
- if current_user.is_authenticated:
317
- tts_personal_leaderboard = get_user_leaderboard(current_user.id, ModelType.TTS)
318
- conversational_personal_leaderboard = get_user_leaderboard(
319
- current_user.id, ModelType.CONVERSATIONAL
320
- )
321
- user_leaderboard_visibility = current_user.show_in_leaderboard
322
-
323
- # Get key dates for the timeline
324
- tts_key_dates = get_key_historical_dates(ModelType.TTS)
325
- conversational_key_dates = get_key_historical_dates(ModelType.CONVERSATIONAL)
326
-
327
- # Format dates for display in the dropdown
328
- formatted_tts_dates = [date.strftime("%B %Y") for date in tts_key_dates]
329
- formatted_conversational_dates = [
330
- date.strftime("%B %Y") for date in conversational_key_dates
331
- ]
332
-
333
- return render_template(
334
- "leaderboard.html",
335
- tts_leaderboard=tts_leaderboard,
336
- conversational_leaderboard=conversational_leaderboard,
337
- tts_personal_leaderboard=tts_personal_leaderboard,
338
- conversational_personal_leaderboard=conversational_personal_leaderboard,
339
- tts_key_dates=tts_key_dates,
340
- conversational_key_dates=conversational_key_dates,
341
- formatted_tts_dates=formatted_tts_dates,
342
- formatted_conversational_dates=formatted_conversational_dates,
343
- top_voters=top_voters,
344
- user_leaderboard_visibility=user_leaderboard_visibility
345
- )
346
-
347
-
348
- @app.route("/api/historical-leaderboard/<model_type>")
349
- def historical_leaderboard(model_type):
350
- """Get historical leaderboard data for a specific date"""
351
- if model_type not in [ModelType.TTS, ModelType.CONVERSATIONAL]:
352
- return jsonify({"error": "Invalid model type"}), 400
353
-
354
- # Get date from query parameter
355
- date_str = request.args.get("date")
356
- if not date_str:
357
- return jsonify({"error": "Date parameter is required"}), 400
358
-
359
- try:
360
- # Parse date from URL parameter (format: YYYY-MM-DD)
361
- target_date = datetime.strptime(date_str, "%Y-%m-%d")
362
-
363
- # Get historical leaderboard data
364
- leaderboard_data = get_historical_leaderboard_data(model_type, target_date)
365
-
366
- return jsonify(
367
- {"date": target_date.strftime("%B %d, %Y"), "leaderboard": leaderboard_data}
368
- )
369
- except ValueError:
370
- return jsonify({"error": "Invalid date format. Use YYYY-MM-DD"}), 400
371
-
372
-
373
- @app.route("/about")
374
- def about():
375
- return render_template("about.html")
376
-
377
-
378
- # --- TTS Caching Functions ---
379
-
380
- def generate_and_save_tts(text, model_id, output_dir):
381
- """Generates TTS and saves it to a specific directory, returning the full path."""
382
- temp_audio_path = None # Initialize to None
383
- try:
384
- app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'")
385
- # If predict_tts saves file itself and returns path:
386
- temp_audio_path = predict_tts(text, model_id)
387
- app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}")
388
-
389
- if not temp_audio_path or not os.path.exists(temp_audio_path):
390
- app.logger.warning(f"[TTS Gen {model_id}] predict_tts failed or returned invalid path: {temp_audio_path}")
391
- raise ValueError("predict_tts did not return a valid path or file does not exist")
392
-
393
- file_uuid = str(uuid.uuid4())
394
- dest_path = os.path.join(output_dir, f"{file_uuid}.wav")
395
- app.logger.debug(f"[TTS Gen {model_id}] Moving {temp_audio_path} to {dest_path}")
396
- # Move the file generated by predict_tts to the target cache directory
397
- shutil.move(temp_audio_path, dest_path)
398
- app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}")
399
- return dest_path
400
-
401
- except Exception as e:
402
- app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}")
403
- # Ensure temporary file from predict_tts (if any) is cleaned up
404
- if temp_audio_path and os.path.exists(temp_audio_path):
405
- try:
406
- app.logger.debug(f"[TTS Gen {model_id}] Cleaning up temporary file {temp_audio_path} after error.")
407
- os.remove(temp_audio_path)
408
- except OSError:
409
- pass # Ignore error if file couldn't be removed
410
- return None
411
-
412
-
413
- def _generate_cache_entry_task(sentence):
414
- """Task function to generate audio for a sentence and add to cache."""
415
- # Wrap the entire task in an application context
416
- with app.app_context():
417
- if not sentence:
418
- # Select a new sentence if not provided (for replacement)
419
- with tts_cache_lock:
420
- cached_keys = set(tts_cache.keys())
421
- available_sentences = [s for s in all_harvard_sentences if s not in cached_keys]
422
- if not available_sentences:
423
- app.logger.warning("No more unique Harvard sentences available for caching.")
424
- return
425
- sentence = random.choice(available_sentences)
426
-
427
- # app.logger.info removed duplicate log
428
- print(f"[Cache Task] Querying models for: '{sentence[:50]}...'")
429
- available_models = Model.query.filter_by(
430
- model_type=ModelType.TTS, is_active=True
431
- ).all()
432
-
433
- if len(available_models) < 2:
434
- app.logger.error("Not enough active TTS models to generate cache entry.")
435
- return
436
-
437
- try:
438
- models = get_weighted_random_models(available_models, 2, ModelType.TTS)
439
- model_a_id = models[0].id
440
- model_b_id = models[1].id
441
-
442
- # Generate audio concurrently using a local executor for clarity within the task
443
- with ThreadPoolExecutor(max_workers=2, thread_name_prefix='AudioGen') as audio_executor:
444
- future_a = audio_executor.submit(generate_and_save_tts, sentence, model_a_id, CACHE_AUDIO_DIR)
445
- future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR)
446
-
447
- timeout_seconds = 120
448
- audio_a_path = future_a.result(timeout=timeout_seconds)
449
- audio_b_path = future_b.result(timeout=timeout_seconds)
450
-
451
- if audio_a_path and audio_b_path:
452
- with tts_cache_lock:
453
- # Only add if the sentence isn't already back in the cache
454
- # And ensure cache size doesn't exceed limit
455
- if sentence not in tts_cache and len(tts_cache) < TTS_CACHE_SIZE:
456
- tts_cache[sentence] = {
457
- "model_a": model_a_id,
458
- "model_b": model_b_id,
459
- "audio_a": audio_a_path,
460
- "audio_b": audio_b_path,
461
- "created_at": datetime.utcnow(),
462
- }
463
- app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'")
464
- elif sentence in tts_cache:
465
- app.logger.warning(f"Sentence '{sentence[:50]}...' already re-cached. Discarding new generation.")
466
- # Clean up the newly generated files if not added
467
- if os.path.exists(audio_a_path): os.remove(audio_a_path)
468
- if os.path.exists(audio_b_path): os.remove(audio_b_path)
469
- else: # Cache is full
470
- app.logger.warning(f"Cache is full ({len(tts_cache)} entries). Discarding new generation for '{sentence[:50]}...'.")
471
- # Clean up the newly generated files if not added
472
- if os.path.exists(audio_a_path): os.remove(audio_a_path)
473
- if os.path.exists(audio_b_path): os.remove(audio_b_path)
474
-
475
- else:
476
- app.logger.error(f"Failed to generate one or both audio files for cache: '{sentence[:50]}...'")
477
- # Clean up whichever file might have been created
478
- if audio_a_path and os.path.exists(audio_a_path): os.remove(audio_a_path)
479
- if audio_b_path and os.path.exists(audio_b_path): os.remove(audio_b_path)
480
-
481
- except Exception as e:
482
- # Log the exception within the app context
483
- app.logger.error(f"Exception in _generate_cache_entry_task for '{sentence[:50]}...': {str(e)}", exc_info=True)
484
-
485
-
486
- def initialize_tts_cache():
487
- print("Initializing TTS cache")
488
- """Selects initial sentences and starts generation tasks."""
489
- with app.app_context(): # Ensure access to models
490
- if not all_harvard_sentences:
491
- app.logger.error("Harvard sentences not loaded. Cannot initialize cache.")
492
- return
493
-
494
- initial_selection = random.sample(all_harvard_sentences, min(len(all_harvard_sentences), TTS_CACHE_SIZE))
495
- app.logger.info(f"Initializing TTS cache with {len(initial_selection)} sentences...")
496
-
497
- for sentence in initial_selection:
498
- # Use the main cache_executor for initial population too
499
- cache_executor.submit(_generate_cache_entry_task, sentence)
500
- app.logger.info("Submitted initial cache generation tasks.")
501
-
502
- # --- End TTS Caching Functions ---
503
-
504
-
505
- @app.route("/api/tts/generate", methods=["POST"])
506
- @limiter.limit("10 per minute") # Keep limit, cached responses are still requests
507
- def generate_tts():
508
- # If verification not setup, handle it first
509
- if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
510
- return jsonify({"error": "Turnstile verification required"}), 403
511
-
512
- data = request.json
513
- text = data.get("text", "").strip() # Ensure text is stripped
514
-
515
- if not text or len(text) > 1000:
516
- return jsonify({"error": "Invalid or too long text"}), 400
517
-
518
- # --- Cache Check ---
519
- cache_hit = False
520
- session_data_from_cache = None
521
- with tts_cache_lock:
522
- if text in tts_cache:
523
- cache_hit = True
524
- cached_entry = tts_cache.pop(text) # Remove from cache immediately
525
- app.logger.info(f"TTS Cache HIT for: '{text[:50]}...'")
526
-
527
- # Prepare session data using cached info
528
- session_id = str(uuid.uuid4())
529
- session_data_from_cache = {
530
- "model_a": cached_entry["model_a"],
531
- "model_b": cached_entry["model_b"],
532
- "audio_a": cached_entry["audio_a"], # Paths are now from cache_dir
533
- "audio_b": cached_entry["audio_b"],
534
- "text": text,
535
- "created_at": datetime.utcnow(),
536
- "expires_at": datetime.utcnow() + timedelta(minutes=30),
537
- "voted": False,
538
- }
539
- app.tts_sessions[session_id] = session_data_from_cache
540
-
541
- # --- Trigger background tasks to refill the cache ---
542
- # Calculate how many slots need refilling
543
- current_cache_size = len(tts_cache) # Size *before* adding potentially new items
544
- needed_refills = TTS_CACHE_SIZE - current_cache_size
545
- # Limit concurrent refills to 8 or the actual need
546
- refills_to_submit = min(needed_refills, 8)
547
-
548
- if refills_to_submit > 0:
549
- app.logger.info(f"Cache hit: Submitting {refills_to_submit} background task(s) to refill cache (current size: {current_cache_size}, target: {TTS_CACHE_SIZE}).")
550
- for _ in range(refills_to_submit):
551
- # Pass None to signal replacement selection within the task
552
- cache_executor.submit(_generate_cache_entry_task, None)
553
- else:
554
- app.logger.info(f"Cache hit: Cache is already full or at target size ({current_cache_size}/{TTS_CACHE_SIZE}). No refill tasks submitted.")
555
- # --- End Refill Trigger ---
556
-
557
- if cache_hit and session_data_from_cache:
558
- # Return response using cached data
559
- # Note: The files are now managed by the session lifecycle (cleanup_session)
560
- return jsonify(
561
- {
562
- "session_id": session_id,
563
- "audio_a": f"/api/tts/audio/{session_id}/a",
564
- "audio_b": f"/api/tts/audio/{session_id}/b",
565
- "expires_in": 1800, # 30 minutes in seconds
566
- "cache_hit": True,
567
- }
568
- )
569
- # --- End Cache Check ---
570
-
571
- # --- Cache Miss: Generate on the fly ---
572
- app.logger.info(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.")
573
- available_models = Model.query.filter_by(
574
- model_type=ModelType.TTS, is_active=True
575
- ).all()
576
- if len(available_models) < 2:
577
- return jsonify({"error": "Not enough TTS models available"}), 500
578
-
579
- selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS)
580
-
581
- try:
582
- audio_files = []
583
- model_ids = []
584
-
585
- # Function to process a single model (generate directly to TEMP_AUDIO_DIR, not cache subdir)
586
- def process_model_on_the_fly(model):
587
- # Generate and save directly to the main temp dir
588
- # Assume predict_tts handles saving temporary files
589
- temp_audio_path = predict_tts(text, model.id)
590
- if not temp_audio_path or not os.path.exists(temp_audio_path):
591
- raise ValueError(f"predict_tts failed for model {model.id}")
592
-
593
- # Create a unique name in the main TEMP_AUDIO_DIR for the session
594
- file_uuid = str(uuid.uuid4())
595
- dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav")
596
- shutil.move(temp_audio_path, dest_path) # Move from predict_tts's temp location
597
-
598
- return {"model_id": model.id, "audio_path": dest_path}
599
-
600
-
601
- # Use ThreadPoolExecutor to process models concurrently
602
- with ThreadPoolExecutor(max_workers=2) as executor:
603
- results = list(executor.map(process_model_on_the_fly, selected_models))
604
-
605
- # Extract results
606
- for result in results:
607
- model_ids.append(result["model_id"])
608
- audio_files.append(result["audio_path"])
609
-
610
- # Create session
611
- session_id = str(uuid.uuid4())
612
- app.tts_sessions[session_id] = {
613
- "model_a": model_ids[0],
614
- "model_b": model_ids[1],
615
- "audio_a": audio_files[0], # Paths are now from TEMP_AUDIO_DIR directly
616
- "audio_b": audio_files[1],
617
- "text": text,
618
- "created_at": datetime.utcnow(),
619
- "expires_at": datetime.utcnow() + timedelta(minutes=30),
620
- "voted": False,
621
- }
622
-
623
- # Return audio file paths and session
624
- return jsonify(
625
- {
626
- "session_id": session_id,
627
- "audio_a": f"/api/tts/audio/{session_id}/a",
628
- "audio_b": f"/api/tts/audio/{session_id}/b",
629
- "expires_in": 1800,
630
- "cache_hit": False,
631
- }
632
- )
633
-
634
- except Exception as e:
635
- app.logger.error(f"TTS on-the-fly generation error: {str(e)}", exc_info=True)
636
- # Cleanup any files potentially created during the failed attempt
637
- if 'results' in locals():
638
- for res in results:
639
- if 'audio_path' in res and os.path.exists(res['audio_path']):
640
- try:
641
- os.remove(res['audio_path'])
642
- except OSError:
643
- pass
644
- return jsonify({"error": "Failed to generate TTS"}), 500
645
- # --- End Cache Miss ---
646
-
647
-
648
- @app.route("/api/tts/audio/<session_id>/<model_key>")
649
- def get_audio(session_id, model_key):
650
- # If verification not setup, handle it first
651
- if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
652
- return jsonify({"error": "Turnstile verification required"}), 403
653
-
654
- if session_id not in app.tts_sessions:
655
- return jsonify({"error": "Invalid or expired session"}), 404
656
-
657
- session_data = app.tts_sessions[session_id]
658
-
659
- # Check if session expired
660
- if datetime.utcnow() > session_data["expires_at"]:
661
- cleanup_session(session_id)
662
- return jsonify({"error": "Session expired"}), 410
663
-
664
- if model_key == "a":
665
- audio_path = session_data["audio_a"]
666
- elif model_key == "b":
667
- audio_path = session_data["audio_b"]
668
- else:
669
- return jsonify({"error": "Invalid model key"}), 400
670
-
671
- # Check if file exists
672
- if not os.path.exists(audio_path):
673
- return jsonify({"error": "Audio file not found"}), 404
674
-
675
- return send_file(audio_path, mimetype="audio/wav")
676
-
677
-
678
- @app.route("/api/tts/vote", methods=["POST"])
679
- @limiter.limit("30 per minute")
680
- def submit_vote():
681
- # If verification not setup, handle it first
682
- if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
683
- return jsonify({"error": "Turnstile verification required"}), 403
684
-
685
- data = request.json
686
- session_id = data.get("session_id")
687
- chosen_model_key = data.get("chosen_model") # "a" or "b"
688
-
689
- if not session_id or session_id not in app.tts_sessions:
690
- return jsonify({"error": "Invalid or expired session"}), 404
691
-
692
- if not chosen_model_key or chosen_model_key not in ["a", "b"]:
693
- return jsonify({"error": "Invalid chosen model"}), 400
694
-
695
- session_data = app.tts_sessions[session_id]
696
-
697
- # Check if session expired
698
- if datetime.utcnow() > session_data["expires_at"]:
699
- cleanup_session(session_id)
700
- return jsonify({"error": "Session expired"}), 410
701
-
702
- # Check if already voted
703
- if session_data["voted"]:
704
- return jsonify({"error": "Vote already submitted for this session"}), 400
705
-
706
- # Get model IDs and audio paths
707
- chosen_id = (
708
- session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"]
709
- )
710
- rejected_id = (
711
- session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"]
712
- )
713
- chosen_audio_path = (
714
- session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"]
715
- )
716
- rejected_audio_path = (
717
- session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"]
718
- )
719
-
720
- # Record vote in database
721
- user_id = current_user.id if current_user.is_authenticated else None
722
- vote, error = record_vote(
723
- user_id, session_data["text"], chosen_id, rejected_id, ModelType.TTS
724
- )
725
-
726
- if error:
727
- return jsonify({"error": error}), 500
728
-
729
- # --- Save preference data ---
730
- try:
731
- vote_uuid = str(uuid.uuid4())
732
- vote_dir = os.path.join("./votes", vote_uuid)
733
- os.makedirs(vote_dir, exist_ok=True)
734
-
735
- # Copy audio files
736
- shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav"))
737
- shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav"))
738
-
739
- # Create metadata
740
- chosen_model_obj = Model.query.get(chosen_id)
741
- rejected_model_obj = Model.query.get(rejected_id)
742
- metadata = {
743
- "text": session_data["text"],
744
- "chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown",
745
- "chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown",
746
- "rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown",
747
- "rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown",
748
- "session_id": session_id,
749
- "timestamp": datetime.utcnow().isoformat(),
750
- "username": current_user.username if current_user.is_authenticated else None,
751
- "model_type": "TTS"
752
- }
753
- with open(os.path.join(vote_dir, "metadata.json"), "w") as f:
754
- json.dump(metadata, f, indent=2)
755
-
756
- except Exception as e:
757
- app.logger.error(f"Error saving preference data for vote {session_id}: {str(e)}")
758
- # Continue even if saving preference data fails, vote is already recorded
759
-
760
- # Mark session as voted
761
- session_data["voted"] = True
762
-
763
- # Return updated models (use previously fetched objects)
764
- return jsonify(
765
- {
766
- "success": True,
767
- "chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"},
768
- "rejected_model": {
769
- "id": rejected_id,
770
- "name": rejected_model_obj.name if rejected_model_obj else "Unknown",
771
- },
772
- "names": {
773
- "a": (
774
- chosen_model_obj.name if chosen_model_key == "a" else rejected_model_obj.name
775
- if chosen_model_obj and rejected_model_obj else "Unknown"
776
- ),
777
- "b": (
778
- rejected_model_obj.name if chosen_model_key == "a" else chosen_model_obj.name
779
- if chosen_model_obj and rejected_model_obj else "Unknown"
780
- ),
781
- },
782
- }
783
- )
784
-
785
-
786
- def cleanup_session(session_id):
787
- """Remove session and its audio files"""
788
- if session_id in app.tts_sessions:
789
- session = app.tts_sessions[session_id]
790
-
791
- # Remove audio files
792
- for audio_file in [session["audio_a"], session["audio_b"]]:
793
- if os.path.exists(audio_file):
794
- try:
795
- os.remove(audio_file)
796
- except Exception as e:
797
- app.logger.error(f"Error removing audio file: {str(e)}")
798
-
799
- # Remove session
800
- del app.tts_sessions[session_id]
801
-
802
-
803
- @app.route("/api/conversational/generate", methods=["POST"])
804
- @limiter.limit("5 per minute")
805
- def generate_podcast():
806
- # If verification not setup, handle it first
807
- if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
808
- return jsonify({"error": "Turnstile verification required"}), 403
809
-
810
- data = request.json
811
- script = data.get("script")
812
-
813
- if not script or not isinstance(script, list) or len(script) < 2:
814
- return jsonify({"error": "Invalid script format or too short"}), 400
815
-
816
- # Validate script format
817
- for line in script:
818
- if not isinstance(line, dict) or "text" not in line or "speaker_id" not in line:
819
- return (
820
- jsonify(
821
- {
822
- "error": "Invalid script line format. Each line must have text and speaker_id"
823
- }
824
- ),
825
- 400,
826
- )
827
- if (
828
- not line["text"]
829
- or not isinstance(line["speaker_id"], int)
830
- or line["speaker_id"] not in [0, 1]
831
- ):
832
- return (
833
- jsonify({"error": "Invalid script content. Speaker ID must be 0 or 1"}),
834
- 400,
835
- )
836
-
837
- # Get two conversational models (currently only CSM and PlayDialog)
838
- available_models = Model.query.filter_by(
839
- model_type=ModelType.CONVERSATIONAL, is_active=True
840
- ).all()
841
-
842
- if len(available_models) < 2:
843
- return jsonify({"error": "Not enough conversational models available"}), 500
844
-
845
- selected_models = get_weighted_random_models(available_models, 2, ModelType.CONVERSATIONAL)
846
-
847
- try:
848
- # Generate audio for both models concurrently
849
- audio_files = []
850
- model_ids = []
851
-
852
- # Function to process a single model
853
- def process_model(model):
854
- # Call conversational TTS service
855
- audio_content = predict_tts(script, model.id)
856
-
857
- # Save to temp file with unique name
858
- file_uuid = str(uuid.uuid4())
859
- dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav")
860
-
861
- with open(dest_path, "wb") as f:
862
- f.write(audio_content)
863
-
864
- return {"model_id": model.id, "audio_path": dest_path}
865
-
866
- # Use ThreadPoolExecutor to process models concurrently
867
- with ThreadPoolExecutor(max_workers=2) as executor:
868
- results = list(executor.map(process_model, selected_models))
869
-
870
- # Extract results
871
- for result in results:
872
- model_ids.append(result["model_id"])
873
- audio_files.append(result["audio_path"])
874
-
875
- # Create session
876
- session_id = str(uuid.uuid4())
877
- script_text = " ".join([line["text"] for line in script])
878
- app.conversational_sessions[session_id] = {
879
- "model_a": model_ids[0],
880
- "model_b": model_ids[1],
881
- "audio_a": audio_files[0],
882
- "audio_b": audio_files[1],
883
- "text": script_text[:1000], # Limit text length
884
- "created_at": datetime.utcnow(),
885
- "expires_at": datetime.utcnow() + timedelta(minutes=30),
886
- "voted": False,
887
- "script": script,
888
- }
889
-
890
- # Return audio file paths and session
891
- return jsonify(
892
- {
893
- "session_id": session_id,
894
- "audio_a": f"/api/conversational/audio/{session_id}/a",
895
- "audio_b": f"/api/conversational/audio/{session_id}/b",
896
- "expires_in": 1800, # 30 minutes in seconds
897
- }
898
- )
899
-
900
- except Exception as e:
901
- app.logger.error(f"Conversational generation error: {str(e)}")
902
- return jsonify({"error": f"Failed to generate podcast: {str(e)}"}), 500
903
-
904
-
905
- @app.route("/api/conversational/audio/<session_id>/<model_key>")
906
- def get_podcast_audio(session_id, model_key):
907
- # If verification not setup, handle it first
908
- if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
909
- return jsonify({"error": "Turnstile verification required"}), 403
910
-
911
- if session_id not in app.conversational_sessions:
912
- return jsonify({"error": "Invalid or expired session"}), 404
913
-
914
- session_data = app.conversational_sessions[session_id]
915
-
916
- # Check if session expired
917
- if datetime.utcnow() > session_data["expires_at"]:
918
- cleanup_conversational_session(session_id)
919
- return jsonify({"error": "Session expired"}), 410
920
-
921
- if model_key == "a":
922
- audio_path = session_data["audio_a"]
923
- elif model_key == "b":
924
- audio_path = session_data["audio_b"]
925
- else:
926
- return jsonify({"error": "Invalid model key"}), 400
927
-
928
- # Check if file exists
929
- if not os.path.exists(audio_path):
930
- return jsonify({"error": "Audio file not found"}), 404
931
-
932
- return send_file(audio_path, mimetype="audio/wav")
933
-
934
-
935
- @app.route("/api/conversational/vote", methods=["POST"])
936
- @limiter.limit("30 per minute")
937
- def submit_podcast_vote():
938
- # If verification not setup, handle it first
939
- if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
940
- return jsonify({"error": "Turnstile verification required"}), 403
941
-
942
- data = request.json
943
- session_id = data.get("session_id")
944
- chosen_model_key = data.get("chosen_model") # "a" or "b"
945
-
946
- if not session_id or session_id not in app.conversational_sessions:
947
- return jsonify({"error": "Invalid or expired session"}), 404
948
-
949
- if not chosen_model_key or chosen_model_key not in ["a", "b"]:
950
- return jsonify({"error": "Invalid chosen model"}), 400
951
-
952
- session_data = app.conversational_sessions[session_id]
953
-
954
- # Check if session expired
955
- if datetime.utcnow() > session_data["expires_at"]:
956
- cleanup_conversational_session(session_id)
957
- return jsonify({"error": "Session expired"}), 410
958
-
959
- # Check if already voted
960
- if session_data["voted"]:
961
- return jsonify({"error": "Vote already submitted for this session"}), 400
962
-
963
- # Get model IDs and audio paths
964
- chosen_id = (
965
- session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"]
966
- )
967
- rejected_id = (
968
- session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"]
969
- )
970
- chosen_audio_path = (
971
- session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"]
972
- )
973
- rejected_audio_path = (
974
- session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"]
975
- )
976
-
977
- # Record vote in database
978
- user_id = current_user.id if current_user.is_authenticated else None
979
- vote, error = record_vote(
980
- user_id, session_data["text"], chosen_id, rejected_id, ModelType.CONVERSATIONAL
981
- )
982
-
983
- if error:
984
- return jsonify({"error": error}), 500
985
-
986
- # --- Save preference data ---\
987
- try:
988
- vote_uuid = str(uuid.uuid4())
989
- vote_dir = os.path.join("./votes", vote_uuid)
990
- os.makedirs(vote_dir, exist_ok=True)
991
-
992
- # Copy audio files
993
- shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav"))
994
- shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav"))
995
-
996
- # Create metadata
997
- chosen_model_obj = Model.query.get(chosen_id)
998
- rejected_model_obj = Model.query.get(rejected_id)
999
- metadata = {
1000
- "script": session_data["script"], # Save the full script
1001
- "chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown",
1002
- "chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown",
1003
- "rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown",
1004
- "rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown",
1005
- "session_id": session_id,
1006
- "timestamp": datetime.utcnow().isoformat(),
1007
- "username": current_user.username if current_user.is_authenticated else None,
1008
- "model_type": "CONVERSATIONAL"
1009
- }
1010
- with open(os.path.join(vote_dir, "metadata.json"), "w") as f:
1011
- json.dump(metadata, f, indent=2)
1012
-
1013
- except Exception as e:
1014
- app.logger.error(f"Error saving preference data for conversational vote {session_id}: {str(e)}")
1015
- # Continue even if saving preference data fails, vote is already recorded
1016
-
1017
- # Mark session as voted
1018
- session_data["voted"] = True
1019
-
1020
- # Return updated models (use previously fetched objects)
1021
- return jsonify(
1022
- {
1023
- "success": True,
1024
- "chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"},
1025
- "rejected_model": {
1026
- "id": rejected_id,
1027
- "name": rejected_model_obj.name if rejected_model_obj else "Unknown",
1028
- },
1029
- "names": {
1030
- "a": Model.query.get(session_data["model_a"]).name,
1031
- "b": Model.query.get(session_data["model_b"]).name,
1032
- },
1033
- }
1034
- )
1035
-
1036
-
1037
- def cleanup_conversational_session(session_id):
1038
- """Remove conversational session and its audio files"""
1039
- if session_id in app.conversational_sessions:
1040
- session = app.conversational_sessions[session_id]
1041
-
1042
- # Remove audio files
1043
- for audio_file in [session["audio_a"], session["audio_b"]]:
1044
- if os.path.exists(audio_file):
1045
- try:
1046
- os.remove(audio_file)
1047
- except Exception as e:
1048
- app.logger.error(
1049
- f"Error removing conversational audio file: {str(e)}"
1050
- )
1051
-
1052
- # Remove session
1053
- del app.conversational_sessions[session_id]
1054
-
1055
-
1056
- # Schedule periodic cleanup
1057
- def setup_cleanup():
1058
- def cleanup_expired_sessions():
1059
- with app.app_context(): # Ensure app context for logging
1060
- current_time = datetime.utcnow()
1061
- # Cleanup TTS sessions
1062
- expired_tts_sessions = [
1063
- sid
1064
- for sid, session_data in app.tts_sessions.items()
1065
- if current_time > session_data["expires_at"]
1066
- ]
1067
- for sid in expired_tts_sessions:
1068
- cleanup_session(sid)
1069
-
1070
- # Cleanup conversational sessions
1071
- expired_conv_sessions = [
1072
- sid
1073
- for sid, session_data in app.conversational_sessions.items()
1074
- if current_time > session_data["expires_at"]
1075
- ]
1076
- for sid in expired_conv_sessions:
1077
- cleanup_conversational_session(sid)
1078
- app.logger.info(f"Cleaned up {len(expired_tts_sessions)} TTS and {len(expired_conv_sessions)} conversational sessions.")
1079
-
1080
- # Also cleanup potentially expired cache entries (e.g., > 1 hour old)
1081
- # This prevents stale cache entries if generation is slow or failing
1082
- # cleanup_stale_cache_entries()
1083
-
1084
- # Run cleanup every 15 minutes
1085
- scheduler = BackgroundScheduler(daemon=True) # Run scheduler as daemon thread
1086
- scheduler.add_job(cleanup_expired_sessions, "interval", minutes=15)
1087
- scheduler.start()
1088
- print("Cleanup scheduler started") # Use print for startup messages
1089
-
1090
-
1091
- # Schedule periodic tasks (database sync and preference upload)
1092
- def setup_periodic_tasks():
1093
- """Setup periodic database synchronization and preference data upload for Spaces"""
1094
- if not IS_SPACES:
1095
- return
1096
-
1097
- db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "instance/") # Get relative path
1098
- preferences_repo_id = "TTS-AGI/arena-v2-preferences"
1099
- database_repo_id = "TTS-AGI/database-arena-v2"
1100
- votes_dir = "./votes"
1101
-
1102
- def sync_database():
1103
- """Uploads the database to HF dataset"""
1104
- with app.app_context(): # Ensure app context for logging
1105
- try:
1106
- if not os.path.exists(db_path):
1107
- app.logger.warning(f"Database file not found at {db_path}, skipping sync.")
1108
- return
1109
-
1110
- api = HfApi(token=os.getenv("HF_TOKEN"))
1111
- api.upload_file(
1112
- path_or_fileobj=db_path,
1113
- path_in_repo="tts_arena.db",
1114
- repo_id=database_repo_id,
1115
- repo_type="dataset",
1116
- )
1117
- app.logger.info(f"Database uploaded to {database_repo_id} at {datetime.utcnow()}")
1118
- except Exception as e:
1119
- app.logger.error(f"Error uploading database to {database_repo_id}: {str(e)}")
1120
-
1121
- def sync_preferences_data():
1122
- """Zips and uploads preference data folders in batches to HF dataset"""
1123
- with app.app_context(): # Ensure app context for logging
1124
- if not os.path.isdir(votes_dir):
1125
- return # Don't log every 5 mins if dir doesn't exist yet
1126
-
1127
- temp_batch_dir = None # Initialize to manage cleanup
1128
- temp_individual_zip_dir = None # Initialize for individual zips
1129
- local_batch_zip_path = None # Initialize for batch zip path
1130
-
1131
- try:
1132
- api = HfApi(token=os.getenv("HF_TOKEN"))
1133
- vote_uuids = [d for d in os.listdir(votes_dir) if os.path.isdir(os.path.join(votes_dir, d))]
1134
-
1135
- if not vote_uuids:
1136
- return # No data to process
1137
-
1138
- app.logger.info(f"Found {len(vote_uuids)} vote directories to process.")
1139
-
1140
- # Create temporary directories
1141
- temp_batch_dir = tempfile.mkdtemp(prefix="hf_batch_")
1142
- temp_individual_zip_dir = tempfile.mkdtemp(prefix="hf_indiv_zips_")
1143
- app.logger.debug(f"Created temp directories: {temp_batch_dir}, {temp_individual_zip_dir}")
1144
-
1145
- processed_vote_dirs = []
1146
- individual_zips_in_batch = []
1147
-
1148
- # 1. Create individual zips and move them to the batch directory
1149
- for vote_uuid in vote_uuids:
1150
- dir_path = os.path.join(votes_dir, vote_uuid)
1151
- individual_zip_base_path = os.path.join(temp_individual_zip_dir, vote_uuid)
1152
- individual_zip_path = f"{individual_zip_base_path}.zip"
1153
-
1154
- try:
1155
- shutil.make_archive(individual_zip_base_path, 'zip', dir_path)
1156
- app.logger.debug(f"Created individual zip: {individual_zip_path}")
1157
-
1158
- # Move the created zip into the batch directory
1159
- final_individual_zip_path = os.path.join(temp_batch_dir, f"{vote_uuid}.zip")
1160
- shutil.move(individual_zip_path, final_individual_zip_path)
1161
- app.logger.debug(f"Moved individual zip to batch dir: {final_individual_zip_path}")
1162
-
1163
- processed_vote_dirs.append(dir_path) # Mark original dir for later cleanup
1164
- individual_zips_in_batch.append(final_individual_zip_path)
1165
-
1166
- except Exception as zip_err:
1167
- app.logger.error(f"Error creating or moving zip for {vote_uuid}: {str(zip_err)}")
1168
- # Clean up partial zip if it exists
1169
- if os.path.exists(individual_zip_path):
1170
- try:
1171
- os.remove(individual_zip_path)
1172
- except OSError:
1173
- pass
1174
- # Continue processing other votes
1175
-
1176
- # Clean up the temporary dir used for creating individual zips
1177
- shutil.rmtree(temp_individual_zip_dir)
1178
- temp_individual_zip_dir = None # Mark as cleaned
1179
- app.logger.debug("Cleaned up temporary individual zip directory.")
1180
-
1181
- if not individual_zips_in_batch:
1182
- app.logger.warning("No individual zips were successfully created for batching.")
1183
- # Clean up batch dir if it's empty or only contains failed attempts
1184
- if temp_batch_dir and os.path.exists(temp_batch_dir):
1185
- shutil.rmtree(temp_batch_dir)
1186
- temp_batch_dir = None
1187
- return
1188
-
1189
- # 2. Create the batch zip file
1190
- batch_timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
1191
- batch_uuid_short = str(uuid.uuid4())[:8]
1192
- batch_zip_filename = f"{batch_timestamp}_batch_{batch_uuid_short}.zip"
1193
- # Create batch zip in a standard temp location first
1194
- local_batch_zip_base = os.path.join(tempfile.gettempdir(), batch_zip_filename.replace('.zip', ''))
1195
- local_batch_zip_path = f"{local_batch_zip_base}.zip"
1196
-
1197
- app.logger.info(f"Creating batch zip: {local_batch_zip_path} with {len(individual_zips_in_batch)} individual zips.")
1198
- shutil.make_archive(local_batch_zip_base, 'zip', temp_batch_dir)
1199
- app.logger.info(f"Batch zip created successfully: {local_batch_zip_path}")
1200
-
1201
- # 3. Upload the batch zip file
1202
- hf_repo_path = f"votes/{year}/{month}/{batch_zip_filename}"
1203
- app.logger.info(f"Uploading batch zip to HF Hub: {preferences_repo_id}/{hf_repo_path}")
1204
-
1205
- api.upload_file(
1206
- path_or_fileobj=local_batch_zip_path,
1207
- path_in_repo=hf_repo_path,
1208
- repo_id=preferences_repo_id,
1209
- repo_type="dataset",
1210
- commit_message=f"Add batch preference data {batch_zip_filename} ({len(individual_zips_in_batch)} votes)"
1211
- )
1212
- app.logger.info(f"Successfully uploaded batch {batch_zip_filename} to {preferences_repo_id}")
1213
-
1214
- # 4. Cleanup after successful upload
1215
- app.logger.info("Cleaning up local files after successful upload.")
1216
- # Remove original vote directories that were successfully zipped and uploaded
1217
- for dir_path in processed_vote_dirs:
1218
- try:
1219
- shutil.rmtree(dir_path)
1220
- app.logger.debug(f"Removed original vote directory: {dir_path}")
1221
- except OSError as e:
1222
- app.logger.error(f"Error removing processed vote directory {dir_path}: {str(e)}")
1223
-
1224
- # Remove the temporary batch directory (containing the individual zips)
1225
- shutil.rmtree(temp_batch_dir)
1226
- temp_batch_dir = None
1227
- app.logger.debug("Removed temporary batch directory.")
1228
-
1229
- # Remove the local batch zip file
1230
- os.remove(local_batch_zip_path)
1231
- local_batch_zip_path = None
1232
- app.logger.debug("Removed local batch zip file.")
1233
-
1234
- app.logger.info(f"Finished preference data sync. Uploaded batch {batch_zip_filename}.")
1235
-
1236
- except Exception as e:
1237
- app.logger.error(f"Error during preference data batch sync: {str(e)}", exc_info=True)
1238
- # If upload failed, the local batch zip might exist, clean it up.
1239
- if local_batch_zip_path and os.path.exists(local_batch_zip_path):
1240
- try:
1241
- os.remove(local_batch_zip_path)
1242
- app.logger.debug("Cleaned up local batch zip after failed upload.")
1243
- except OSError as clean_err:
1244
- app.logger.error(f"Error cleaning up batch zip after failed upload: {clean_err}")
1245
- # Do NOT remove temp_batch_dir if it exists; its contents will be retried next time.
1246
- # Do NOT remove original vote directories if upload failed.
1247
-
1248
- finally:
1249
- # Final cleanup for temporary directories in case of unexpected exits
1250
- if temp_individual_zip_dir and os.path.exists(temp_individual_zip_dir):
1251
- try:
1252
- shutil.rmtree(temp_individual_zip_dir)
1253
- except Exception as final_clean_err:
1254
- app.logger.error(f"Error in final cleanup (indiv zips): {final_clean_err}")
1255
- # Only clean up batch dir in finally block if it *wasn't* kept intentionally after upload failure
1256
- if temp_batch_dir and os.path.exists(temp_batch_dir):
1257
- # Check if an upload attempt happened and failed
1258
- upload_failed = 'e' in locals() and isinstance(e, Exception) # Crude check if exception occurred
1259
- if not upload_failed: # If no upload error or upload succeeded, clean up
1260
- try:
1261
- shutil.rmtree(temp_batch_dir)
1262
- except Exception as final_clean_err:
1263
- app.logger.error(f"Error in final cleanup (batch dir): {final_clean_err}")
1264
- else:
1265
- app.logger.warning("Keeping temporary batch directory due to upload failure for next attempt.")
1266
-
1267
-
1268
- # Schedule periodic tasks
1269
- scheduler = BackgroundScheduler()
1270
- # Sync database less frequently if needed, e.g., every 15 minutes
1271
- scheduler.add_job(sync_database, "interval", minutes=15, id="sync_db_job")
1272
- # Sync preferences more frequently
1273
- scheduler.add_job(sync_preferences_data, "interval", minutes=5, id="sync_pref_job")
1274
- scheduler.start()
1275
- print("Periodic tasks scheduler started (DB sync and Preferences upload)") # Use print for startup
1276
-
1277
-
1278
- @app.cli.command("init-db")
1279
- def init_db():
1280
- """Initialize the database."""
1281
- with app.app_context():
1282
- db.create_all()
1283
- print("Database initialized!")
1284
-
1285
-
1286
- @app.route("/api/toggle-leaderboard-visibility", methods=["POST"])
1287
- def toggle_leaderboard_visibility():
1288
- """Toggle whether the current user appears in the top voters leaderboard"""
1289
- if not current_user.is_authenticated:
1290
- return jsonify({"error": "You must be logged in to change this setting"}), 401
1291
-
1292
- new_status = toggle_user_leaderboard_visibility(current_user.id)
1293
- if new_status is None:
1294
- return jsonify({"error": "User not found"}), 404
1295
-
1296
- return jsonify({
1297
- "success": True,
1298
- "visible": new_status,
1299
- "message": "You are now visible in the voters leaderboard" if new_status else "You are now hidden from the voters leaderboard"
1300
- })
1301
-
1302
-
1303
- @app.route("/api/tts/cached-sentences")
1304
- def get_cached_sentences():
1305
- """Returns a list of sentences currently available in the TTS cache."""
1306
- with tts_cache_lock:
1307
- cached_keys = list(tts_cache.keys())
1308
- return jsonify(cached_keys)
1309
-
1310
-
1311
- def get_weighted_random_models(
1312
- applicable_models: list[Model], num_to_select: int, model_type: ModelType
1313
- ) -> list[Model]:
1314
- """
1315
- Selects a specified number of models randomly from a list of applicable_models,
1316
- weighting models with fewer votes higher. A smoothing factor is used to ensure
1317
- the preference is slight and to prevent models with zero votes from being
1318
- overwhelmingly favored. Models are selected without replacement.
1319
-
1320
- Assumes len(applicable_models) >= num_to_select, which should be checked by the caller.
1321
- """
1322
- model_votes_counts = {}
1323
- for model in applicable_models:
1324
- votes = (
1325
- Vote.query.filter(Vote.model_type == model_type)
1326
- .filter(or_(Vote.model_chosen == model.id, Vote.model_rejected == model.id))
1327
- .count()
1328
- )
1329
- model_votes_counts[model.id] = votes
1330
-
1331
- weights = [
1332
- 1.0 / (model_votes_counts[model.id] + SMOOTHING_FACTOR_MODEL_SELECTION)
1333
- for model in applicable_models
1334
- ]
1335
-
1336
- selected_models_list = []
1337
- # Create copies to modify during selection process
1338
- current_candidates = list(applicable_models)
1339
- current_weights = list(weights)
1340
-
1341
- # Assumes num_to_select is positive and less than or equal to len(current_candidates)
1342
- # Callers should ensure this (e.g., len(available_models) >= 2).
1343
- for _ in range(num_to_select):
1344
- if not current_candidates: # Safety break
1345
- app.logger.warning("Not enough candidates left for weighted selection.")
1346
- break
1347
-
1348
- chosen_model = random.choices(current_candidates, weights=current_weights, k=1)[0]
1349
- selected_models_list.append(chosen_model)
1350
-
1351
- try:
1352
- idx_to_remove = current_candidates.index(chosen_model)
1353
- current_candidates.pop(idx_to_remove)
1354
- current_weights.pop(idx_to_remove)
1355
- except ValueError:
1356
- # This should ideally not happen if chosen_model came from current_candidates.
1357
- app.logger.error(f"Error removing model {chosen_model.id} from weighted selection candidates.")
1358
- break # Avoid potential issues
1359
-
1360
- return selected_models_list
1361
-
1362
 
1363
  if __name__ == "__main__":
1364
- with app.app_context():
1365
- # Ensure ./instance and ./votes directories exist
1366
- os.makedirs("instance", exist_ok=True)
1367
- os.makedirs("./votes", exist_ok=True) # Create votes directory if it doesn't exist
1368
- os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache audio dir exists
1369
-
1370
- # Clean up old cache audio files on startup
1371
- try:
1372
- app.logger.info(f"Clearing old cache audio files from {CACHE_AUDIO_DIR}")
1373
- for filename in os.listdir(CACHE_AUDIO_DIR):
1374
- file_path = os.path.join(CACHE_AUDIO_DIR, filename)
1375
- try:
1376
- if os.path.isfile(file_path) or os.path.islink(file_path):
1377
- os.unlink(file_path)
1378
- elif os.path.isdir(file_path):
1379
- shutil.rmtree(file_path)
1380
- except Exception as e:
1381
- app.logger.error(f'Failed to delete {file_path}. Reason: {e}')
1382
- except Exception as e:
1383
- app.logger.error(f"Error clearing cache directory {CACHE_AUDIO_DIR}: {e}")
1384
-
1385
-
1386
- # Download database if it doesn't exist (only on initial space start)
1387
- if IS_SPACES and not os.path.exists(app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "")):
1388
- try:
1389
- print("Database not found, downloading from HF dataset...")
1390
- hf_hub_download(
1391
- repo_id="TTS-AGI/database-arena-v2",
1392
- filename="tts_arena.db",
1393
- repo_type="dataset",
1394
- local_dir="instance", # download to instance/
1395
- token=os.getenv("HF_TOKEN"),
1396
- )
1397
- print("Database downloaded successfully ✅")
1398
- except Exception as e:
1399
- print(f"Error downloading database from HF dataset: {str(e)} ⚠️")
1400
-
1401
-
1402
- db.create_all() # Create tables if they don't exist
1403
- insert_initial_models()
1404
- # Setup background tasks
1405
- initialize_tts_cache() # Start populating the cache
1406
- setup_cleanup()
1407
- setup_periodic_tasks() # Renamed function call
1408
-
1409
- # Configure Flask to recognize HTTPS when behind a reverse proxy
1410
- from werkzeug.middleware.proxy_fix import ProxyFix
1411
-
1412
- # Apply ProxyFix middleware to handle reverse proxy headers
1413
- # This ensures Flask generates correct URLs with https scheme
1414
- # X-Forwarded-Proto header will be used to detect the original protocol
1415
- app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1)
1416
-
1417
- # Force Flask to prefer HTTPS for generated URLs
1418
- app.config["PREFERRED_URL_SCHEME"] = "https"
1419
-
1420
- from waitress import serve
1421
-
1422
- # Configuration for 2 vCPUs:
1423
- # - threads: typically 4-8 threads per CPU core is a good balance
1424
- # - connection_limit: maximum concurrent connections
1425
- # - channel_timeout: prevent hanging connections
1426
- threads = 12 # 6 threads per vCPU is a good balance for mixed IO/CPU workloads
1427
-
1428
- if IS_SPACES:
1429
- serve(
1430
- app,
1431
- host="0.0.0.0",
1432
- port=int(os.environ.get("PORT", 7860)),
1433
- threads=threads,
1434
- connection_limit=100,
1435
- channel_timeout=30,
1436
- url_scheme='https'
1437
- )
1438
- else:
1439
- print(f"Starting Waitress server with {threads} threads")
1440
- serve(
1441
- app,
1442
- host="0.0.0.0",
1443
- port=5000,
1444
- threads=threads,
1445
- connection_limit=100,
1446
- channel_timeout=30,
1447
- url_scheme='https' # Keep https for local dev if using proxy/tunnel
1448
- )
 
1
+ from flask import Flask, render_template_string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = Flask(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ HTML = """
6
+ <!DOCTYPE html>
7
+ <html lang="en">
8
+ <head>
9
+ <meta charset="UTF-8">
10
+ <title>Maintenance</title>
11
+ <meta name="viewport" content="width=device-width, initial-scale=1">
12
+ <script src="https://cdn.tailwindcss.com"></script>
13
+ </head>
14
+ <body class="bg-gray-100 flex items-center justify-center h-screen">
15
+ <div class="bg-white p-8 rounded-2xl shadow-lg text-center max-w-md">
16
+ <svg class="mx-auto mb-4 w-16 h-16 text-yellow-500" fill="none" stroke="currentColor" stroke-width="1.5"
17
+ viewBox="0 0 24 24">
18
+ <path stroke-linecap="round" stroke-linejoin="round"
19
+ d="M12 9v2m0 4h.01M4.93 4.93a10 10 0 0114.14 0 10 10 0 010 14.14 10 10 0 01-14.14 0 10 10 0 010-14.14z"/>
20
+ </svg>
21
+ <h1 class="text-2xl font-bold text-gray-800 mb-2">We'll be back soon!</h1>
22
+ <p class="text-gray-600">The TTS Arena is temporarily undergoing maintenance.<br>Thank you for your patience.</p>
23
+ </div>
24
+ </body>
25
+ </html>
26
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  @app.route("/")
29
+ def maintenance():
30
+ return render_template_string(HTML)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  if __name__ == "__main__":
33
+ app.run(debug=True)