import gradio as gr import numpy as np import torch from utils.data_processing import get_data from modules.tgn import TGN # Load data node_features, edge_features, full_data, train_data, val_data, test_data, \ new_node_val_data, new_node_test_data = get_data('wikipedia') # Initialize model (must match training args) tgn = TGN( raw_message_dim=edge_features.shape[1], memory_dim=172, message_dim=100, memory_update_at_end=False, embedding_module_type='graph_attention', message_function='identity', memory_updater_type='gru', n_heads=2, dropout=0.1, n_layers=1, n_neighbors=10, aggregator_type='last', edge_features=edge_features, node_features=node_features, device='cpu', use_memory=True, ) tgn.load_state_dict(torch.load("tgn-attn-wikipedia.pth", map_location="cpu")) tgn.eval() def predict(u_id, i_id, timestamp): u = np.array([int(u_id)]) i = np.array([int(i_id)]) ts = np.array([float(timestamp)]) prob = tgn.predict_edge_probabilities(u, i, ts) return f"Predicted interaction probability: {prob[0]:.4f}" demo = gr.Interface( fn=predict, inputs=[ gr.Number(label="Source Node ID"), gr.Number(label="Destination Node ID"), gr.Number(label="Timestamp"), ], outputs="text", title="🧠 TGN Playground (Wikipedia)", description="Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).", ) demo.launch()