Create inference.py
Browse files- inference.py +156 -0
inference.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch_geometric.data import Data
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
|
6 |
+
class GNN(torch.nn.Module):
|
7 |
+
"""
|
8 |
+
Overall graph neural network. Consists of learnable user/item (i.e., playlist/song) embeddings
|
9 |
+
and LightGCN layers.
|
10 |
+
"""
|
11 |
+
def __init__(self, embedding_dim, num_nodes, num_playlists, num_layers):
|
12 |
+
super(GNN, self).__init__()
|
13 |
+
|
14 |
+
self.embedding_dim = embedding_dim
|
15 |
+
self.num_nodes = num_nodes # total number of nodes (songs + playlists) in dataset
|
16 |
+
self.num_playlists = num_playlists # total number of playlists in dataset
|
17 |
+
self.num_layers = num_layers
|
18 |
+
|
19 |
+
# Initialize embeddings for all playlists and songs. Playlists will have indices from 0...num_playlists-1,
|
20 |
+
# songs will have indices from num_playlists...num_nodes-1
|
21 |
+
self.embeddings = torch.nn.Embedding(num_embeddings=self.num_nodes, embedding_dim=self.embedding_dim)
|
22 |
+
torch.nn.init.normal_(self.embeddings.weight, std=0.1)
|
23 |
+
|
24 |
+
self.layers = torch.nn.ModuleList() # LightGCN layers
|
25 |
+
for _ in range(self.num_layers):
|
26 |
+
self.layers.append(LightGCN())
|
27 |
+
|
28 |
+
self.sigmoid = torch.sigmoid
|
29 |
+
|
30 |
+
def forward(self):
|
31 |
+
raise NotImplementedError("forward() has not been implemented for the GNN class. Do not use")
|
32 |
+
|
33 |
+
def gnn_propagation(self, edge_index_mp):
|
34 |
+
"""
|
35 |
+
Performs the linear embedding propagation (using the LightGCN layers) and calculates final (multi-scale) embeddings
|
36 |
+
for each user/item, which are calculated as a weighted sum of that user/item's embeddings at each layer (from
|
37 |
+
0 to self.num_layers). Technically, the weighted sum here is the average, which is what the LightGCN authors recommend.
|
38 |
+
|
39 |
+
args:
|
40 |
+
edge_index_mp: a tensor of all (undirected) edges in the graph, which is used for message passing/propagation and
|
41 |
+
calculating the multi-scale embeddings. (In contrast to the evaluation/supervision edges, which are distinct
|
42 |
+
from the message passing edges and will be used for calculating loss/performance metrics).
|
43 |
+
returns:
|
44 |
+
final multi-scale embeddings for all users/items
|
45 |
+
"""
|
46 |
+
x = self.embeddings.weight # layer-0 embeddings
|
47 |
+
|
48 |
+
x_at_each_layer = [x] # stores embeddings from each layer. Start with layer-0 embeddings
|
49 |
+
for i in range(self.num_layers): # now performing the GNN propagation
|
50 |
+
x = self.layers[i](x, edge_index_mp)
|
51 |
+
x_at_each_layer.append(x)
|
52 |
+
final_embs = torch.stack(x_at_each_layer, dim=0).mean(dim=0) # take average to calculate multi-scale embeddings
|
53 |
+
return final_embs
|
54 |
+
|
55 |
+
def predict_scores(self, edge_index, embs):
|
56 |
+
"""
|
57 |
+
Calculates predicted scores for each playlist/song pair in the list of edges. Uses dot product of their embeddings.
|
58 |
+
|
59 |
+
args:
|
60 |
+
edge_index: tensor of edges (between playlists and songs) whose scores we will calculate.
|
61 |
+
embs: node embeddings for calculating predicted scores (typically the multi-scale embeddings from gnn_propagation())
|
62 |
+
returns:
|
63 |
+
predicted scores for each playlist/song pair in edge_index
|
64 |
+
"""
|
65 |
+
scores = embs[edge_index[0,:], :] * embs[edge_index[1,:], :] # taking dot product for each playlist/song pair
|
66 |
+
scores = scores.sum(dim=1)
|
67 |
+
scores = self.sigmoid(scores)
|
68 |
+
return scores
|
69 |
+
|
70 |
+
def calc_loss(self, data_mp, data_pos, data_neg):
|
71 |
+
"""
|
72 |
+
The main training step. Performs GNN propagation on message passing edges, to get multi-scale embeddings.
|
73 |
+
Then predicts scores for each training example, and calculates Bayesian Personalized Ranking (BPR) loss.
|
74 |
+
|
75 |
+
args:
|
76 |
+
data_mp: tensor of edges used for message passing / calculating multi-scale embeddings
|
77 |
+
data_pos: set of positive edges that will be used during loss calculation
|
78 |
+
data_neg: set of negative edges that will be used during loss calculation
|
79 |
+
returns:
|
80 |
+
loss calculated on the positive/negative training edges
|
81 |
+
"""
|
82 |
+
# Perform GNN propagation on message passing edges to get final embeddings
|
83 |
+
final_embs = self.gnn_propagation(data_mp.edge_index)
|
84 |
+
|
85 |
+
# Get edge prediction scores for all positive and negative evaluation edges
|
86 |
+
pos_scores = self.predict_scores(data_pos.edge_index, final_embs)
|
87 |
+
neg_scores = self.predict_scores(data_neg.edge_index, final_embs)
|
88 |
+
|
89 |
+
# # Calculate loss (binary cross-entropy). Commenting out, but can use instead of BPR if desired.
|
90 |
+
# all_scores = torch.cat([pos_scores, neg_scores], dim=0)
|
91 |
+
# all_labels = torch.cat([torch.ones(pos_scores.shape[0]), torch.zeros(neg_scores.shape[0])], dim=0)
|
92 |
+
# loss_fn = torch.nn.BCELoss()
|
93 |
+
# loss = loss_fn(all_scores, all_labels)
|
94 |
+
|
95 |
+
# Calculate loss (using variation of Bayesian Personalized Ranking loss, similar to the one used in official
|
96 |
+
# LightGCN implementation at https://github.com/gusye1234/LightGCN-PyTorch/blob/master/code/model.py#L202)
|
97 |
+
loss = -torch.log(self.sigmoid(pos_scores - neg_scores)).mean()
|
98 |
+
return loss
|
99 |
+
|
100 |
+
def evaluation(self, data_mp, data_pos, k):
|
101 |
+
"""
|
102 |
+
Performs evaluation on validation or test set. Calculates recall@k.
|
103 |
+
|
104 |
+
args:
|
105 |
+
data_mp: message passing edges to use for propagation/calculating multi-scale embeddings
|
106 |
+
data_pos: positive edges to use for scoring metrics. Should be no overlap between these edges and data_mp's edges
|
107 |
+
k: value of k to use for recall@k
|
108 |
+
returns:
|
109 |
+
dictionary mapping playlist ID -> recall@k on that playlist
|
110 |
+
"""
|
111 |
+
# Run propagation on the message-passing edges to get multi-scale embeddings
|
112 |
+
final_embs = self.gnn_propagation(data_mp.edge_index)
|
113 |
+
|
114 |
+
# Get embeddings of all unique playlists in the batch of evaluation edges
|
115 |
+
unique_playlists = torch.unique_consecutive(data_pos.edge_index[0,:])
|
116 |
+
playlist_emb = final_embs[unique_playlists, :] # has shape [number of playlists in batch, 64]
|
117 |
+
|
118 |
+
# Get embeddings of ALL songs in dataset
|
119 |
+
song_emb = final_embs[self.num_playlists:, :] # has shape [total number of songs in dataset, 64]
|
120 |
+
|
121 |
+
# All ratings for each playlist in batch to each song in entire dataset (using dot product as the scoring function)
|
122 |
+
ratings = self.sigmoid(torch.matmul(playlist_emb, song_emb.t())) # shape: [# playlists in batch, # songs in dataset]
|
123 |
+
# where entry i,j is rating of song j for playlist i
|
124 |
+
# Calculate recall@k
|
125 |
+
result = recall_at_k(ratings.cpu(), k, self.num_playlists, data_pos.edge_index.cpu(),
|
126 |
+
unique_playlists.cpu(), data_mp.edge_index.cpu())
|
127 |
+
return result
|
128 |
+
|
129 |
+
|
130 |
+
# Carga el modelo previamente entrenado
|
131 |
+
data = torch.load(os.path.join(base_dir, "data_object.pt"))
|
132 |
+
with open(os.path.join(base_dir, "dataset_stats.json"), 'r') as f:
|
133 |
+
stats = json.load(f)
|
134 |
+
num_playlists, num_nodes = stats["num_playlists"], stats["num_nodes"]
|
135 |
+
model = GNN(embedding_dim=64, num_nodes=data.num_nodes, num_playlists=num_playlists, num_layers=3)
|
136 |
+
model.load_state_dict(torch.load("pesos_modelo.pth")) # Reemplaza "pesos_modelo.pth" con el nombre de tu archivo de pesos
|
137 |
+
|
138 |
+
# Define la función de inferencia
|
139 |
+
def predict(edge_index):
|
140 |
+
# Convierte la entrada en un objeto PyG Data
|
141 |
+
data = Data(edge_index=edge_index)
|
142 |
+
|
143 |
+
# Realiza la inferencia con el modelo
|
144 |
+
model.eval()
|
145 |
+
with torch.no_grad():
|
146 |
+
output = model.gnn_propagation(data.edge_index)
|
147 |
+
|
148 |
+
# Aquí puedes realizar cualquier postprocesamiento necesario de las predicciones
|
149 |
+
return output
|
150 |
+
|
151 |
+
# Ejemplo de uso
|
152 |
+
if __name__ == "__main__":
|
153 |
+
# Aquí puedes realizar pruebas con datos de ejemplo
|
154 |
+
edge_index = np.array([[0, 1, 2], [1, 2, 0]]) # Ejemplo de datos de entrada (lista de aristas)
|
155 |
+
predictions = predict(edge_index)
|
156 |
+
print(predictions)
|