import gradio as gr import sys import pickle import json import gc import torch from pathlib import Path import gdown import os import difflib from datetime import datetime import random # Import your existing modules from utils import * from options import args from models import model_factory class LazyDict: def __init__(self, file_path): self.file_path = file_path self._data = None self._loaded = False def _load_data(self): if not self._loaded: try: with open(self.file_path, "r", encoding="utf-8") as file: self._data = json.load(file) self._loaded = True except Exception as e: print(f"Warning: Could not load {self.file_path}: {str(e)}") self._data = {} self._loaded = True def get(self, key, default=None): self._load_data() return self._data.get(key, default) def __contains__(self, key): self._load_data() return key in self._data def items(self): self._load_data() return self._data.items() def keys(self): self._load_data() return self._data.keys() def __len__(self): self._load_data() return len(self._data) class AnimeRecommendationSystem: def __init__(self, checkpoint_path, dataset_path, animes_path, images_path, mal_urls_path, type_seq_path, genres_path): self.model = None self.dataset = None self.checkpoint_path = checkpoint_path self.dataset_path = dataset_path self.animes_path = animes_path # Lazy loading ile memory optimization self.id_to_anime = LazyDict(animes_path) self.id_to_url = LazyDict(images_path) self.id_to_mal_url = LazyDict(mal_urls_path) self.id_to_type_seq = LazyDict(type_seq_path) self.id_to_genres = LazyDict(genres_path) # Cache için weak reference kullan self._cache = {} self.load_model_and_data() def load_model_and_data(self): try: print("Loading model and data...") args.bert_max_len = 128 # Dataset'i yükle dataset_path = Path(self.dataset_path) with dataset_path.open('rb') as f: self.dataset = pickle.load(f)["smap"] args.num_items = len(self.dataset) print(args.num_items) # Model'i yükle self.model = model_factory(args) self.load_checkpoint() # Garbage collection gc.collect() print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {str(e)}") raise e def load_checkpoint(self): try: with open(self.checkpoint_path, 'rb') as f: checkpoint = torch.load(f, map_location='cpu', weights_only=False) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() # Checkpoint'i bellekten temizle del checkpoint gc.collect() except Exception as e: raise Exception(f"Failed to load checkpoint from {self.checkpoint_path}: {str(e)}") def get_anime_genres(self, anime_id): genres = self.id_to_genres.get(str(anime_id), [])[0] return [genre.title() for genre in genres] if genres else [] def get_anime_image_url(self, anime_id): return self.id_to_url.get(str(anime_id), None) def get_anime_mal_url(self, anime_id): return self.id_to_mal_url.get(str(anime_id), None) def _get_type(self, anime_id): type_seq_info = self.id_to_type_seq.get(str(anime_id)) if not type_seq_info or len(type_seq_info) < 2: return "Unknown" return type_seq_info[0] def find_closest_anime(self, input_name): """Finds the closest matching anime to the input name""" anime_names = {} # Collect all titles (main + alternative) for k, v in self.id_to_anime.items(): anime_id = int(k) if isinstance(v, list) and len(v) > 0: # Main title main_title = v[0] anime_names[main_title.lower().strip()] = (anime_id, main_title) # Alternative titles if len(v) > 1: for alt_title in v[1:]: if alt_title and isinstance(alt_title, str): alt_title_clean = alt_title.strip() if alt_title_clean: anime_names[alt_title_clean.lower()] = (anime_id, main_title) else: title = str(v).strip() anime_names[title.lower()] = (anime_id, title) input_lower = input_name.lower().strip() # 1. Exact match if input_lower in anime_names: return anime_names[input_lower] # 2. Substring search for anime_name_lower, (anime_id, main_title) in anime_names.items(): if input_lower in anime_name_lower: return (anime_id, main_title) # 3. Fuzzy matching anime_name_list = list(anime_names.keys()) close_matches = difflib.get_close_matches(input_lower, anime_name_list, n=1, cutoff=0.6) if close_matches: match = close_matches[0] return anime_names[match] return None def search_animes(self, query): """Search animes by query""" animes = [] query_lower = query.lower() if query else "" count = 0 for k, v in self.id_to_anime.items(): if count >= 200: # Limit for performance break anime_names = v if isinstance(v, list) else [v] match_found = False for name in anime_names: if not query or query_lower in name.lower(): match_found = True break if match_found: main_name = anime_names[0] if anime_names else "Unknown" animes.append((int(k), main_name)) count += 1 animes.sort(key=lambda x: x[1]) return animes def get_recommendations(self, favorite_anime_ids, num_recommendations=20, filters=None): try: if not favorite_anime_ids: return [], [], "Please add some favorite animes first!" smap = self.dataset inverted_smap = {v: k for k, v in smap.items()} converted_ids = [] for anime_id in favorite_anime_ids: if anime_id in smap: converted_ids.append(smap[anime_id]) if not converted_ids: return [], [], "None of the selected animes are in the model vocabulary!" # Normal recommendations target_len = 128 padded = converted_ids + [0] * (target_len - len(converted_ids)) input_tensor = torch.tensor(padded, dtype=torch.long).unsqueeze(0) max_predictions = min(75, len(inverted_smap)) with torch.no_grad(): logits = self.model(input_tensor) last_logits = logits[:, -1, :] top_scores, top_indices = torch.topk(last_logits, k=max_predictions, dim=1) recommendations = [] scores = [] for idx, score in zip(top_indices.numpy()[0], top_scores.detach().numpy()[0]): if idx in inverted_smap: anime_id = inverted_smap[idx] if anime_id in favorite_anime_ids: continue if str(anime_id) in self.id_to_anime: # Filter check if filters and not self._should_include_anime(anime_id, filters): continue anime_data = self.id_to_anime.get(str(anime_id)) anime_name = anime_data[0] if isinstance(anime_data, list) and len(anime_data) > 0 else str(anime_data) image_url = self.get_anime_image_url(anime_id) mal_url = self.get_anime_mal_url(anime_id) recommendations.append({ 'id': anime_id, 'name': anime_name, 'score': float(score), 'image_url': image_url, 'mal_url': mal_url, 'genres': self.get_anime_genres(anime_id), 'type': self._get_type(anime_id) }) scores.append(float(score)) if len(recommendations) >= num_recommendations: break # Memory cleanup del logits, last_logits, top_scores, top_indices gc.collect() return recommendations, scores, f"Found {len(recommendations)} recommendations!" except Exception as e: return [], [], f"Error during prediction: {str(e)}" def _should_include_anime(self, anime_id, filters): """Check if anime should be included based on filters""" if not filters: return True type_seq_info = self.id_to_type_seq.get(str(anime_id)) if not type_seq_info or len(type_seq_info) < 2: return True anime_type = type_seq_info[0] is_sequel = type_seq_info[1] if len(type_seq_info) > 1 else False # Sequel filter if not filters.get('show_sequels', True) and is_sequel: return False # Type filters if not filters.get('show_movies', True) and anime_type == 'MOVIE': return False if not filters.get('show_tv', True) and anime_type == 'TV': return False if not filters.get('show_ova', True) and anime_type in ['ONA', 'OVA', 'SPECIAL']: return False return True # Global recommendation system recommendation_system = None def initialize_system(): global recommendation_system if recommendation_system is None: try: args.num_items = 15687 file_ids = { "1X1jUSbE4x6DbccP7mHz-nAeGcfOjSHwe": "pretrained_bert.pth", "1J1RmuJE5OjZUO0z1irVb2M-xnvuVvvHR": "animes.json", "1xGxUCbCDUnbdnJa6Ab8wgM9cpInpeQnN": "dataset.pkl", "1PtB6o_91tNWAb4zN0xj-Kf8SKvVAJp1c": "id_to_url.json", "1xVfTB_CmeYEqq6-l_BkQXo-QAUEyBfbW": "anime_to_malurl.json", "1zMbL9TpCbODKfVT5ahiaYILlnwBZNJc1": "anime_to_typenseq.json", "1LLMRhYyw82GOz3d8SUDZF9YRJdybgAFA": "id_to_genres.json", "1bW-UlKiGplb2jTt7uD-dfIx3CMXD3iWT": "id_to_genreids.json" } def download_from_gdrive(file_id, output_path): url = f"https://drive.google.com/uc?id={file_id}" try: print(f"Downloading: {output_path}") gdown.download(url, output_path, quiet=False) print(f"Downloaded: {output_path}") return True except Exception as e: print(f"Error downloading {output_path}: {e}") return False for file_id, filename in file_ids.items(): if not os.path.isfile(filename): download_from_gdrive(file_id, filename) recommendation_system = AnimeRecommendationSystem( "pretrained_bert.pth", "dataset.pkl", "animes.json", "id_to_url.json", "anime_to_malurl.json", "anime_to_typenseq.json", "id_to_genres.json" ) print("Recommendation system initialized successfully!") except Exception as e: print(f"Failed to initialize recommendation system: {e}") return f"Error: {str(e)}" return "System ready!" def search_and_add_anime(query, favorites_state): """Search anime and return search results""" if not recommendation_system: return "System not initialized", favorites_state, "" if not query.strip(): return "Please enter an anime name to search", favorites_state, "" # Search for anime result = recommendation_system.find_closest_anime(query.strip()) if result: anime_id, anime_name = result # Check if already in favorites if anime_id in favorites_state: return f"'{anime_name}' is already in your favorites", favorites_state, "" # Add to favorites if len(favorites_state) >= 15: return "Maximum 15 favorite animes allowed", favorites_state, "" favorites_state.append(anime_id) return f"Added '{anime_name}' to favorites", favorites_state, "" else: return f"No anime found matching '{query}'", favorites_state, "" def get_favorites_display(favorites_state): """Get display string for favorites""" if not favorites_state or not recommendation_system: return "No favorites added yet" display = "Your Favorite Animes:\n" for i, anime_id in enumerate(favorites_state, 1): anime_data = recommendation_system.id_to_anime.get(str(anime_id)) if anime_data: anime_name = anime_data[0] if isinstance(anime_data, list) else str(anime_data) display += f"{i}. {anime_name}\n" return display def clear_favorites(favorites_state): """Clear all favorites""" return "Favorites cleared", [], "" def get_recommendations_gradio(favorites_state, num_recs, show_sequels, show_movies, show_tv, show_ova): """Get recommendations for Gradio interface with HTML formatting for images""" if not recommendation_system: return "System not initialized" if not favorites_state: return "Please add some favorite animes first!" # Prepare filters filters = { 'show_sequels': show_sequels, 'show_movies': show_movies, 'show_tv': show_tv, 'show_ova': show_ova } recommendations, scores, message = recommendation_system.get_recommendations( favorites_state, num_recommendations=int(num_recs), filters=filters ) if not recommendations: return f"No recommendations found. {message}" # Format recommendations with HTML and images result = f"
Add your favorite animes and get personalized recommendations!