import torch import torch.nn as nn import torch.nn.functional as F import pickle import json import threading from pathlib import Path import torch.nn as nn from .token import TokenEmbedding from .position import PositionalEmbedding import time from pathlib import Path from torch.nn.utils.rnn import pad_sequence import pickle import json import os class BERTEmbedding(nn.Module): _mappings_cache = None _cache_lock = threading.Lock() @classmethod def _load_mappings(cls): if cls._mappings_cache is None: with cls._cache_lock: if cls._mappings_cache is None: # Double-checked locking try: main_dir = os.getcwd() relative_path_dataset = "Data/preprocessed/AnimeRatings_min_rating7-min_uc10-min_sc10-splitleave_one_out/dataset.pkl" relative_path_genres = "Data/AnimeRatings/id_to_genreids.json" full_path_dataset = Path(main_dir) / relative_path_dataset full_path_genres = Path(main_dir) / relative_path_genres with full_path_dataset.open('rb') as f: dataset_smap = pickle.load(f)["smap"] with full_path_genres.open('rb') as f: id_to_genres = json.load(f) cls._mappings_cache = { 'dataset_smap': dataset_smap, 'id_to_genres': id_to_genres } except Exception as e: print(f"Warning: Could not load mappings: {e}") cls._mappings_cache = { 'dataset_smap': {}, 'id_to_genres': {} } return cls._mappings_cache def __init__(self, vocab_size, embed_size, max_len, dropout=0.1, multi_genre=True, max_genres_per_anime=5): super().__init__() mappings = self._load_mappings() dataset_smap = mappings['dataset_smap'] id_to_genres = mappings['id_to_genres'] self.multi_genre = multi_genre self.max_genres_per_anime = max_genres_per_anime all_genres = set() for anime_id, genres in id_to_genres.items(): all_genres.update(genres) max_genre_id = max(all_genres) if all_genres else 0 self.num_genres = max_genre_id + 1 print(f"Detected {self.num_genres} unique genres (max_id: {max_genre_id})") self.vocab_size = vocab_size if multi_genre: self._create_multi_genre_mapping(dataset_smap, id_to_genres, vocab_size) else: self._create_single_genre_mapping(dataset_smap, id_to_genres, vocab_size) self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size) self.genre_embed = nn.Embedding(num_embeddings=self.num_genres, embedding_dim=embed_size, padding_idx=0) if multi_genre: self.fusion_layer = nn.Sequential( nn.Linear(embed_size * 2, embed_size), nn.LayerNorm(embed_size), nn.ReLU() ) self.genre_aggregation = nn.Parameter(torch.ones(max_genres_per_anime) / max_genres_per_anime) self.genre_attention = nn.MultiheadAttention(embed_size, num_heads=4, batch_first=True) else: self.fusion_layer = nn.Sequential( nn.Linear(embed_size * 2, embed_size), nn.LayerNorm(embed_size), nn.ReLU() ) self.dropout = nn.Dropout(p=dropout) self.embed_size = embed_size self._genre_cache = {} self._cache_lock = threading.Lock() def _create_single_genre_mapping(self, dataset_smap, id_to_genres, vocab_size): token_to_genre = {} for anime_id, token_id in dataset_smap.items(): if token_id < vocab_size: genre_list = id_to_genres.get(str(anime_id), [0]) genre_id = genre_list[0] if genre_list else 0 if genre_id >= self.num_genres: print(f"Warning: Genre ID {genre_id} >= {self.num_genres}, setting to 0") genre_id = 0 token_to_genre[token_id] = genre_id if token_to_genre: token_ids = torch.tensor(list(token_to_genre.keys()), dtype=torch.long) genre_ids = torch.tensor(list(token_to_genre.values()), dtype=torch.long) self.register_buffer('token_ids', token_ids) self.register_buffer('genre_ids', genre_ids) self.has_mappings = True else: self.register_buffer('token_ids', torch.empty(0, dtype=torch.long)) self.register_buffer('genre_ids', torch.empty(0, dtype=torch.long)) self.has_mappings = False def _create_multi_genre_mapping(self, dataset_smap, id_to_genres, vocab_size): token_to_genres = {} for anime_id, token_id in dataset_smap.items(): if token_id < vocab_size: genre_list = id_to_genres.get(str(anime_id), [0]) valid_genres = [] for genre_id in genre_list: if genre_id >= self.num_genres: print(f"Warning: Genre ID {genre_id} >= {self.num_genres}, setting to 0") genre_id = 0 valid_genres.append(genre_id) if len(valid_genres) < self.max_genres_per_anime: valid_genres.extend([0] * (self.max_genres_per_anime - len(valid_genres))) else: valid_genres = valid_genres[:self.max_genres_per_anime] token_to_genres[token_id] = valid_genres if token_to_genres: token_ids = torch.tensor(list(token_to_genres.keys()), dtype=torch.long) genre_ids = torch.tensor(list(token_to_genres.values()), dtype=torch.long) self.register_buffer('token_ids', token_ids) self.register_buffer('genre_ids', genre_ids) self.has_mappings = True else: self.register_buffer('token_ids', torch.empty(0, dtype=torch.long)) self.register_buffer('genre_ids', torch.empty(0, self.max_genres_per_anime, dtype=torch.long)) self.has_mappings = False def _get_single_genre_mapping(self, sequence): """Original single genre mapping with improved bounds checking""" batch_size, seq_len = sequence.shape device = sequence.device if not self.has_mappings: return torch.zeros_like(sequence) sequence = torch.clamp(sequence, 0, self.vocab_size - 1) genre_sequence = torch.zeros_like(sequence) flat_sequence = sequence.flatten() flat_genre = torch.zeros_like(flat_sequence) token_mask = torch.isin(flat_sequence, self.token_ids) if token_mask.any(): valid_tokens = flat_sequence[token_mask] with self._cache_lock: cache_key = (device, len(self.token_ids)) if cache_key not in self._genre_cache: sorted_indices = torch.argsort(self.token_ids) self._genre_cache[cache_key] = { 'sorted_tokens': self.token_ids[sorted_indices], 'sorted_genres': self.genre_ids[sorted_indices] } cached_data = self._genre_cache[cache_key] indices = torch.searchsorted(cached_data['sorted_tokens'], valid_tokens) indices = torch.clamp(indices, 0, len(cached_data['sorted_tokens']) - 1) exact_matches = cached_data['sorted_tokens'][indices] == valid_tokens genre_values = torch.where( exact_matches, cached_data['sorted_genres'][indices], torch.tensor(0, device=device, dtype=self.genre_ids.dtype) ) flat_genre[token_mask] = genre_values return flat_genre.view(batch_size, seq_len) def _get_multi_genre_mapping(self, sequence): """Get multiple genres for each anime in sequence with bounds checking""" batch_size, seq_len = sequence.shape device = sequence.device if not self.has_mappings: return torch.zeros(batch_size, seq_len, self.max_genres_per_anime, device=device, dtype=torch.long) sequence = torch.clamp(sequence, 0, self.vocab_size - 1) genre_sequences = torch.zeros(batch_size, seq_len, self.max_genres_per_anime, device=device, dtype=torch.long) flat_sequence = sequence.flatten() flat_genres = torch.zeros(len(flat_sequence), self.max_genres_per_anime, device=device, dtype=torch.long) token_mask = torch.isin(flat_sequence, self.token_ids) if token_mask.any(): valid_tokens = flat_sequence[token_mask] with self._cache_lock: cache_key = (device, len(self.token_ids), 'multi') if cache_key not in self._genre_cache: sorted_indices = torch.argsort(self.token_ids) self._genre_cache[cache_key] = { 'sorted_tokens': self.token_ids[sorted_indices], 'sorted_genres': self.genre_ids[sorted_indices] # Shape: (num_tokens, max_genres_per_anime) } cached_data = self._genre_cache[cache_key] indices = torch.searchsorted(cached_data['sorted_tokens'], valid_tokens) indices = torch.clamp(indices, 0, len(cached_data['sorted_tokens']) - 1) exact_matches = cached_data['sorted_tokens'][indices] == valid_tokens genre_values = cached_data['sorted_genres'][indices] # Shape: (num_valid_tokens, max_genres_per_anime) valid_mask = token_mask.nonzero(as_tuple=True)[0] exact_valid_mask = valid_mask[exact_matches] flat_genres[exact_valid_mask] = genre_values[exact_matches] return flat_genres.view(batch_size, seq_len, self.max_genres_per_anime) def _aggregate_genre_embeddings(self, genre_embeddings): """Aggregate multiple genre embeddings per anime""" # genre_embeddings shape: (batch_size, seq_len, max_genres_per_anime, embed_size) batch_size, seq_len, max_genres, embed_size = genre_embeddings.shape weights = F.softmax(self.genre_aggregation, dim=0) weighted_genres = torch.einsum('bsgd,g->bsd', genre_embeddings, weights) return weighted_genres def forward(self, sequence): """ Enhanced forward pass with per-anime genre processing """ if sequence.max() >= self.vocab_size: print(f"Warning: Input contains tokens >= vocab_size ({self.vocab_size})") sequence = torch.clamp(sequence, 0, self.vocab_size - 1) token_emb = self.token(sequence) if self.multi_genre: genre_sequences = self._get_multi_genre_mapping(sequence) # (batch, seq, max_genres) genre_sequences = torch.clamp(genre_sequences, 0, self.num_genres - 1) genre_embeddings = self.genre_embed(genre_sequences) # (batch, seq, max_genres, embed_size) aggregated_genre_emb = self._aggregate_genre_embeddings(genre_embeddings) # (batch, seq, embed_size) combined = torch.cat([token_emb, aggregated_genre_emb], dim=-1) else: genre_sequence = self._get_single_genre_mapping(sequence) genre_sequence = torch.clamp(genre_sequence, 0, self.num_genres - 1) genre_emb = self.genre_embed(genre_sequence) combined = torch.cat([token_emb, genre_emb], dim=-1) x = self.fusion_layer(combined) return self.dropout(x) def clear_cache(self): """Clear internal caches to free GPU memory""" with self._cache_lock: self._genre_cache.clear() @classmethod def clear_global_cache(cls): """Clear global mappings cache""" with cls._cache_lock: cls._mappings_cache = None