TTS-Arena-V2 / models.py
GitHub Actions
Sync from GitHub repo
87c9725
raw
history blame contribute delete
17.8 kB
from flask_sqlalchemy import SQLAlchemy
from flask_login import UserMixin
from datetime import datetime
import math
from sqlalchemy import func
db = SQLAlchemy()
class User(db.Model, UserMixin):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(100), unique=True, nullable=False)
hf_id = db.Column(db.String(100), unique=True, nullable=False)
join_date = db.Column(db.DateTime, default=datetime.utcnow)
votes = db.relationship("Vote", backref="user", lazy=True)
show_in_leaderboard = db.Column(db.Boolean, default=True)
def __repr__(self):
return f"<User {self.username}>"
class ModelType:
TTS = "tts"
CONVERSATIONAL = "conversational"
class Model(db.Model):
id = db.Column(db.String(100), primary_key=True)
name = db.Column(db.String(100), nullable=False)
model_type = db.Column(db.String(20), nullable=False) # 'tts' or 'conversational'
# Fix ambiguous foreign keys by specifying which foreign key to use
votes = db.relationship(
"Vote",
primaryjoin="or_(Model.id==Vote.model_chosen, Model.id==Vote.model_rejected)",
viewonly=True,
)
current_elo = db.Column(db.Float, default=1500.0)
win_count = db.Column(db.Integer, default=0)
match_count = db.Column(db.Integer, default=0)
is_open = db.Column(db.Boolean, default=False)
is_active = db.Column(
db.Boolean, default=True
) # Whether the model is active and can be voted on
model_url = db.Column(db.String(255), nullable=True)
@property
def win_rate(self):
if self.match_count == 0:
return 0
return (self.win_count / self.match_count) * 100
def __repr__(self):
return f"<Model {self.name} ({self.model_type})>"
class Vote(db.Model):
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=True)
text = db.Column(db.String(1000), nullable=False)
vote_date = db.Column(db.DateTime, default=datetime.utcnow)
model_chosen = db.Column(db.String(100), db.ForeignKey("model.id"), nullable=False)
model_rejected = db.Column(
db.String(100), db.ForeignKey("model.id"), nullable=False
)
model_type = db.Column(db.String(20), nullable=False) # 'tts' or 'conversational'
chosen = db.relationship(
"Model",
foreign_keys=[model_chosen],
backref=db.backref("chosen_votes", lazy=True),
)
rejected = db.relationship(
"Model",
foreign_keys=[model_rejected],
backref=db.backref("rejected_votes", lazy=True),
)
def __repr__(self):
return f"<Vote {self.id}: {self.model_chosen} over {self.model_rejected} ({self.model_type})>"
class EloHistory(db.Model):
id = db.Column(db.Integer, primary_key=True)
model_id = db.Column(db.String(100), db.ForeignKey("model.id"), nullable=False)
timestamp = db.Column(db.DateTime, default=datetime.utcnow)
elo_score = db.Column(db.Float, nullable=False)
vote_id = db.Column(db.Integer, db.ForeignKey("vote.id"), nullable=True)
model_type = db.Column(db.String(20), nullable=False) # 'tts' or 'conversational'
model = db.relationship("Model", backref=db.backref("elo_history", lazy=True))
vote = db.relationship("Vote", backref=db.backref("elo_changes", lazy=True))
def __repr__(self):
return f"<EloHistory {self.model_id}: {self.elo_score} at {self.timestamp} ({self.model_type})>"
def calculate_elo_change(winner_elo, loser_elo, k_factor=32):
"""Calculate Elo rating changes for a match."""
expected_winner = 1 / (1 + math.pow(10, (loser_elo - winner_elo) / 400))
expected_loser = 1 / (1 + math.pow(10, (winner_elo - loser_elo) / 400))
winner_new_elo = winner_elo + k_factor * (1 - expected_winner)
loser_new_elo = loser_elo + k_factor * (0 - expected_loser)
return winner_new_elo, loser_new_elo
def record_vote(user_id, text, chosen_model_id, rejected_model_id, model_type):
"""Record a vote and update Elo ratings."""
# Create the vote
vote = Vote(
user_id=user_id, # Can be None for anonymous votes
text=text,
model_chosen=chosen_model_id,
model_rejected=rejected_model_id,
model_type=model_type,
)
db.session.add(vote)
db.session.flush() # Get the vote ID without committing
# Get the models
chosen_model = Model.query.filter_by(
id=chosen_model_id, model_type=model_type
).first()
rejected_model = Model.query.filter_by(
id=rejected_model_id, model_type=model_type
).first()
if not chosen_model or not rejected_model:
db.session.rollback()
return None, "One or both models not found for the specified model type"
# Calculate new Elo ratings
new_chosen_elo, new_rejected_elo = calculate_elo_change(
chosen_model.current_elo, rejected_model.current_elo
)
# Update model stats
chosen_model.current_elo = new_chosen_elo
chosen_model.win_count += 1
chosen_model.match_count += 1
rejected_model.current_elo = new_rejected_elo
rejected_model.match_count += 1
# Record Elo history
chosen_history = EloHistory(
model_id=chosen_model_id,
elo_score=new_chosen_elo,
vote_id=vote.id,
model_type=model_type,
)
rejected_history = EloHistory(
model_id=rejected_model_id,
elo_score=new_rejected_elo,
vote_id=vote.id,
model_type=model_type,
)
db.session.add_all([chosen_history, rejected_history])
db.session.commit()
return vote, None
def get_leaderboard_data(model_type):
"""
Get leaderboard data for the specified model type.
Args:
model_type (str): The model type ('tts' or 'conversational')
Returns:
list: List of dictionaries containing model data for the leaderboard
"""
query = Model.query.filter_by(model_type=model_type)
# Get models ordered by ELO score
models = query.order_by(Model.current_elo.desc()).all()
result = []
for rank, model in enumerate(models, 1):
# Determine tier based on rank
if rank <= 2:
tier = "tier-s"
elif rank <= 4:
tier = "tier-a"
elif rank <= 7:
tier = "tier-b"
else:
tier = ""
result.append(
{
"rank": rank,
"id": model.id,
"name": model.name,
"model_url": model.model_url,
"win_rate": f"{model.win_rate:.0f}%",
"total_votes": model.match_count,
"elo": int(model.current_elo),
"tier": tier,
"is_open": model.is_open,
}
)
return result
def get_user_leaderboard(user_id, model_type):
"""
Get personalized leaderboard data for a specific user.
Args:
user_id (int): The user ID
model_type (str): The model type ('tts' or 'conversational')
Returns:
list: List of dictionaries containing model data for the user's personal leaderboard
"""
# Get all models of the specified type
models = Model.query.filter_by(model_type=model_type).all()
# Get user's votes
user_votes = Vote.query.filter_by(user_id=user_id, model_type=model_type).all()
# Calculate win counts and match counts for each model based on user's votes
model_stats = {model.id: {"wins": 0, "matches": 0} for model in models}
for vote in user_votes:
model_stats[vote.model_chosen]["wins"] += 1
model_stats[vote.model_chosen]["matches"] += 1
model_stats[vote.model_rejected]["matches"] += 1
# Calculate win rates and prepare result
result = []
for model in models:
stats = model_stats[model.id]
win_rate = (
(stats["wins"] / stats["matches"] * 100) if stats["matches"] > 0 else 0
)
# Only include models the user has voted on
if stats["matches"] > 0:
result.append(
{
"id": model.id,
"name": model.name,
"model_url": model.model_url,
"win_rate": f"{win_rate:.0f}%",
"total_votes": stats["matches"],
"wins": stats["wins"],
"is_open": model.is_open,
}
)
# Sort by win rate descending
result.sort(key=lambda x: float(x["win_rate"].rstrip("%")), reverse=True)
# Add rank
for i, item in enumerate(result, 1):
item["rank"] = i
return result
def get_historical_leaderboard_data(model_type, target_date=None):
"""
Get leaderboard data at a specific date in history.
Args:
model_type (str): The model type ('tts' or 'conversational')
target_date (datetime): The target date for historical data, defaults to current time
Returns:
list: List of dictionaries containing model data for the historical leaderboard
"""
if not target_date:
target_date = datetime.utcnow()
# Get all models of the specified type
models = Model.query.filter_by(model_type=model_type).all()
# Create a result list for the models
result = []
for model in models:
# Get the most recent EloHistory entry for each model before the target date
elo_entry = (
EloHistory.query.filter(
EloHistory.model_id == model.id,
EloHistory.model_type == model_type,
EloHistory.timestamp <= target_date,
)
.order_by(EloHistory.timestamp.desc())
.first()
)
# Skip models that have no history before the target date
if not elo_entry:
continue
# Count wins and matches up to the target date
match_count = Vote.query.filter(
db.or_(Vote.model_chosen == model.id, Vote.model_rejected == model.id),
Vote.model_type == model_type,
Vote.vote_date <= target_date,
).count()
win_count = Vote.query.filter(
Vote.model_chosen == model.id,
Vote.model_type == model_type,
Vote.vote_date <= target_date,
).count()
# Calculate win rate
win_rate = (win_count / match_count * 100) if match_count > 0 else 0
# Add to result
result.append(
{
"id": model.id,
"name": model.name,
"model_url": model.model_url,
"win_rate": f"{win_rate:.0f}%",
"total_votes": match_count,
"elo": int(elo_entry.elo_score),
"is_open": model.is_open,
}
)
# Sort by ELO score descending
result.sort(key=lambda x: x["elo"], reverse=True)
# Add rank and tier
for i, item in enumerate(result, 1):
item["rank"] = i
# Determine tier based on rank
if i <= 2:
item["tier"] = "tier-s"
elif i <= 4:
item["tier"] = "tier-a"
elif i <= 7:
item["tier"] = "tier-b"
else:
item["tier"] = ""
return result
def get_key_historical_dates(model_type):
"""
Get a list of key dates in the leaderboard history.
Args:
model_type (str): The model type ('tts' or 'conversational')
Returns:
list: List of datetime objects representing key dates
"""
# Get first and most recent vote dates
first_vote = (
Vote.query.filter_by(model_type=model_type)
.order_by(Vote.vote_date.asc())
.first()
)
last_vote = (
Vote.query.filter_by(model_type=model_type)
.order_by(Vote.vote_date.desc())
.first()
)
if not first_vote or not last_vote:
return []
# Generate a list of key dates - first day of each month between the first and last vote
dates = []
current_date = first_vote.vote_date.replace(day=1)
end_date = last_vote.vote_date
while current_date <= end_date:
dates.append(current_date)
# Move to next month
if current_date.month == 12:
current_date = current_date.replace(year=current_date.year + 1, month=1)
else:
current_date = current_date.replace(month=current_date.month + 1)
# Add latest date
if dates and dates[-1].month != end_date.month or dates[-1].year != end_date.year:
dates.append(end_date)
return dates
def insert_initial_models():
"""Insert initial models into the database."""
tts_models = [
Model(
id="eleven-multilingual-v2",
name="Eleven Multilingual v2",
model_type=ModelType.TTS,
is_open=False,
model_url="https://elevenlabs.io/",
),
Model(
id="eleven-turbo-v2.5",
name="Eleven Turbo v2.5",
model_type=ModelType.TTS,
is_open=False,
model_url="https://elevenlabs.io/",
),
Model(
id="eleven-flash-v2.5",
name="Eleven Flash v2.5",
model_type=ModelType.TTS,
is_open=False,
model_url="https://elevenlabs.io/",
),
Model(
id="cartesia-sonic-2",
name="Cartesia Sonic 2",
model_type=ModelType.TTS,
is_open=False,
model_url="https://cartesia.ai/",
),
Model(
id="spark-tts",
name="Spark TTS",
model_type=ModelType.TTS,
is_open=False,
model_url="https://github.com/SparkAudio/Spark-TTS",
),
Model(
id="playht-2.0",
name="PlayHT 2.0",
model_type=ModelType.TTS,
is_open=False,
model_url="https://play.ht/",
),
Model(
id="styletts2",
name="StyleTTS 2",
model_type=ModelType.TTS,
is_open=True,
model_url="https://github.com/yl4579/StyleTTS2",
),
Model(
id="kokoro-v1",
name="Kokoro v1.0",
model_type=ModelType.TTS,
is_open=True,
model_url="https://huggingface.co/hexgrad/Kokoro-82M",
),
Model(
id="cosyvoice-2.0",
name="CosyVoice 2.0",
model_type=ModelType.TTS,
is_open=True,
model_url="https://github.com/FunAudioLLM/CosyVoice",
),
Model(
id="papla-p1",
name="Papla P1",
model_type=ModelType.TTS,
is_open=False,
model_url="https://papla.media/",
),
Model(
id="hume-octave",
name="Hume Octave",
model_type=ModelType.TTS,
is_open=False,
model_url="https://hume.ai/",
),
Model(
id="megatts3",
name="MegaTTS 3",
model_type=ModelType.TTS,
is_active=False,
is_open=True,
model_url="https://github.com/bytedance/MegaTTS3",
),
]
conversational_models = [
Model(
id="csm-1b",
name="CSM 1B",
model_type=ModelType.CONVERSATIONAL,
is_open=True,
model_url="https://huggingface.co/sesame/csm-1b",
),
Model(
id="playdialog-1.0",
name="PlayDialog 1.0",
model_type=ModelType.CONVERSATIONAL,
is_open=False,
model_url="https://play.ht/",
),
Model(
id="dia-1.6b",
name="Dia 1.6B",
model_type=ModelType.CONVERSATIONAL,
is_open=True,
model_url="https://huggingface.co/nari-labs/Dia-1.6B",
),
]
all_models = tts_models + conversational_models
for model in all_models:
existing = Model.query.filter_by(
id=model.id, model_type=model.model_type
).first()
if not existing:
db.session.add(model)
else:
# Update model attributes if they've changed, but preserve other data
existing.name = model.name
existing.is_open = model.is_open
if model.is_active is not None:
existing.is_active = model.is_active
db.session.commit()
def get_top_voters(limit=10):
"""
Get the top voters by number of votes.
Args:
limit (int): Number of users to return
Returns:
list: List of dictionaries containing user data and vote counts
"""
# Query users who have opted in to the leaderboard and have at least one vote
top_users = db.session.query(
User, func.count(Vote.id).label('vote_count')
).join(Vote).filter(
User.show_in_leaderboard == True
).group_by(User.id).order_by(
func.count(Vote.id).desc()
).limit(limit).all()
result = []
for i, (user, vote_count) in enumerate(top_users, 1):
result.append({
"rank": i,
"username": user.username,
"vote_count": vote_count,
"join_date": user.join_date.strftime("%b %d, %Y")
})
return result
def toggle_user_leaderboard_visibility(user_id):
"""
Toggle whether a user appears in the voters leaderboard
Args:
user_id (int): The user ID
Returns:
bool: New visibility state
"""
user = User.query.get(user_id)
if not user:
return None
user.show_in_leaderboard = not user.show_in_leaderboard
db.session.commit()
return user.show_in_leaderboard