Spaces:
Running
Running
import gradio as gr | |
import sys | |
import pickle | |
import json | |
import gc | |
import torch | |
from pathlib import Path | |
import gdown | |
import os | |
import difflib | |
from datetime import datetime | |
import random | |
# Import your existing modules | |
from utils import * | |
from options import args | |
from models import model_factory | |
class LazyDict: | |
def __init__(self, file_path): | |
self.file_path = file_path | |
self._data = None | |
self._loaded = False | |
def _load_data(self): | |
if not self._loaded: | |
try: | |
with open(self.file_path, "r", encoding="utf-8") as file: | |
self._data = json.load(file) | |
self._loaded = True | |
except Exception as e: | |
print(f"Warning: Could not load {self.file_path}: {str(e)}") | |
self._data = {} | |
self._loaded = True | |
def get(self, key, default=None): | |
self._load_data() | |
return self._data.get(key, default) | |
def __contains__(self, key): | |
self._load_data() | |
return key in self._data | |
def items(self): | |
self._load_data() | |
return self._data.items() | |
def keys(self): | |
self._load_data() | |
return self._data.keys() | |
def __len__(self): | |
self._load_data() | |
return len(self._data) | |
class AnimeRecommendationSystem: | |
def __init__(self, checkpoint_path, dataset_path, animes_path, images_path, mal_urls_path, type_seq_path, genres_path): | |
self.model = None | |
self.dataset = None | |
self.checkpoint_path = checkpoint_path | |
self.dataset_path = dataset_path | |
self.animes_path = animes_path | |
# Lazy loading ile memory optimization | |
self.id_to_anime = LazyDict(animes_path) | |
self.id_to_url = LazyDict(images_path) | |
self.id_to_mal_url = LazyDict(mal_urls_path) | |
self.id_to_type_seq = LazyDict(type_seq_path) | |
self.id_to_genres = LazyDict(genres_path) | |
# Cache için weak reference kullan | |
self._cache = {} | |
self.load_model_and_data() | |
def load_model_and_data(self): | |
try: | |
print("Loading model and data...") | |
args.bert_max_len = 128 | |
# Dataset'i yükle | |
dataset_path = Path(self.dataset_path) | |
with dataset_path.open('rb') as f: | |
self.dataset = pickle.load(f)["smap"] | |
args.num_items = len(self.dataset) | |
# Model'i yükle | |
self.model = model_factory(args) | |
self.load_checkpoint() | |
# Garbage collection | |
gc.collect() | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise e | |
def load_checkpoint(self): | |
try: | |
with open(self.checkpoint_path, 'rb') as f: | |
checkpoint = torch.load(f, map_location='cpu', weights_only=False) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
self.model.eval() | |
# Checkpoint'i bellekten temizle | |
del checkpoint | |
gc.collect() | |
except Exception as e: | |
raise Exception(f"Failed to load checkpoint from {self.checkpoint_path}: {str(e)}") | |
def get_anime_genres(self, anime_id): | |
genres = self.id_to_genres.get(str(anime_id), []) | |
return [genre.title() for genre in genres] if genres else [] | |
def get_anime_image_url(self, anime_id): | |
return self.id_to_url.get(str(anime_id), None) | |
def get_anime_mal_url(self, anime_id): | |
return self.id_to_mal_url.get(str(anime_id), None) | |
def _is_hentai(self, anime_id): | |
type_seq_info = self.id_to_type_seq.get(str(anime_id)) | |
if not type_seq_info or len(type_seq_info) < 3: | |
return False | |
return type_seq_info[2] | |
def _get_type(self, anime_id): | |
type_seq_info = self.id_to_type_seq.get(str(anime_id)) | |
if not type_seq_info or len(type_seq_info) < 2: | |
return "Unknown" | |
return type_seq_info[0] | |
def find_closest_anime(self, input_name): | |
"""Finds the closest matching anime to the input name""" | |
anime_names = {} | |
# Collect all titles (main + alternative) | |
for k, v in self.id_to_anime.items(): | |
anime_id = int(k) | |
if isinstance(v, list) and len(v) > 0: | |
# Main title | |
main_title = v[0] | |
anime_names[main_title.lower().strip()] = (anime_id, main_title) | |
# Alternative titles | |
if len(v) > 1: | |
for alt_title in v[1:]: | |
if alt_title and isinstance(alt_title, str): | |
alt_title_clean = alt_title.strip() | |
if alt_title_clean: | |
anime_names[alt_title_clean.lower()] = (anime_id, main_title) | |
else: | |
title = str(v).strip() | |
anime_names[title.lower()] = (anime_id, title) | |
input_lower = input_name.lower().strip() | |
# 1. Exact match | |
if input_lower in anime_names: | |
return anime_names[input_lower] | |
# 2. Substring search | |
for anime_name_lower, (anime_id, main_title) in anime_names.items(): | |
if input_lower in anime_name_lower: | |
return (anime_id, main_title) | |
# 3. Fuzzy matching | |
anime_name_list = list(anime_names.keys()) | |
close_matches = difflib.get_close_matches(input_lower, anime_name_list, n=1, cutoff=0.6) | |
if close_matches: | |
match = close_matches[0] | |
return anime_names[match] | |
return None | |
def search_animes(self, query): | |
"""Search animes by query""" | |
animes = [] | |
query_lower = query.lower() if query else "" | |
count = 0 | |
for k, v in self.id_to_anime.items(): | |
if count >= 200: # Limit for performance | |
break | |
anime_names = v if isinstance(v, list) else [v] | |
match_found = False | |
for name in anime_names: | |
if not query or query_lower in name.lower(): | |
match_found = True | |
break | |
if match_found: | |
main_name = anime_names[0] if anime_names else "Unknown" | |
animes.append((int(k), main_name)) | |
count += 1 | |
animes.sort(key=lambda x: x[1]) | |
return animes | |
def get_recommendations(self, favorite_anime_ids, num_recommendations=20, filters=None): | |
try: | |
if not favorite_anime_ids: | |
return [], [], "Please add some favorite animes first!" | |
smap = self.dataset | |
inverted_smap = {v: k for k, v in smap.items()} | |
converted_ids = [] | |
for anime_id in favorite_anime_ids: | |
if anime_id in smap: | |
converted_ids.append(smap[anime_id]) | |
if not converted_ids: | |
return [], [], "None of the selected animes are in the model vocabulary!" | |
# Normal recommendations | |
target_len = 128 | |
padded = converted_ids + [0] * (target_len - len(converted_ids)) | |
input_tensor = torch.tensor(padded, dtype=torch.long).unsqueeze(0) | |
max_predictions = min(75, len(inverted_smap)) | |
with torch.no_grad(): | |
logits = self.model(input_tensor) | |
last_logits = logits[:, -1, :] | |
top_scores, top_indices = torch.topk(last_logits, k=max_predictions, dim=1) | |
recommendations = [] | |
scores = [] | |
for idx, score in zip(top_indices.numpy()[0], top_scores.detach().numpy()[0]): | |
if idx in inverted_smap: | |
anime_id = inverted_smap[idx] | |
if anime_id in favorite_anime_ids: | |
continue | |
if str(anime_id) in self.id_to_anime: | |
# Filter check | |
if filters and not self._should_include_anime(anime_id, filters): | |
continue | |
anime_data = self.id_to_anime.get(str(anime_id)) | |
anime_name = anime_data[0] if isinstance(anime_data, list) and len(anime_data) > 0 else str(anime_data) | |
image_url = self.get_anime_image_url(anime_id) | |
mal_url = self.get_anime_mal_url(anime_id) | |
recommendations.append({ | |
'id': anime_id, | |
'name': anime_name, | |
'score': float(score), | |
'image_url': image_url, | |
'mal_url': mal_url, | |
'genres': self.get_anime_genres(anime_id), | |
'type': self._get_type(anime_id) | |
}) | |
scores.append(float(score)) | |
if len(recommendations) >= num_recommendations: | |
break | |
# Memory cleanup | |
del logits, last_logits, top_scores, top_indices | |
gc.collect() | |
return recommendations, scores, f"Found {len(recommendations)} recommendations!" | |
except Exception as e: | |
return [], [], f"Error during prediction: {str(e)}" | |
def _should_include_anime(self, anime_id, filters): | |
"""Check if anime should be included based on filters""" | |
if not filters: | |
return True | |
type_seq_info = self.id_to_type_seq.get(str(anime_id)) | |
if not type_seq_info or len(type_seq_info) < 2: | |
return True | |
anime_type = type_seq_info[0] | |
is_sequel = type_seq_info[1] if len(type_seq_info) > 1 else False | |
is_hentai = type_seq_info[2] if len(type_seq_info) > 2 else False | |
# Hentai filter | |
if not filters.get('show_hentai', True) and is_hentai: | |
return False | |
# Sequel filter | |
if not filters.get('show_sequels', True) and is_sequel: | |
return False | |
# Type filters | |
if not filters.get('show_movies', True) and anime_type == 'MOVIE': | |
return False | |
if not filters.get('show_tv', True) and anime_type == 'TV': | |
return False | |
if not filters.get('show_ova', True) and anime_type in ['ONA', 'OVA', 'SPECIAL']: | |
return False | |
return True | |
# Global recommendation system | |
recommendation_system = None | |
def initialize_system(): | |
global recommendation_system | |
if recommendation_system is None: | |
try: | |
args.num_items = 12689 | |
file_ids = { | |
"1C6mdjblhiWGhRgbIk5DP2XCc4ElS9x8p": "pretrained_bert.pth", | |
"1U42cFrdLFT8NVNikT9C5SD9aAux7a5U2": "animes.json", | |
"1s-8FM1Wi2wOWJ9cstvm-O1_6XculTcTG": "dataset.pkl", | |
"1SOm1llcTKfhr-RTHC0dhaZ4AfWPs8wRx": "id_to_url.json", | |
"1vwJEMEOIYwvCKCCbbeaP0U_9L3NhvBzg": "anime_to_malurl.json", | |
"1_TyzON6ie2CqvzVNvPyc9prMTwLMefdu": "anime_to_typenseq.json", | |
"1G9O_ahyuJ5aO0cwoVnIXrlzMqjKrf2aw": "id_to_genres.json" | |
} | |
def download_from_gdrive(file_id, output_path): | |
url = f"https://drive.google.com/uc?id={file_id}" | |
try: | |
print(f"Downloading: {output_path}") | |
gdown.download(url, output_path, quiet=False) | |
print(f"Downloaded: {output_path}") | |
return True | |
except Exception as e: | |
print(f"Error downloading {output_path}: {e}") | |
return False | |
for file_id, filename in file_ids.items(): | |
if not os.path.isfile(filename): | |
download_from_gdrive(file_id, filename) | |
recommendation_system = AnimeRecommendationSystem( | |
"pretrained_bert.pth", | |
"dataset.pkl", | |
"animes.json", | |
"id_to_url.json", | |
"anime_to_malurl.json", | |
"anime_to_typenseq.json", | |
"id_to_genres.json" | |
) | |
print("Recommendation system initialized successfully!") | |
except Exception as e: | |
print(f"Failed to initialize recommendation system: {e}") | |
return f"Error: {str(e)}" | |
return "System ready!" | |
def search_and_add_anime(query, favorites_state): | |
"""Search anime and return search results""" | |
if not recommendation_system: | |
return "System not initialized", favorites_state, "" | |
if not query.strip(): | |
return "Please enter an anime name to search", favorites_state, "" | |
# Search for anime | |
result = recommendation_system.find_closest_anime(query.strip()) | |
if result: | |
anime_id, anime_name = result | |
# Check if already in favorites | |
if anime_id in favorites_state: | |
return f"'{anime_name}' is already in your favorites", favorites_state, "" | |
# Add to favorites | |
if len(favorites_state) >= 15: | |
return "Maximum 15 favorite animes allowed", favorites_state, "" | |
favorites_state.append(anime_id) | |
return f"Added '{anime_name}' to favorites", favorites_state, "" | |
else: | |
return f"No anime found matching '{query}'", favorites_state, "" | |
def get_favorites_display(favorites_state): | |
"""Get display string for favorites""" | |
if not favorites_state or not recommendation_system: | |
return "No favorites added yet" | |
display = "Your Favorite Animes:\n" | |
for i, anime_id in enumerate(favorites_state, 1): | |
anime_data = recommendation_system.id_to_anime.get(str(anime_id)) | |
if anime_data: | |
anime_name = anime_data[0] if isinstance(anime_data, list) else str(anime_data) | |
display += f"{i}. {anime_name}\n" | |
return display | |
def clear_favorites(favorites_state): | |
"""Clear all favorites""" | |
return "Favorites cleared", [], "" | |
def get_recommendations_gradio(favorites_state, num_recs, show_hentai, show_sequels, show_movies, show_tv, show_ova): | |
"""Get recommendations for Gradio interface""" | |
if not recommendation_system: | |
return "System not initialized" | |
if not favorites_state: | |
return "Please add some favorite animes first!" | |
# Prepare filters | |
filters = { | |
'show_hentai': show_hentai, | |
'show_sequels': show_sequels, | |
'show_movies': show_movies, | |
'show_tv': show_tv, | |
'show_ova': show_ova | |
} | |
recommendations, scores, message = recommendation_system.get_recommendations( | |
favorites_state, | |
num_recommendations=int(num_recs), | |
filters=filters | |
) | |
if not recommendations: | |
return f"No recommendations found. {message}" | |
# Format recommendations | |
result = f"**{message}**\n\n" | |
for i, rec in enumerate(recommendations, 1): | |
result += f"**{i}. {rec['name']}**\n" | |
result += f"Score: {rec['score']:.4f}\n" | |
result += f"Type: {rec.get('type', 'Unknown')}\n" | |
if rec['genres']: | |
result += f"Genres: {', '.join(rec['genres'])}\n" | |
if rec.get('mal_url'): | |
result += f"[MyAnimeList Link]({rec['mal_url']})\n" | |
result += "\n" + "-"*50 + "\n\n" | |
return result | |
def create_interface(): | |
# Initialize system | |
init_status = initialize_system() | |
print(init_status) | |
with gr.Blocks(title="Anime Recommendation System", theme=gr.themes.Soft()) as demo: | |
# State for favorites | |
favorites_state = gr.State([]) | |
gr.HTML(""" | |
<div style="text-align: center; margin-bottom: 20px;"> | |
<h1>🎌 Anime Recommendation System</h1> | |
<p>Add your favorite animes and get personalized recommendations!</p> | |
</div> | |
""") | |
with gr.Tab("Add Favorites"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
search_input = gr.Textbox( | |
label="Search Anime", | |
placeholder="Enter anime name (e.g., 'Mushoku Tensei', 'Attack on Titan')", | |
lines=1 | |
) | |
with gr.Row(): | |
add_btn = gr.Button("Add to Favorites", variant="primary") | |
clear_btn = gr.Button("Clear All Favorites", variant="secondary") | |
with gr.Column(scale=2): | |
status_output = gr.Textbox(label="Status", lines=2) | |
favorites_display = gr.Textbox( | |
label="Your Favorites", | |
lines=10, | |
interactive=False, | |
value="No favorites added yet" | |
) | |
with gr.Tab("Get Recommendations"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Recommendation Settings") | |
num_recs = gr.Slider( | |
minimum=5, | |
maximum=50, | |
value=20, | |
step=5, | |
label="Number of Recommendations" | |
) | |
gr.Markdown("### Filters") | |
show_movies = gr.Checkbox(label="Include Movies", value=True) | |
show_tv = gr.Checkbox(label="Include TV Series", value=True) | |
show_ova = gr.Checkbox(label="Include OVA/ONA/Special", value=True) | |
show_sequels = gr.Checkbox(label="Include Sequels", value=True) | |
show_hentai = gr.Checkbox(label="Include Hentai", value=False) | |
recommend_btn = gr.Button("Get Recommendations", variant="primary") | |
with gr.Column(scale=2): | |
recommendations_output = gr.Markdown( | |
label="Recommendations", | |
value="Add some favorite animes and click 'Get Recommendations'" | |
) | |
# Event handlers | |
add_btn.click( | |
fn=search_and_add_anime, | |
inputs=[search_input, favorites_state], | |
outputs=[status_output, favorites_state, search_input] | |
).then( | |
fn=get_favorites_display, | |
inputs=[favorites_state], | |
outputs=[favorites_display] | |
) | |
clear_btn.click( | |
fn=clear_favorites, | |
inputs=[favorites_state], | |
outputs=[status_output, favorites_state, search_input] | |
).then( | |
fn=get_favorites_display, | |
inputs=[favorites_state], | |
outputs=[favorites_display] | |
) | |
recommend_btn.click( | |
fn=get_recommendations_gradio, | |
inputs=[ | |
favorites_state, num_recs, show_hentai, show_sequels, | |
show_movies, show_tv, show_ova | |
], | |
outputs=[recommendations_output] | |
) | |
# Examples | |
with gr.Tab("Examples"): | |
gr.Markdown(""" | |
### How to use: | |
1. **Add Favorites**: Search and add your favorite animes | |
2. **Set Filters**: Choose what types of anime to include | |
3. **Get Recommendations**: Click to get personalized suggestions | |
### Example Searches: | |
- Mushoku Tensei | |
- Attack on Titan | |
- Demon Slayer | |
- Your Name | |
- Spirited Away | |
- One Piece | |
- Naruto | |
""") | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(server_name="0.0.0.0", server_port=7860) |