Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
from utils.data_processing import get_data | |
from modules.tgn import TGN | |
# Load data | |
node_features, edge_features, full_data, train_data, val_data, test_data, \ | |
new_node_val_data, new_node_test_data = get_data('wikipedia') | |
# Initialize model (must match training args) | |
tgn = TGN( | |
raw_message_dim=edge_features.shape[1], | |
memory_dim=172, | |
message_dim=100, | |
memory_update_at_end=False, | |
embedding_module_type='graph_attention', | |
message_function='identity', | |
memory_updater_type='gru', | |
n_heads=2, | |
dropout=0.1, | |
n_layers=1, | |
n_neighbors=10, | |
aggregator_type='last', | |
edge_features=edge_features, | |
node_features=node_features, | |
device='cpu', | |
use_memory=True, | |
) | |
tgn.load_state_dict(torch.load("tgn-attn-wikipedia.pth", map_location="cpu")) | |
tgn.eval() | |
def predict(u_id, i_id, timestamp): | |
u = np.array([int(u_id)]) | |
i = np.array([int(i_id)]) | |
ts = np.array([float(timestamp)]) | |
prob = tgn.predict_edge_probabilities(u, i, ts) | |
return f"Predicted interaction probability: {prob[0]:.4f}" | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Number(label="Source Node ID"), | |
gr.Number(label="Destination Node ID"), | |
gr.Number(label="Timestamp"), | |
], | |
outputs="text", | |
title="🧠 TGN Playground (Wikipedia)", | |
description="Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).", | |
) | |
demo.launch() | |