tgn-playground / app.py
ashu316's picture
Upload 3 files
155336d verified
raw
history blame
1.5 kB
import gradio as gr
import numpy as np
import torch
from utils.data_processing import get_data
from modules.tgn import TGN
# Load data
node_features, edge_features, full_data, train_data, val_data, test_data, \
new_node_val_data, new_node_test_data = get_data('wikipedia')
# Initialize model (must match training args)
tgn = TGN(
raw_message_dim=edge_features.shape[1],
memory_dim=172,
message_dim=100,
memory_update_at_end=False,
embedding_module_type='graph_attention',
message_function='identity',
memory_updater_type='gru',
n_heads=2,
dropout=0.1,
n_layers=1,
n_neighbors=10,
aggregator_type='last',
edge_features=edge_features,
node_features=node_features,
device='cpu',
use_memory=True,
)
tgn.load_state_dict(torch.load("tgn-attn-wikipedia.pth", map_location="cpu"))
tgn.eval()
def predict(u_id, i_id, timestamp):
u = np.array([int(u_id)])
i = np.array([int(i_id)])
ts = np.array([float(timestamp)])
prob = tgn.predict_edge_probabilities(u, i, ts)
return f"Predicted interaction probability: {prob[0]:.4f}"
demo = gr.Interface(
fn=predict,
inputs=[
gr.Number(label="Source Node ID"),
gr.Number(label="Destination Node ID"),
gr.Number(label="Timestamp"),
],
outputs="text",
title="🧠 TGN Playground (Wikipedia)",
description="Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).",
)
demo.launch()