AnimeRecBERT / app.py
mramazan's picture
Create app.py
d9ab129 verified
raw
history blame
20.2 kB
import gradio as gr
import sys
import pickle
import json
import gc
import torch
from pathlib import Path
import gdown
import os
import difflib
from datetime import datetime
import random
# Import your existing modules
from utils import *
from options import args
from models import model_factory
class LazyDict:
def __init__(self, file_path):
self.file_path = file_path
self._data = None
self._loaded = False
def _load_data(self):
if not self._loaded:
try:
with open(self.file_path, "r", encoding="utf-8") as file:
self._data = json.load(file)
self._loaded = True
except Exception as e:
print(f"Warning: Could not load {self.file_path}: {str(e)}")
self._data = {}
self._loaded = True
def get(self, key, default=None):
self._load_data()
return self._data.get(key, default)
def __contains__(self, key):
self._load_data()
return key in self._data
def items(self):
self._load_data()
return self._data.items()
def keys(self):
self._load_data()
return self._data.keys()
def __len__(self):
self._load_data()
return len(self._data)
class AnimeRecommendationSystem:
def __init__(self, checkpoint_path, dataset_path, animes_path, images_path, mal_urls_path, type_seq_path, genres_path):
self.model = None
self.dataset = None
self.checkpoint_path = checkpoint_path
self.dataset_path = dataset_path
self.animes_path = animes_path
# Lazy loading ile memory optimization
self.id_to_anime = LazyDict(animes_path)
self.id_to_url = LazyDict(images_path)
self.id_to_mal_url = LazyDict(mal_urls_path)
self.id_to_type_seq = LazyDict(type_seq_path)
self.id_to_genres = LazyDict(genres_path)
# Cache için weak reference kullan
self._cache = {}
self.load_model_and_data()
def load_model_and_data(self):
try:
print("Loading model and data...")
args.bert_max_len = 128
# Dataset'i yükle
dataset_path = Path(self.dataset_path)
with dataset_path.open('rb') as f:
self.dataset = pickle.load(f)["smap"]
args.num_items = len(self.dataset)
# Model'i yükle
self.model = model_factory(args)
self.load_checkpoint()
# Garbage collection
gc.collect()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise e
def load_checkpoint(self):
try:
with open(self.checkpoint_path, 'rb') as f:
checkpoint = torch.load(f, map_location='cpu', weights_only=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
# Checkpoint'i bellekten temizle
del checkpoint
gc.collect()
except Exception as e:
raise Exception(f"Failed to load checkpoint from {self.checkpoint_path}: {str(e)}")
def get_anime_genres(self, anime_id):
genres = self.id_to_genres.get(str(anime_id), [])
return [genre.title() for genre in genres] if genres else []
def get_anime_image_url(self, anime_id):
return self.id_to_url.get(str(anime_id), None)
def get_anime_mal_url(self, anime_id):
return self.id_to_mal_url.get(str(anime_id), None)
def _is_hentai(self, anime_id):
type_seq_info = self.id_to_type_seq.get(str(anime_id))
if not type_seq_info or len(type_seq_info) < 3:
return False
return type_seq_info[2]
def _get_type(self, anime_id):
type_seq_info = self.id_to_type_seq.get(str(anime_id))
if not type_seq_info or len(type_seq_info) < 2:
return "Unknown"
return type_seq_info[0]
def find_closest_anime(self, input_name):
"""Finds the closest matching anime to the input name"""
anime_names = {}
# Collect all titles (main + alternative)
for k, v in self.id_to_anime.items():
anime_id = int(k)
if isinstance(v, list) and len(v) > 0:
# Main title
main_title = v[0]
anime_names[main_title.lower().strip()] = (anime_id, main_title)
# Alternative titles
if len(v) > 1:
for alt_title in v[1:]:
if alt_title and isinstance(alt_title, str):
alt_title_clean = alt_title.strip()
if alt_title_clean:
anime_names[alt_title_clean.lower()] = (anime_id, main_title)
else:
title = str(v).strip()
anime_names[title.lower()] = (anime_id, title)
input_lower = input_name.lower().strip()
# 1. Exact match
if input_lower in anime_names:
return anime_names[input_lower]
# 2. Substring search
for anime_name_lower, (anime_id, main_title) in anime_names.items():
if input_lower in anime_name_lower:
return (anime_id, main_title)
# 3. Fuzzy matching
anime_name_list = list(anime_names.keys())
close_matches = difflib.get_close_matches(input_lower, anime_name_list, n=1, cutoff=0.6)
if close_matches:
match = close_matches[0]
return anime_names[match]
return None
def search_animes(self, query):
"""Search animes by query"""
animes = []
query_lower = query.lower() if query else ""
count = 0
for k, v in self.id_to_anime.items():
if count >= 200: # Limit for performance
break
anime_names = v if isinstance(v, list) else [v]
match_found = False
for name in anime_names:
if not query or query_lower in name.lower():
match_found = True
break
if match_found:
main_name = anime_names[0] if anime_names else "Unknown"
animes.append((int(k), main_name))
count += 1
animes.sort(key=lambda x: x[1])
return animes
def get_recommendations(self, favorite_anime_ids, num_recommendations=20, filters=None):
try:
if not favorite_anime_ids:
return [], [], "Please add some favorite animes first!"
smap = self.dataset
inverted_smap = {v: k for k, v in smap.items()}
converted_ids = []
for anime_id in favorite_anime_ids:
if anime_id in smap:
converted_ids.append(smap[anime_id])
if not converted_ids:
return [], [], "None of the selected animes are in the model vocabulary!"
# Normal recommendations
target_len = 128
padded = converted_ids + [0] * (target_len - len(converted_ids))
input_tensor = torch.tensor(padded, dtype=torch.long).unsqueeze(0)
max_predictions = min(75, len(inverted_smap))
with torch.no_grad():
logits = self.model(input_tensor)
last_logits = logits[:, -1, :]
top_scores, top_indices = torch.topk(last_logits, k=max_predictions, dim=1)
recommendations = []
scores = []
for idx, score in zip(top_indices.numpy()[0], top_scores.detach().numpy()[0]):
if idx in inverted_smap:
anime_id = inverted_smap[idx]
if anime_id in favorite_anime_ids:
continue
if str(anime_id) in self.id_to_anime:
# Filter check
if filters and not self._should_include_anime(anime_id, filters):
continue
anime_data = self.id_to_anime.get(str(anime_id))
anime_name = anime_data[0] if isinstance(anime_data, list) and len(anime_data) > 0 else str(anime_data)
image_url = self.get_anime_image_url(anime_id)
mal_url = self.get_anime_mal_url(anime_id)
recommendations.append({
'id': anime_id,
'name': anime_name,
'score': float(score),
'image_url': image_url,
'mal_url': mal_url,
'genres': self.get_anime_genres(anime_id),
'type': self._get_type(anime_id)
})
scores.append(float(score))
if len(recommendations) >= num_recommendations:
break
# Memory cleanup
del logits, last_logits, top_scores, top_indices
gc.collect()
return recommendations, scores, f"Found {len(recommendations)} recommendations!"
except Exception as e:
return [], [], f"Error during prediction: {str(e)}"
def _should_include_anime(self, anime_id, filters):
"""Check if anime should be included based on filters"""
if not filters:
return True
type_seq_info = self.id_to_type_seq.get(str(anime_id))
if not type_seq_info or len(type_seq_info) < 2:
return True
anime_type = type_seq_info[0]
is_sequel = type_seq_info[1] if len(type_seq_info) > 1 else False
is_hentai = type_seq_info[2] if len(type_seq_info) > 2 else False
# Hentai filter
if not filters.get('show_hentai', True) and is_hentai:
return False
# Sequel filter
if not filters.get('show_sequels', True) and is_sequel:
return False
# Type filters
if not filters.get('show_movies', True) and anime_type == 'MOVIE':
return False
if not filters.get('show_tv', True) and anime_type == 'TV':
return False
if not filters.get('show_ova', True) and anime_type in ['ONA', 'OVA', 'SPECIAL']:
return False
return True
# Global recommendation system
recommendation_system = None
def initialize_system():
global recommendation_system
if recommendation_system is None:
try:
args.num_items = 12689
file_ids = {
"1C6mdjblhiWGhRgbIk5DP2XCc4ElS9x8p": "pretrained_bert.pth",
"1U42cFrdLFT8NVNikT9C5SD9aAux7a5U2": "animes.json",
"1s-8FM1Wi2wOWJ9cstvm-O1_6XculTcTG": "dataset.pkl",
"1SOm1llcTKfhr-RTHC0dhaZ4AfWPs8wRx": "id_to_url.json",
"1vwJEMEOIYwvCKCCbbeaP0U_9L3NhvBzg": "anime_to_malurl.json",
"1_TyzON6ie2CqvzVNvPyc9prMTwLMefdu": "anime_to_typenseq.json",
"1G9O_ahyuJ5aO0cwoVnIXrlzMqjKrf2aw": "id_to_genres.json"
}
def download_from_gdrive(file_id, output_path):
url = f"https://drive.google.com/uc?id={file_id}"
try:
print(f"Downloading: {output_path}")
gdown.download(url, output_path, quiet=False)
print(f"Downloaded: {output_path}")
return True
except Exception as e:
print(f"Error downloading {output_path}: {e}")
return False
for file_id, filename in file_ids.items():
if not os.path.isfile(filename):
download_from_gdrive(file_id, filename)
recommendation_system = AnimeRecommendationSystem(
"pretrained_bert.pth",
"dataset.pkl",
"animes.json",
"id_to_url.json",
"anime_to_malurl.json",
"anime_to_typenseq.json",
"id_to_genres.json"
)
print("Recommendation system initialized successfully!")
except Exception as e:
print(f"Failed to initialize recommendation system: {e}")
return f"Error: {str(e)}"
return "System ready!"
def search_and_add_anime(query, favorites_state):
"""Search anime and return search results"""
if not recommendation_system:
return "System not initialized", favorites_state, ""
if not query.strip():
return "Please enter an anime name to search", favorites_state, ""
# Search for anime
result = recommendation_system.find_closest_anime(query.strip())
if result:
anime_id, anime_name = result
# Check if already in favorites
if anime_id in favorites_state:
return f"'{anime_name}' is already in your favorites", favorites_state, ""
# Add to favorites
if len(favorites_state) >= 15:
return "Maximum 15 favorite animes allowed", favorites_state, ""
favorites_state.append(anime_id)
return f"Added '{anime_name}' to favorites", favorites_state, ""
else:
return f"No anime found matching '{query}'", favorites_state, ""
def get_favorites_display(favorites_state):
"""Get display string for favorites"""
if not favorites_state or not recommendation_system:
return "No favorites added yet"
display = "Your Favorite Animes:\n"
for i, anime_id in enumerate(favorites_state, 1):
anime_data = recommendation_system.id_to_anime.get(str(anime_id))
if anime_data:
anime_name = anime_data[0] if isinstance(anime_data, list) else str(anime_data)
display += f"{i}. {anime_name}\n"
return display
def clear_favorites(favorites_state):
"""Clear all favorites"""
return "Favorites cleared", [], ""
def get_recommendations_gradio(favorites_state, num_recs, show_hentai, show_sequels, show_movies, show_tv, show_ova):
"""Get recommendations for Gradio interface"""
if not recommendation_system:
return "System not initialized"
if not favorites_state:
return "Please add some favorite animes first!"
# Prepare filters
filters = {
'show_hentai': show_hentai,
'show_sequels': show_sequels,
'show_movies': show_movies,
'show_tv': show_tv,
'show_ova': show_ova
}
recommendations, scores, message = recommendation_system.get_recommendations(
favorites_state,
num_recommendations=int(num_recs),
filters=filters
)
if not recommendations:
return f"No recommendations found. {message}"
# Format recommendations
result = f"**{message}**\n\n"
for i, rec in enumerate(recommendations, 1):
result += f"**{i}. {rec['name']}**\n"
result += f"Score: {rec['score']:.4f}\n"
result += f"Type: {rec.get('type', 'Unknown')}\n"
if rec['genres']:
result += f"Genres: {', '.join(rec['genres'])}\n"
if rec.get('mal_url'):
result += f"[MyAnimeList Link]({rec['mal_url']})\n"
result += "\n" + "-"*50 + "\n\n"
return result
def create_interface():
# Initialize system
init_status = initialize_system()
print(init_status)
with gr.Blocks(title="Anime Recommendation System", theme=gr.themes.Soft()) as demo:
# State for favorites
favorites_state = gr.State([])
gr.HTML("""
<div style="text-align: center; margin-bottom: 20px;">
<h1>🎌 Anime Recommendation System</h1>
<p>Add your favorite animes and get personalized recommendations!</p>
</div>
""")
with gr.Tab("Add Favorites"):
with gr.Row():
with gr.Column(scale=2):
search_input = gr.Textbox(
label="Search Anime",
placeholder="Enter anime name (e.g., 'Mushoku Tensei', 'Attack on Titan')",
lines=1
)
with gr.Row():
add_btn = gr.Button("Add to Favorites", variant="primary")
clear_btn = gr.Button("Clear All Favorites", variant="secondary")
with gr.Column(scale=2):
status_output = gr.Textbox(label="Status", lines=2)
favorites_display = gr.Textbox(
label="Your Favorites",
lines=10,
interactive=False,
value="No favorites added yet"
)
with gr.Tab("Get Recommendations"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Recommendation Settings")
num_recs = gr.Slider(
minimum=5,
maximum=50,
value=20,
step=5,
label="Number of Recommendations"
)
gr.Markdown("### Filters")
show_movies = gr.Checkbox(label="Include Movies", value=True)
show_tv = gr.Checkbox(label="Include TV Series", value=True)
show_ova = gr.Checkbox(label="Include OVA/ONA/Special", value=True)
show_sequels = gr.Checkbox(label="Include Sequels", value=True)
show_hentai = gr.Checkbox(label="Include Hentai", value=False)
recommend_btn = gr.Button("Get Recommendations", variant="primary")
with gr.Column(scale=2):
recommendations_output = gr.Markdown(
label="Recommendations",
value="Add some favorite animes and click 'Get Recommendations'"
)
# Event handlers
add_btn.click(
fn=search_and_add_anime,
inputs=[search_input, favorites_state],
outputs=[status_output, favorites_state, search_input]
).then(
fn=get_favorites_display,
inputs=[favorites_state],
outputs=[favorites_display]
)
clear_btn.click(
fn=clear_favorites,
inputs=[favorites_state],
outputs=[status_output, favorites_state, search_input]
).then(
fn=get_favorites_display,
inputs=[favorites_state],
outputs=[favorites_display]
)
recommend_btn.click(
fn=get_recommendations_gradio,
inputs=[
favorites_state, num_recs, show_hentai, show_sequels,
show_movies, show_tv, show_ova
],
outputs=[recommendations_output]
)
# Examples
with gr.Tab("Examples"):
gr.Markdown("""
### How to use:
1. **Add Favorites**: Search and add your favorite animes
2. **Set Filters**: Choose what types of anime to include
3. **Get Recommendations**: Click to get personalized suggestions
### Example Searches:
- Mushoku Tensei
- Attack on Titan
- Demon Slayer
- Your Name
- Spirited Away
- One Piece
- Naruto
""")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860)