mramazan's picture
Upload 41 files
0edbb0d verified
raw
history blame
12.6 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import json
import threading
from pathlib import Path
import torch.nn as nn
from .token import TokenEmbedding
from .position import PositionalEmbedding
import time
from pathlib import Path
from torch.nn.utils.rnn import pad_sequence
import pickle
import json
import os
class BERTEmbedding(nn.Module):
_mappings_cache = None
_cache_lock = threading.Lock()
@classmethod
def _load_mappings(cls):
if cls._mappings_cache is None:
with cls._cache_lock:
if cls._mappings_cache is None: # Double-checked locking
try:
main_dir = os.getcwd()
relative_path_dataset = "Data/preprocessed/AnimeRatings_min_rating7-min_uc10-min_sc10-splitleave_one_out/dataset.pkl"
relative_path_genres = "Data/AnimeRatings/id_to_genreids.json"
full_path_dataset = Path(main_dir) / relative_path_dataset
full_path_genres = Path(main_dir) / relative_path_genres
with full_path_dataset.open('rb') as f:
dataset_smap = pickle.load(f)["smap"]
with full_path_genres.open('rb') as f:
id_to_genres = json.load(f)
cls._mappings_cache = {
'dataset_smap': dataset_smap,
'id_to_genres': id_to_genres
}
except Exception as e:
print(f"Warning: Could not load mappings: {e}")
cls._mappings_cache = {
'dataset_smap': {},
'id_to_genres': {}
}
return cls._mappings_cache
def __init__(self, vocab_size, embed_size, max_len, dropout=0.1, multi_genre=True, max_genres_per_anime=5):
super().__init__()
mappings = self._load_mappings()
dataset_smap = mappings['dataset_smap']
id_to_genres = mappings['id_to_genres']
self.multi_genre = multi_genre
self.max_genres_per_anime = max_genres_per_anime
all_genres = set()
for anime_id, genres in id_to_genres.items():
all_genres.update(genres)
max_genre_id = max(all_genres) if all_genres else 0
self.num_genres = max_genre_id + 1
print(f"Detected {self.num_genres} unique genres (max_id: {max_genre_id})")
self.vocab_size = vocab_size
if multi_genre:
self._create_multi_genre_mapping(dataset_smap, id_to_genres, vocab_size)
else:
self._create_single_genre_mapping(dataset_smap, id_to_genres, vocab_size)
self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
self.genre_embed = nn.Embedding(num_embeddings=self.num_genres, embedding_dim=embed_size, padding_idx=0)
if multi_genre:
self.fusion_layer = nn.Sequential(
nn.Linear(embed_size * 2, embed_size),
nn.LayerNorm(embed_size),
nn.ReLU()
)
self.genre_aggregation = nn.Parameter(torch.ones(max_genres_per_anime) / max_genres_per_anime)
self.genre_attention = nn.MultiheadAttention(embed_size, num_heads=4, batch_first=True)
else:
self.fusion_layer = nn.Sequential(
nn.Linear(embed_size * 2, embed_size),
nn.LayerNorm(embed_size),
nn.ReLU()
)
self.dropout = nn.Dropout(p=dropout)
self.embed_size = embed_size
self._genre_cache = {}
self._cache_lock = threading.Lock()
def _create_single_genre_mapping(self, dataset_smap, id_to_genres, vocab_size):
token_to_genre = {}
for anime_id, token_id in dataset_smap.items():
if token_id < vocab_size:
genre_list = id_to_genres.get(str(anime_id), [0])
genre_id = genre_list[0] if genre_list else 0
if genre_id >= self.num_genres:
print(f"Warning: Genre ID {genre_id} >= {self.num_genres}, setting to 0")
genre_id = 0
token_to_genre[token_id] = genre_id
if token_to_genre:
token_ids = torch.tensor(list(token_to_genre.keys()), dtype=torch.long)
genre_ids = torch.tensor(list(token_to_genre.values()), dtype=torch.long)
self.register_buffer('token_ids', token_ids)
self.register_buffer('genre_ids', genre_ids)
self.has_mappings = True
else:
self.register_buffer('token_ids', torch.empty(0, dtype=torch.long))
self.register_buffer('genre_ids', torch.empty(0, dtype=torch.long))
self.has_mappings = False
def _create_multi_genre_mapping(self, dataset_smap, id_to_genres, vocab_size):
token_to_genres = {}
for anime_id, token_id in dataset_smap.items():
if token_id < vocab_size:
genre_list = id_to_genres.get(str(anime_id), [0])
valid_genres = []
for genre_id in genre_list:
if genre_id >= self.num_genres:
print(f"Warning: Genre ID {genre_id} >= {self.num_genres}, setting to 0")
genre_id = 0
valid_genres.append(genre_id)
if len(valid_genres) < self.max_genres_per_anime:
valid_genres.extend([0] * (self.max_genres_per_anime - len(valid_genres)))
else:
valid_genres = valid_genres[:self.max_genres_per_anime]
token_to_genres[token_id] = valid_genres
if token_to_genres:
token_ids = torch.tensor(list(token_to_genres.keys()), dtype=torch.long)
genre_ids = torch.tensor(list(token_to_genres.values()), dtype=torch.long)
self.register_buffer('token_ids', token_ids)
self.register_buffer('genre_ids', genre_ids)
self.has_mappings = True
else:
self.register_buffer('token_ids', torch.empty(0, dtype=torch.long))
self.register_buffer('genre_ids', torch.empty(0, self.max_genres_per_anime, dtype=torch.long))
self.has_mappings = False
def _get_single_genre_mapping(self, sequence):
"""Original single genre mapping with improved bounds checking"""
batch_size, seq_len = sequence.shape
device = sequence.device
if not self.has_mappings:
return torch.zeros_like(sequence)
sequence = torch.clamp(sequence, 0, self.vocab_size - 1)
genre_sequence = torch.zeros_like(sequence)
flat_sequence = sequence.flatten()
flat_genre = torch.zeros_like(flat_sequence)
token_mask = torch.isin(flat_sequence, self.token_ids)
if token_mask.any():
valid_tokens = flat_sequence[token_mask]
with self._cache_lock:
cache_key = (device, len(self.token_ids))
if cache_key not in self._genre_cache:
sorted_indices = torch.argsort(self.token_ids)
self._genre_cache[cache_key] = {
'sorted_tokens': self.token_ids[sorted_indices],
'sorted_genres': self.genre_ids[sorted_indices]
}
cached_data = self._genre_cache[cache_key]
indices = torch.searchsorted(cached_data['sorted_tokens'], valid_tokens)
indices = torch.clamp(indices, 0, len(cached_data['sorted_tokens']) - 1)
exact_matches = cached_data['sorted_tokens'][indices] == valid_tokens
genre_values = torch.where(
exact_matches,
cached_data['sorted_genres'][indices],
torch.tensor(0, device=device, dtype=self.genre_ids.dtype)
)
flat_genre[token_mask] = genre_values
return flat_genre.view(batch_size, seq_len)
def _get_multi_genre_mapping(self, sequence):
"""Get multiple genres for each anime in sequence with bounds checking"""
batch_size, seq_len = sequence.shape
device = sequence.device
if not self.has_mappings:
return torch.zeros(batch_size, seq_len, self.max_genres_per_anime, device=device, dtype=torch.long)
sequence = torch.clamp(sequence, 0, self.vocab_size - 1)
genre_sequences = torch.zeros(batch_size, seq_len, self.max_genres_per_anime, device=device, dtype=torch.long)
flat_sequence = sequence.flatten()
flat_genres = torch.zeros(len(flat_sequence), self.max_genres_per_anime, device=device, dtype=torch.long)
token_mask = torch.isin(flat_sequence, self.token_ids)
if token_mask.any():
valid_tokens = flat_sequence[token_mask]
with self._cache_lock:
cache_key = (device, len(self.token_ids), 'multi')
if cache_key not in self._genre_cache:
sorted_indices = torch.argsort(self.token_ids)
self._genre_cache[cache_key] = {
'sorted_tokens': self.token_ids[sorted_indices],
'sorted_genres': self.genre_ids[sorted_indices] # Shape: (num_tokens, max_genres_per_anime)
}
cached_data = self._genre_cache[cache_key]
indices = torch.searchsorted(cached_data['sorted_tokens'], valid_tokens)
indices = torch.clamp(indices, 0, len(cached_data['sorted_tokens']) - 1)
exact_matches = cached_data['sorted_tokens'][indices] == valid_tokens
genre_values = cached_data['sorted_genres'][indices] # Shape: (num_valid_tokens, max_genres_per_anime)
valid_mask = token_mask.nonzero(as_tuple=True)[0]
exact_valid_mask = valid_mask[exact_matches]
flat_genres[exact_valid_mask] = genre_values[exact_matches]
return flat_genres.view(batch_size, seq_len, self.max_genres_per_anime)
def _aggregate_genre_embeddings(self, genre_embeddings):
"""Aggregate multiple genre embeddings per anime"""
# genre_embeddings shape: (batch_size, seq_len, max_genres_per_anime, embed_size)
batch_size, seq_len, max_genres, embed_size = genre_embeddings.shape
weights = F.softmax(self.genre_aggregation, dim=0)
weighted_genres = torch.einsum('bsgd,g->bsd', genre_embeddings, weights)
return weighted_genres
def forward(self, sequence):
"""
Enhanced forward pass with per-anime genre processing
"""
if sequence.max() >= self.vocab_size:
print(f"Warning: Input contains tokens >= vocab_size ({self.vocab_size})")
sequence = torch.clamp(sequence, 0, self.vocab_size - 1)
token_emb = self.token(sequence)
if self.multi_genre:
genre_sequences = self._get_multi_genre_mapping(sequence) # (batch, seq, max_genres)
genre_sequences = torch.clamp(genre_sequences, 0, self.num_genres - 1)
genre_embeddings = self.genre_embed(genre_sequences) # (batch, seq, max_genres, embed_size)
aggregated_genre_emb = self._aggregate_genre_embeddings(genre_embeddings) # (batch, seq, embed_size)
combined = torch.cat([token_emb, aggregated_genre_emb], dim=-1)
else:
genre_sequence = self._get_single_genre_mapping(sequence)
genre_sequence = torch.clamp(genre_sequence, 0, self.num_genres - 1)
genre_emb = self.genre_embed(genre_sequence)
combined = torch.cat([token_emb, genre_emb], dim=-1)
x = self.fusion_layer(combined)
return self.dropout(x)
def clear_cache(self):
"""Clear internal caches to free GPU memory"""
with self._cache_lock:
self._genre_cache.clear()
@classmethod
def clear_global_cache(cls):
"""Clear global mappings cache"""
with cls._cache_lock:
cls._mappings_cache = None