mramazan commited on
Commit
0edbb0d
·
verified ·
1 Parent(s): 16f3f17

Upload 41 files

Browse files
Files changed (41) hide show
  1. models/__init__.py +14 -0
  2. models/__pycache__/__init__.cpython-312.pyc +0 -0
  3. models/__pycache__/base.cpython-312.pyc +0 -0
  4. models/__pycache__/bert.cpython-312.pyc +0 -0
  5. models/__pycache__/dae.cpython-312.pyc +0 -0
  6. models/__pycache__/vae.cpython-312.pyc +0 -0
  7. models/base.py +15 -0
  8. models/bert.py +19 -0
  9. models/bert_modules/__init__.py +1 -0
  10. models/bert_modules/__pycache__/__init__.cpython-312.pyc +0 -0
  11. models/bert_modules/__pycache__/bert.cpython-312.pyc +0 -0
  12. models/bert_modules/__pycache__/transformer.cpython-312.pyc +0 -0
  13. models/bert_modules/attention/__init__.py +2 -0
  14. models/bert_modules/attention/__pycache__/__init__.cpython-312.pyc +0 -0
  15. models/bert_modules/attention/__pycache__/multi_head.cpython-312.pyc +0 -0
  16. models/bert_modules/attention/__pycache__/single.cpython-312.pyc +0 -0
  17. models/bert_modules/attention/multi_head.py +37 -0
  18. models/bert_modules/attention/single.py +25 -0
  19. models/bert_modules/bert.py +45 -0
  20. models/bert_modules/embedding/__init__.py +1 -0
  21. models/bert_modules/embedding/__pycache__/__init__.cpython-312.pyc +0 -0
  22. models/bert_modules/embedding/__pycache__/bert.cpython-312.pyc +0 -0
  23. models/bert_modules/embedding/__pycache__/position.cpython-312.pyc +0 -0
  24. models/bert_modules/embedding/__pycache__/token.cpython-312.pyc +0 -0
  25. models/bert_modules/embedding/bert.py +304 -0
  26. models/bert_modules/embedding/position.py +16 -0
  27. models/bert_modules/embedding/segment.py +6 -0
  28. models/bert_modules/embedding/token.py +6 -0
  29. models/bert_modules/transformer.py +31 -0
  30. models/bert_modules/utils/__init__.py +4 -0
  31. models/bert_modules/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  32. models/bert_modules/utils/__pycache__/feed_forward.cpython-312.pyc +0 -0
  33. models/bert_modules/utils/__pycache__/gelu.cpython-312.pyc +0 -0
  34. models/bert_modules/utils/__pycache__/layer_norm.cpython-312.pyc +0 -0
  35. models/bert_modules/utils/__pycache__/sublayer.cpython-312.pyc +0 -0
  36. models/bert_modules/utils/feed_forward.py +16 -0
  37. models/bert_modules/utils/gelu.py +12 -0
  38. models/bert_modules/utils/layer_norm.py +17 -0
  39. models/bert_modules/utils/sublayer.py +18 -0
  40. models/dae.py +54 -0
  41. models/vae.py +69 -0
models/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bert import BERTModel
2
+ from .dae import DAEModel
3
+ from .vae import VAEModel
4
+
5
+ MODELS = {
6
+ BERTModel.code(): BERTModel,
7
+ DAEModel.code(): DAEModel,
8
+ VAEModel.code(): VAEModel
9
+ }
10
+
11
+
12
+ def model_factory(args):
13
+ model = MODELS[args.model_code]
14
+ return model(args)
models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (656 Bytes). View file
 
models/__pycache__/base.cpython-312.pyc ADDED
Binary file (884 Bytes). View file
 
models/__pycache__/bert.cpython-312.pyc ADDED
Binary file (1.31 kB). View file
 
models/__pycache__/dae.cpython-312.pyc ADDED
Binary file (3.34 kB). View file
 
models/__pycache__/vae.cpython-312.pyc ADDED
Binary file (4.03 kB). View file
 
models/base.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from abc import *
4
+
5
+
6
+ class BaseModel(nn.Module, metaclass=ABCMeta):
7
+ def __init__(self, args):
8
+ super().__init__()
9
+ self.args = args
10
+
11
+ @classmethod
12
+ @abstractmethod
13
+ def code(cls):
14
+ pass
15
+
models/bert.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseModel
2
+ from .bert_modules.bert import BERT
3
+
4
+ import torch.nn as nn
5
+
6
+
7
+ class BERTModel(BaseModel):
8
+ def __init__(self, args):
9
+ super().__init__(args)
10
+ self.bert = BERT(args)
11
+ self.out = nn.Linear(self.bert.hidden, args.num_items + 1)
12
+
13
+ @classmethod
14
+ def code(cls):
15
+ return 'bert'
16
+
17
+ def forward(self, x):
18
+ x = self.bert(x)
19
+ return self.out(x)
models/bert_modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/bert_modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (178 Bytes). View file
 
models/bert_modules/__pycache__/bert.cpython-312.pyc ADDED
Binary file (3.08 kB). View file
 
models/bert_modules/__pycache__/transformer.cpython-312.pyc ADDED
Binary file (2.26 kB). View file
 
models/bert_modules/attention/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .multi_head import MultiHeadedAttention
2
+ from .single import Attention
models/bert_modules/attention/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (287 Bytes). View file
 
models/bert_modules/attention/__pycache__/multi_head.cpython-312.pyc ADDED
Binary file (2.44 kB). View file
 
models/bert_modules/attention/__pycache__/single.cpython-312.pyc ADDED
Binary file (1.31 kB). View file
 
models/bert_modules/attention/multi_head.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .single import Attention
3
+
4
+
5
+ class MultiHeadedAttention(nn.Module):
6
+ """
7
+ Take in model size and number of heads.
8
+ """
9
+
10
+ def __init__(self, h, d_model, dropout=0.1):
11
+ super().__init__()
12
+ assert d_model % h == 0
13
+
14
+ # We assume d_v always equals d_k
15
+ self.d_k = d_model // h
16
+ self.h = h
17
+
18
+ self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
19
+ self.output_linear = nn.Linear(d_model, d_model)
20
+ self.attention = Attention()
21
+
22
+ self.dropout = nn.Dropout(p=dropout)
23
+
24
+ def forward(self, query, key, value, mask=None):
25
+ batch_size = query.size(0)
26
+
27
+ # 1) Do all the linear projections in batch from d_model => h x d_k
28
+ query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
29
+ for l, x in zip(self.linear_layers, (query, key, value))]
30
+
31
+ # 2) Apply attention on all the projected vectors in batch.
32
+ x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
33
+
34
+ # 3) "Concat" using a view and apply a final linear.
35
+ x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
36
+
37
+ return self.output_linear(x)
models/bert_modules/attention/single.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+
5
+ import math
6
+
7
+
8
+ class Attention(nn.Module):
9
+ """
10
+ Compute 'Scaled Dot Product Attention
11
+ """
12
+
13
+ def forward(self, query, key, value, mask=None, dropout=None):
14
+ scores = torch.matmul(query, key.transpose(-2, -1)) \
15
+ / math.sqrt(query.size(-1))
16
+
17
+ if mask is not None:
18
+ scores = scores.masked_fill(mask == 0, -1e9)
19
+
20
+ p_attn = F.softmax(scores, dim=-1)
21
+
22
+ if dropout is not None:
23
+ p_attn = dropout(p_attn)
24
+
25
+ return torch.matmul(p_attn, value), p_attn
models/bert_modules/bert.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+
3
+ from models.bert_modules.embedding import BERTEmbedding
4
+ from models.bert_modules.transformer import TransformerBlock
5
+ from utils import fix_random_seed_as
6
+
7
+ from pathlib import Path
8
+ import pickle
9
+ class BERT(nn.Module):
10
+ def __init__(self, args):
11
+ super().__init__()
12
+
13
+ fix_random_seed_as(args.model_init_seed)
14
+ # self.init_weights()
15
+
16
+ max_len = args.bert_max_len
17
+ num_items = args.num_items
18
+ n_layers = args.bert_num_blocks
19
+ heads = args.bert_num_heads
20
+ vocab_size = num_items + 2
21
+ hidden = args.bert_hidden_units
22
+ self.hidden = hidden
23
+ dropout = args.bert_dropout
24
+
25
+ # embedding for BERT, sum of positional, segment, token embeddings
26
+ self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=self.hidden, max_len=max_len, dropout=dropout)
27
+
28
+ # multi-layers transformer blocks, deep network
29
+ self.transformer_blocks = nn.ModuleList(
30
+ [TransformerBlock(hidden, heads, hidden * 4, dropout) for _ in range(n_layers)])
31
+
32
+ def forward(self, x):
33
+ mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
34
+
35
+ # embedding the indexed sequence to sequence of vectors
36
+ x = self.embedding(x)
37
+
38
+ # running over multiple transformer blocks
39
+ for transformer in self.transformer_blocks:
40
+ x = transformer.forward(x, mask)
41
+
42
+ return x
43
+
44
+ def init_weights(self):
45
+ pass
models/bert_modules/embedding/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .bert import BERTEmbedding
models/bert_modules/embedding/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (233 Bytes). View file
 
models/bert_modules/embedding/__pycache__/bert.cpython-312.pyc ADDED
Binary file (29.8 kB). View file
 
models/bert_modules/embedding/__pycache__/position.cpython-312.pyc ADDED
Binary file (1.19 kB). View file
 
models/bert_modules/embedding/__pycache__/token.cpython-312.pyc ADDED
Binary file (729 Bytes). View file
 
models/bert_modules/embedding/bert.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pickle
5
+ import json
6
+ import threading
7
+ from pathlib import Path
8
+ import torch.nn as nn
9
+ from .token import TokenEmbedding
10
+ from .position import PositionalEmbedding
11
+ import time
12
+ from pathlib import Path
13
+ from torch.nn.utils.rnn import pad_sequence
14
+ import pickle
15
+ import json
16
+ import os
17
+
18
+ class BERTEmbedding(nn.Module):
19
+ _mappings_cache = None
20
+ _cache_lock = threading.Lock()
21
+
22
+ @classmethod
23
+ def _load_mappings(cls):
24
+ if cls._mappings_cache is None:
25
+ with cls._cache_lock:
26
+ if cls._mappings_cache is None: # Double-checked locking
27
+ try:
28
+
29
+ main_dir = os.getcwd()
30
+
31
+ relative_path_dataset = "Data/preprocessed/AnimeRatings_min_rating7-min_uc10-min_sc10-splitleave_one_out/dataset.pkl"
32
+ relative_path_genres = "Data/AnimeRatings/id_to_genreids.json"
33
+
34
+ full_path_dataset = Path(main_dir) / relative_path_dataset
35
+ full_path_genres = Path(main_dir) / relative_path_genres
36
+
37
+
38
+ with full_path_dataset.open('rb') as f:
39
+ dataset_smap = pickle.load(f)["smap"]
40
+
41
+ with full_path_genres.open('rb') as f:
42
+ id_to_genres = json.load(f)
43
+
44
+ cls._mappings_cache = {
45
+ 'dataset_smap': dataset_smap,
46
+ 'id_to_genres': id_to_genres
47
+ }
48
+
49
+ except Exception as e:
50
+ print(f"Warning: Could not load mappings: {e}")
51
+ cls._mappings_cache = {
52
+ 'dataset_smap': {},
53
+ 'id_to_genres': {}
54
+ }
55
+ return cls._mappings_cache
56
+
57
+ def __init__(self, vocab_size, embed_size, max_len, dropout=0.1, multi_genre=True, max_genres_per_anime=5):
58
+ super().__init__()
59
+
60
+ mappings = self._load_mappings()
61
+ dataset_smap = mappings['dataset_smap']
62
+ id_to_genres = mappings['id_to_genres']
63
+
64
+ self.multi_genre = multi_genre
65
+ self.max_genres_per_anime = max_genres_per_anime
66
+
67
+ all_genres = set()
68
+ for anime_id, genres in id_to_genres.items():
69
+ all_genres.update(genres)
70
+
71
+ max_genre_id = max(all_genres) if all_genres else 0
72
+ self.num_genres = max_genre_id + 1
73
+
74
+ print(f"Detected {self.num_genres} unique genres (max_id: {max_genre_id})")
75
+
76
+ self.vocab_size = vocab_size
77
+
78
+ if multi_genre:
79
+ self._create_multi_genre_mapping(dataset_smap, id_to_genres, vocab_size)
80
+ else:
81
+ self._create_single_genre_mapping(dataset_smap, id_to_genres, vocab_size)
82
+
83
+ self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
84
+ self.genre_embed = nn.Embedding(num_embeddings=self.num_genres, embedding_dim=embed_size, padding_idx=0)
85
+
86
+ if multi_genre:
87
+ self.fusion_layer = nn.Sequential(
88
+ nn.Linear(embed_size * 2, embed_size),
89
+ nn.LayerNorm(embed_size),
90
+ nn.ReLU()
91
+ )
92
+
93
+ self.genre_aggregation = nn.Parameter(torch.ones(max_genres_per_anime) / max_genres_per_anime)
94
+ self.genre_attention = nn.MultiheadAttention(embed_size, num_heads=4, batch_first=True)
95
+ else:
96
+ self.fusion_layer = nn.Sequential(
97
+ nn.Linear(embed_size * 2, embed_size),
98
+ nn.LayerNorm(embed_size),
99
+ nn.ReLU()
100
+ )
101
+
102
+ self.dropout = nn.Dropout(p=dropout)
103
+ self.embed_size = embed_size
104
+
105
+ self._genre_cache = {}
106
+ self._cache_lock = threading.Lock()
107
+
108
+ def _create_single_genre_mapping(self, dataset_smap, id_to_genres, vocab_size):
109
+ token_to_genre = {}
110
+ for anime_id, token_id in dataset_smap.items():
111
+ if token_id < vocab_size:
112
+ genre_list = id_to_genres.get(str(anime_id), [0])
113
+ genre_id = genre_list[0] if genre_list else 0
114
+ if genre_id >= self.num_genres:
115
+ print(f"Warning: Genre ID {genre_id} >= {self.num_genres}, setting to 0")
116
+ genre_id = 0
117
+ token_to_genre[token_id] = genre_id
118
+
119
+ if token_to_genre:
120
+ token_ids = torch.tensor(list(token_to_genre.keys()), dtype=torch.long)
121
+ genre_ids = torch.tensor(list(token_to_genre.values()), dtype=torch.long)
122
+
123
+ self.register_buffer('token_ids', token_ids)
124
+ self.register_buffer('genre_ids', genre_ids)
125
+ self.has_mappings = True
126
+ else:
127
+ self.register_buffer('token_ids', torch.empty(0, dtype=torch.long))
128
+ self.register_buffer('genre_ids', torch.empty(0, dtype=torch.long))
129
+ self.has_mappings = False
130
+
131
+ def _create_multi_genre_mapping(self, dataset_smap, id_to_genres, vocab_size):
132
+ token_to_genres = {}
133
+ for anime_id, token_id in dataset_smap.items():
134
+ if token_id < vocab_size:
135
+ genre_list = id_to_genres.get(str(anime_id), [0])
136
+
137
+ valid_genres = []
138
+ for genre_id in genre_list:
139
+ if genre_id >= self.num_genres:
140
+ print(f"Warning: Genre ID {genre_id} >= {self.num_genres}, setting to 0")
141
+ genre_id = 0
142
+ valid_genres.append(genre_id)
143
+
144
+ if len(valid_genres) < self.max_genres_per_anime:
145
+ valid_genres.extend([0] * (self.max_genres_per_anime - len(valid_genres)))
146
+ else:
147
+ valid_genres = valid_genres[:self.max_genres_per_anime]
148
+
149
+ token_to_genres[token_id] = valid_genres
150
+
151
+ if token_to_genres:
152
+ token_ids = torch.tensor(list(token_to_genres.keys()), dtype=torch.long)
153
+ genre_ids = torch.tensor(list(token_to_genres.values()), dtype=torch.long)
154
+
155
+ self.register_buffer('token_ids', token_ids)
156
+ self.register_buffer('genre_ids', genre_ids)
157
+ self.has_mappings = True
158
+ else:
159
+ self.register_buffer('token_ids', torch.empty(0, dtype=torch.long))
160
+ self.register_buffer('genre_ids', torch.empty(0, self.max_genres_per_anime, dtype=torch.long))
161
+ self.has_mappings = False
162
+
163
+ def _get_single_genre_mapping(self, sequence):
164
+ """Original single genre mapping with improved bounds checking"""
165
+ batch_size, seq_len = sequence.shape
166
+ device = sequence.device
167
+
168
+ if not self.has_mappings:
169
+ return torch.zeros_like(sequence)
170
+
171
+ sequence = torch.clamp(sequence, 0, self.vocab_size - 1)
172
+
173
+ genre_sequence = torch.zeros_like(sequence)
174
+ flat_sequence = sequence.flatten()
175
+ flat_genre = torch.zeros_like(flat_sequence)
176
+
177
+ token_mask = torch.isin(flat_sequence, self.token_ids)
178
+
179
+ if token_mask.any():
180
+ valid_tokens = flat_sequence[token_mask]
181
+
182
+ with self._cache_lock:
183
+ cache_key = (device, len(self.token_ids))
184
+ if cache_key not in self._genre_cache:
185
+ sorted_indices = torch.argsort(self.token_ids)
186
+ self._genre_cache[cache_key] = {
187
+ 'sorted_tokens': self.token_ids[sorted_indices],
188
+ 'sorted_genres': self.genre_ids[sorted_indices]
189
+ }
190
+
191
+ cached_data = self._genre_cache[cache_key]
192
+
193
+ indices = torch.searchsorted(cached_data['sorted_tokens'], valid_tokens)
194
+ indices = torch.clamp(indices, 0, len(cached_data['sorted_tokens']) - 1)
195
+ exact_matches = cached_data['sorted_tokens'][indices] == valid_tokens
196
+
197
+ genre_values = torch.where(
198
+ exact_matches,
199
+ cached_data['sorted_genres'][indices],
200
+ torch.tensor(0, device=device, dtype=self.genre_ids.dtype)
201
+ )
202
+
203
+ flat_genre[token_mask] = genre_values
204
+
205
+ return flat_genre.view(batch_size, seq_len)
206
+
207
+ def _get_multi_genre_mapping(self, sequence):
208
+ """Get multiple genres for each anime in sequence with bounds checking"""
209
+ batch_size, seq_len = sequence.shape
210
+ device = sequence.device
211
+
212
+ if not self.has_mappings:
213
+ return torch.zeros(batch_size, seq_len, self.max_genres_per_anime, device=device, dtype=torch.long)
214
+
215
+ sequence = torch.clamp(sequence, 0, self.vocab_size - 1)
216
+
217
+ genre_sequences = torch.zeros(batch_size, seq_len, self.max_genres_per_anime, device=device, dtype=torch.long)
218
+
219
+ flat_sequence = sequence.flatten()
220
+ flat_genres = torch.zeros(len(flat_sequence), self.max_genres_per_anime, device=device, dtype=torch.long)
221
+
222
+ token_mask = torch.isin(flat_sequence, self.token_ids)
223
+
224
+ if token_mask.any():
225
+ valid_tokens = flat_sequence[token_mask]
226
+
227
+ with self._cache_lock:
228
+ cache_key = (device, len(self.token_ids), 'multi')
229
+ if cache_key not in self._genre_cache:
230
+ sorted_indices = torch.argsort(self.token_ids)
231
+ self._genre_cache[cache_key] = {
232
+ 'sorted_tokens': self.token_ids[sorted_indices],
233
+ 'sorted_genres': self.genre_ids[sorted_indices] # Shape: (num_tokens, max_genres_per_anime)
234
+ }
235
+
236
+ cached_data = self._genre_cache[cache_key]
237
+
238
+ indices = torch.searchsorted(cached_data['sorted_tokens'], valid_tokens)
239
+ indices = torch.clamp(indices, 0, len(cached_data['sorted_tokens']) - 1)
240
+ exact_matches = cached_data['sorted_tokens'][indices] == valid_tokens
241
+
242
+ genre_values = cached_data['sorted_genres'][indices] # Shape: (num_valid_tokens, max_genres_per_anime)
243
+
244
+ valid_mask = token_mask.nonzero(as_tuple=True)[0]
245
+ exact_valid_mask = valid_mask[exact_matches]
246
+
247
+ flat_genres[exact_valid_mask] = genre_values[exact_matches]
248
+
249
+ return flat_genres.view(batch_size, seq_len, self.max_genres_per_anime)
250
+
251
+ def _aggregate_genre_embeddings(self, genre_embeddings):
252
+ """Aggregate multiple genre embeddings per anime"""
253
+ # genre_embeddings shape: (batch_size, seq_len, max_genres_per_anime, embed_size)
254
+ batch_size, seq_len, max_genres, embed_size = genre_embeddings.shape
255
+
256
+ weights = F.softmax(self.genre_aggregation, dim=0)
257
+ weighted_genres = torch.einsum('bsgd,g->bsd', genre_embeddings, weights)
258
+
259
+ return weighted_genres
260
+
261
+ def forward(self, sequence):
262
+ """
263
+ Enhanced forward pass with per-anime genre processing
264
+ """
265
+ if sequence.max() >= self.vocab_size:
266
+ print(f"Warning: Input contains tokens >= vocab_size ({self.vocab_size})")
267
+
268
+ sequence = torch.clamp(sequence, 0, self.vocab_size - 1)
269
+
270
+ token_emb = self.token(sequence)
271
+
272
+ if self.multi_genre:
273
+ genre_sequences = self._get_multi_genre_mapping(sequence) # (batch, seq, max_genres)
274
+
275
+ genre_sequences = torch.clamp(genre_sequences, 0, self.num_genres - 1)
276
+
277
+ genre_embeddings = self.genre_embed(genre_sequences) # (batch, seq, max_genres, embed_size)
278
+
279
+ aggregated_genre_emb = self._aggregate_genre_embeddings(genre_embeddings) # (batch, seq, embed_size)
280
+
281
+ combined = torch.cat([token_emb, aggregated_genre_emb], dim=-1)
282
+ else:
283
+ genre_sequence = self._get_single_genre_mapping(sequence)
284
+
285
+ genre_sequence = torch.clamp(genre_sequence, 0, self.num_genres - 1)
286
+
287
+ genre_emb = self.genre_embed(genre_sequence)
288
+ combined = torch.cat([token_emb, genre_emb], dim=-1)
289
+
290
+ x = self.fusion_layer(combined)
291
+
292
+ return self.dropout(x)
293
+
294
+
295
+ def clear_cache(self):
296
+ """Clear internal caches to free GPU memory"""
297
+ with self._cache_lock:
298
+ self._genre_cache.clear()
299
+
300
+ @classmethod
301
+ def clear_global_cache(cls):
302
+ """Clear global mappings cache"""
303
+ with cls._cache_lock:
304
+ cls._mappings_cache = None
models/bert_modules/embedding/position.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import math
4
+
5
+
6
+ class PositionalEmbedding(nn.Module):
7
+
8
+ def __init__(self, max_len, d_model):
9
+ super().__init__()
10
+
11
+ # Compute the positional encodings once in log space.
12
+ self.pe = nn.Embedding(max_len, d_model)
13
+
14
+ def forward(self, x):
15
+ batch_size = x.size(0)
16
+ return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)
models/bert_modules/embedding/segment.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class SegmentEmbedding(nn.Embedding):
5
+ def __init__(self, embed_size=512):
6
+ super().__init__(3, embed_size, padding_idx=0)
models/bert_modules/embedding/token.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class TokenEmbedding(nn.Embedding):
5
+ def __init__(self, vocab_size, embed_size=512):
6
+ super().__init__(vocab_size, embed_size, padding_idx=0)
models/bert_modules/transformer.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .attention import MultiHeadedAttention
4
+ from .utils import SublayerConnection, PositionwiseFeedForward
5
+
6
+
7
+ class TransformerBlock(nn.Module):
8
+ """
9
+ Bidirectional Encoder = Transformer (self-attention)
10
+ Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
11
+ """
12
+
13
+ def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
14
+ """
15
+ :param hidden: hidden size of transformer
16
+ :param attn_heads: head sizes of multi-head attention
17
+ :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
18
+ :param dropout: dropout rate
19
+ """
20
+
21
+ super().__init__()
22
+ self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
23
+ self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
24
+ self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
25
+ self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
26
+ self.dropout = nn.Dropout(p=dropout)
27
+
28
+ def forward(self, x, mask):
29
+ x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
30
+ x = self.output_sublayer(x, self.feed_forward)
31
+ return self.dropout(x)
models/bert_modules/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .feed_forward import PositionwiseFeedForward
2
+ from .layer_norm import LayerNorm
3
+ from .sublayer import SublayerConnection
4
+ from .gelu import GELU
models/bert_modules/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (378 Bytes). View file
 
models/bert_modules/utils/__pycache__/feed_forward.cpython-312.pyc ADDED
Binary file (1.43 kB). View file
 
models/bert_modules/utils/__pycache__/gelu.cpython-312.pyc ADDED
Binary file (1 kB). View file
 
models/bert_modules/utils/__pycache__/layer_norm.cpython-312.pyc ADDED
Binary file (1.49 kB). View file
 
models/bert_modules/utils/__pycache__/sublayer.cpython-312.pyc ADDED
Binary file (1.34 kB). View file
 
models/bert_modules/utils/feed_forward.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .gelu import GELU
3
+
4
+
5
+ class PositionwiseFeedForward(nn.Module):
6
+ "Implements FFN equation."
7
+
8
+ def __init__(self, d_model, d_ff, dropout=0.1):
9
+ super(PositionwiseFeedForward, self).__init__()
10
+ self.w_1 = nn.Linear(d_model, d_ff)
11
+ self.w_2 = nn.Linear(d_ff, d_model)
12
+ self.dropout = nn.Dropout(dropout)
13
+ self.activation = GELU()
14
+
15
+ def forward(self, x):
16
+ return self.w_2(self.dropout(self.activation(self.w_1(x))))
models/bert_modules/utils/gelu.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import math
4
+
5
+
6
+ class GELU(nn.Module):
7
+ """
8
+ Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
9
+ """
10
+
11
+ def forward(self, x):
12
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
models/bert_modules/utils/layer_norm.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ class LayerNorm(nn.Module):
6
+ "Construct a layernorm module (See citation for details)."
7
+
8
+ def __init__(self, features, eps=1e-6):
9
+ super(LayerNorm, self).__init__()
10
+ self.a_2 = nn.Parameter(torch.ones(features))
11
+ self.b_2 = nn.Parameter(torch.zeros(features))
12
+ self.eps = eps
13
+
14
+ def forward(self, x):
15
+ mean = x.mean(-1, keepdim=True)
16
+ std = x.std(-1, keepdim=True)
17
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
models/bert_modules/utils/sublayer.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .layer_norm import LayerNorm
3
+
4
+
5
+ class SublayerConnection(nn.Module):
6
+ """
7
+ A residual connection followed by a layer norm.
8
+ Note for code simplicity the norm is first as opposed to last.
9
+ """
10
+
11
+ def __init__(self, size, dropout):
12
+ super(SublayerConnection, self).__init__()
13
+ self.norm = LayerNorm(size)
14
+ self.dropout = nn.Dropout(dropout)
15
+
16
+ def forward(self, x, sublayer):
17
+ "Apply residual connection to any sublayer with the same size."
18
+ return x + self.dropout(sublayer(self.norm(x)))
models/dae.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseModel
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class DAEModel(BaseModel):
9
+ def __init__(self, args):
10
+ super().__init__(args)
11
+
12
+ # Input dropout
13
+ self.input_dropout = nn.Dropout(p=args.dae_dropout)
14
+
15
+ # Construct a list of dimensions for the encoder and the decoder
16
+ dims = [args.dae_hidden_dim] * 2 * args.dae_num_hidden
17
+ dims = [args.num_items] + dims + [args.dae_latent_dim]
18
+
19
+ # Stack encoders and decoders
20
+ encoder_modules, decoder_modules = [], []
21
+ for i in range(len(dims)//2):
22
+ encoder_modules.append(nn.Linear(dims[2*i], dims[2*i+1]))
23
+ decoder_modules.append(nn.Linear(dims[-2*i-1], dims[-2*i-2]))
24
+ self.encoder = nn.ModuleList(encoder_modules)
25
+ self.decoder = nn.ModuleList(decoder_modules)
26
+
27
+ # Initialize weights
28
+ self.encoder.apply(self.weight_init)
29
+ self.decoder.apply(self.weight_init)
30
+
31
+ def weight_init(self, m):
32
+ if isinstance(m, nn.Linear):
33
+ nn.init.kaiming_normal_(m.weight)
34
+ m.bias.data.normal_(0.0, 0.001)
35
+
36
+ @classmethod
37
+ def code(cls):
38
+ return 'dae'
39
+
40
+ def forward(self, x):
41
+ x = F.normalize(x)
42
+ x = self.input_dropout(x)
43
+
44
+ for i, layer in enumerate(self.encoder):
45
+ x = layer(x)
46
+ x = torch.tanh(x)
47
+
48
+ for i, layer in enumerate(self.decoder):
49
+ x = layer(x)
50
+ if i != len(self.decoder)-1:
51
+ x = torch.tanh(x)
52
+
53
+ return x
54
+
models/vae.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseModel
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class VAEModel(BaseModel):
9
+ def __init__(self, args):
10
+ super().__init__(args)
11
+ self.latent_dim = args.vae_latent_dim
12
+
13
+ # Input dropout
14
+ self.input_dropout = nn.Dropout(p=args.vae_dropout)
15
+
16
+ # Construct a list of dimensions for the encoder and the decoder
17
+ dims = [args.vae_hidden_dim] * 2 * args.vae_num_hidden
18
+ dims = [args.num_items] + dims + [args.vae_latent_dim * 2]
19
+
20
+ # Stack encoders and decoders
21
+ encoder_modules, decoder_modules = [], []
22
+ for i in range(len(dims)//2):
23
+ encoder_modules.append(nn.Linear(dims[2*i], dims[2*i+1]))
24
+ if i == 0:
25
+ decoder_modules.append(nn.Linear(dims[-1]//2, dims[-2]))
26
+ else:
27
+ decoder_modules.append(nn.Linear(dims[-2*i-1], dims[-2*i-2]))
28
+ self.encoder = nn.ModuleList(encoder_modules)
29
+ self.decoder = nn.ModuleList(decoder_modules)
30
+
31
+ # Initialize weights
32
+ self.encoder.apply(self.weight_init)
33
+ self.decoder.apply(self.weight_init)
34
+
35
+ def weight_init(self, m):
36
+ if isinstance(m, nn.Linear):
37
+ nn.init.kaiming_normal_(m.weight)
38
+ m.bias.data.zero_()
39
+
40
+ @classmethod
41
+ def code(cls):
42
+ return 'vae'
43
+
44
+ def forward(self, x):
45
+ x = F.normalize(x)
46
+ x = self.input_dropout(x)
47
+
48
+ for i, layer in enumerate(self.encoder):
49
+ x = layer(x)
50
+ if i != len(self.encoder) - 1:
51
+ x = torch.tanh(x)
52
+
53
+ mu, logvar = x[:, :self.latent_dim], x[:, self.latent_dim:]
54
+
55
+ if self.training:
56
+ # since log(var) = log(sigma^2) = 2*log(sigma)
57
+ sigma = torch.exp(0.5 * logvar)
58
+ eps = torch.randn_like(sigma)
59
+ x = mu + eps * sigma
60
+ else:
61
+ x = mu
62
+
63
+ for i, layer in enumerate(self.decoder):
64
+ x = layer(x)
65
+ if i != len(self.decoder) - 1:
66
+ x = torch.tanh(x)
67
+
68
+ return x, mu, logvar
69
+