tgn-playground / app.py
ashu316's picture
Update app.py
c0dca5d verified
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from utils.data_processing import get_data
from utils.utils import NeighborFinder
from modules.tgn import TGN
import logging
import random
import os
logging.basicConfig(filename="debug.log", level=logging.INFO)
def download_log():
return "debug.log" if os.path.exists("debug.log") else None
# Load data
node_features, edge_features, full_data, train_data, val_data, test_data, \
new_node_val_data, new_node_test_data = get_data('wikipedia')
try:
logging.info(f"Full data dict: {vars(full_data)}")
except TypeError:
logging.info("vars(full_data) failed. Trying manual extraction...")
logging.info(f"Sources: {full_data.sources[:10]}")
logging.info(f"Timestamps: {full_data.timestamps[:10]}")
# Extract edge information
edge_sources = full_data.sources
edge_destinations = full_data.destinations
edge_timestamps = full_data.timestamps
edge_idxs = full_data.edge_idxs
# Construct adjacency list
all_nodes = list(full_data.sources) + list(full_data.destinations)
n_nodes = max(all_nodes) + 1 # assuming node IDs are 0-indexed
adj_list = [[] for _ in range(n_nodes)]
for src, dst, ts, eidx in zip(edge_sources, edge_destinations, edge_timestamps, edge_idxs):
adj_list[src].append((dst, eidx, ts))
adj_list[dst].append((src, eidx, ts)) # If the graph is undirected
# Initialize the neighbor finder
nf = NeighborFinder(adj_list=adj_list, uniform=False)
# Initialize model (must match training args)
tgn = TGN(
neighbor_finder=nf,
node_features=node_features,
edge_features=edge_features,
device='cpu',
n_layers=1,
n_heads=2,
dropout=0.1,
use_memory=True,
message_dimension=100,
memory_dimension=172,
memory_update_at_start=True, # Should match training
embedding_module_type='graph_attention',
message_function='identity',
aggregator_type='last',
memory_updater_type='gru',
n_neighbors=10,
use_destination_embedding_in_message=False,
use_source_embedding_in_message=False,
dyrep=False
)
# Initialize memory buffers with the correct dimensions
new_num_nodes = 9228 # Current number of nodes
memory_dim = 172 # Your defined memory dimension
# Manually reinitialize memory buffers to the correct size as Parameters
tgn.memory.memory = nn.Parameter(torch.randn(new_num_nodes, memory_dim))
tgn.memory.last_update = nn.Parameter(torch.zeros(new_num_nodes))
tgn.memory_updater.memory.memory = nn.Parameter(torch.randn(new_num_nodes, memory_dim))
tgn.memory_updater.memory.last_update = nn.Parameter(torch.zeros(new_num_nodes))
# Load the state_dict from the pretrained model
state_dict = torch.load("tgn-attn-wikipedia.pth", map_location="cpu")
# Remove any conflicting memory-related parameters
state_dict.pop("memory.memory", None)
state_dict.pop("memory.last_update", None)
state_dict.pop("memory_updater.memory.memory", None)
state_dict.pop("memory_updater.memory.last_update", None)
# Now load the state_dict into the model (strict=False to ignore the missing memory buffers)
tgn.load_state_dict(state_dict, strict=False)
tgn.eval()
def predict_old(u_id, i_id, timestamp):
u = np.array([int(u_id)])
i = np.array([int(i_id)])
ts = np.array([float(timestamp)])
#prob = tgn.predict_edge_probabilities(u, i, ts)
prob = tgn.compute_edge_probabilities(u, i, ts)
return f"Predicted interaction probability: {prob[0]:.4f}"
def get_random_negative_node(source_node, destination_node, total_nodes):
candidates = list(range(total_nodes))
candidates.remove(destination_node)
if source_node in candidates:
candidates.remove(source_node) # optional
return random.choice(candidates)
def predict(u_id, i_id, timestamp, mem):
# Before prediction
if mem:
tgn.memory.__init_memory__() # Re-initialize memory
# Then run prediction
u_array = np.array([int(u_id)]) # List of source nodes
i_array = np.array([int(i_id)]) # List of destination nodes
ts_array = np.array([int(timestamp)]) # List of timestamps
# Look for matching entry in full_data to get the edge_idx
edge_idx = None
for idx in range(len(full_data.sources)):
if (full_data.sources[idx] == u_id and
full_data.destinations[idx] == i_id and
full_data.timestamps[idx] == timestamp):
edge_idx = idx
break
# Edge index for the specific source-destination pair
edge_idx_array = np.array([0])
EDGE_FEAT_DIM = 172 # Or detect dynamically if needed
if edge_idx is not None:
edge_idx_array = np.array([int(edge_idx)]) # List of edge indices
logging.info(f"Edge features shape: {edge_idx_array}")
# Negative node sampling: choose random negative nodes for this example (should be handled in your dataset)
total_nodes = len(full_data.unique_nodes) # or whatever gives your total node count
random_negative_node = get_random_negative_node(u_id, i_id, total_nodes)
negative_nodes_array = np.array([random_negative_node]) # List of negative nodes
# Call compute_edge_probabilities
# You can pass edge_idxs or dummy edge features here depending on how your TGN expects it
try:
positive_probs, negative_probs = tgn.compute_edge_probabilities(
u_array,
i_array,
negative_nodes_array,
ts_array,
edge_idx_array
)
except RuntimeError as e:
# fallback in case edge_idxs=None causes internal error
dummy_edge_features = np.zeros((1, EDGE_FEAT_DIM), dtype=np.float32)
positive_probs, negative_probs = tgn.compute_edge_probabilities(
u_array,
i_array,
negative_nodes_array,
ts_array,
dummy_edge_features # Your TGN must support this case
)
# Now positive_probs is a tensor β€” use it to get top-k values
topk = torch.topk(positive_probs, k=min(5, positive_probs.shape[0]))
top_values = [v.item() for v in topk.values] # Flatten and convert each tensor to float
# Format the top-k values
#formatted_result = "\n".join([f"{i+1}. {val:.4f}" for i, val in enumerate(top_values)])
formatted_result = "\n".join([f"{val:.4f}" for i, val in enumerate(top_values)])
#return f"Top {len(top_values)} predicted interaction probabilities:\n{formatted_result}"
return f"Predicted interaction probability: {formatted_result}"
with gr.Blocks() as demo:
gr.Markdown("## 🧠 TGN Playground (Wikipedia)")
gr.Markdown("Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).")
with gr.Row():
src = gr.Number(label="Source Node ID")
dst = gr.Number(label="Destination Node ID")
ts = gr.Number(label="Timestamp")
mem = gr.Checkbox(label="Reinitialize memory, don't update on prediction!") # πŸ‘ˆ your checkbox
predict_btn = gr.Button("Predict")
result = gr.Textbox(label="Prediction Output")
predict_btn.click(fn=predict, inputs=[src, dst, ts, mem], outputs=result)
gr.Markdown("### Download Log File")
log_file = gr.File(label="Log File", interactive=False)
download_btn = gr.Button("Download Logs")
download_btn.click(fn=download_log, inputs=[], outputs=log_file)
demo.launch()