import gradio as gr import numpy as np import torch import torch.nn as nn from utils.data_processing import get_data from utils.utils import NeighborFinder from modules.tgn import TGN import logging import random import os logging.basicConfig(filename="debug.log", level=logging.INFO) def download_log(): return "debug.log" if os.path.exists("debug.log") else None # Load data node_features, edge_features, full_data, train_data, val_data, test_data, \ new_node_val_data, new_node_test_data = get_data('wikipedia') try: logging.info(f"Full data dict: {vars(full_data)}") except TypeError: logging.info("vars(full_data) failed. Trying manual extraction...") logging.info(f"Sources: {full_data.sources[:10]}") logging.info(f"Timestamps: {full_data.timestamps[:10]}") # Extract edge information edge_sources = full_data.sources edge_destinations = full_data.destinations edge_timestamps = full_data.timestamps edge_idxs = full_data.edge_idxs # Construct adjacency list all_nodes = list(full_data.sources) + list(full_data.destinations) n_nodes = max(all_nodes) + 1 # assuming node IDs are 0-indexed adj_list = [[] for _ in range(n_nodes)] for src, dst, ts, eidx in zip(edge_sources, edge_destinations, edge_timestamps, edge_idxs): adj_list[src].append((dst, eidx, ts)) adj_list[dst].append((src, eidx, ts)) # If the graph is undirected # Initialize the neighbor finder nf = NeighborFinder(adj_list=adj_list, uniform=False) # Initialize model (must match training args) tgn = TGN( neighbor_finder=nf, node_features=node_features, edge_features=edge_features, device='cpu', n_layers=1, n_heads=2, dropout=0.1, use_memory=True, message_dimension=100, memory_dimension=172, memory_update_at_start=True, # Should match training embedding_module_type='graph_attention', message_function='identity', aggregator_type='last', memory_updater_type='gru', n_neighbors=10, use_destination_embedding_in_message=False, use_source_embedding_in_message=False, dyrep=False ) # Initialize memory buffers with the correct dimensions new_num_nodes = 9228 # Current number of nodes memory_dim = 172 # Your defined memory dimension # Manually reinitialize memory buffers to the correct size as Parameters tgn.memory.memory = nn.Parameter(torch.randn(new_num_nodes, memory_dim)) tgn.memory.last_update = nn.Parameter(torch.zeros(new_num_nodes)) tgn.memory_updater.memory.memory = nn.Parameter(torch.randn(new_num_nodes, memory_dim)) tgn.memory_updater.memory.last_update = nn.Parameter(torch.zeros(new_num_nodes)) # Load the state_dict from the pretrained model state_dict = torch.load("tgn-attn-wikipedia.pth", map_location="cpu") # Remove any conflicting memory-related parameters state_dict.pop("memory.memory", None) state_dict.pop("memory.last_update", None) state_dict.pop("memory_updater.memory.memory", None) state_dict.pop("memory_updater.memory.last_update", None) # Now load the state_dict into the model (strict=False to ignore the missing memory buffers) tgn.load_state_dict(state_dict, strict=False) tgn.eval() def predict_old(u_id, i_id, timestamp): u = np.array([int(u_id)]) i = np.array([int(i_id)]) ts = np.array([float(timestamp)]) #prob = tgn.predict_edge_probabilities(u, i, ts) prob = tgn.compute_edge_probabilities(u, i, ts) return f"Predicted interaction probability: {prob[0]:.4f}" def get_random_negative_node(source_node, destination_node, total_nodes): candidates = list(range(total_nodes)) candidates.remove(destination_node) if source_node in candidates: candidates.remove(source_node) # optional return random.choice(candidates) def predict(u_id, i_id, timestamp, mem): # Before prediction if mem: tgn.memory.__init_memory__() # Re-initialize memory # Then run prediction u_array = np.array([int(u_id)]) # List of source nodes i_array = np.array([int(i_id)]) # List of destination nodes ts_array = np.array([int(timestamp)]) # List of timestamps # Look for matching entry in full_data to get the edge_idx edge_idx = None for idx in range(len(full_data.sources)): if (full_data.sources[idx] == u_id and full_data.destinations[idx] == i_id and full_data.timestamps[idx] == timestamp): edge_idx = idx break # Edge index for the specific source-destination pair edge_idx_array = np.array([0]) EDGE_FEAT_DIM = 172 # Or detect dynamically if needed if edge_idx is not None: edge_idx_array = np.array([int(edge_idx)]) # List of edge indices logging.info(f"Edge features shape: {edge_idx_array}") # Negative node sampling: choose random negative nodes for this example (should be handled in your dataset) total_nodes = len(full_data.unique_nodes) # or whatever gives your total node count random_negative_node = get_random_negative_node(u_id, i_id, total_nodes) negative_nodes_array = np.array([random_negative_node]) # List of negative nodes # Call compute_edge_probabilities # You can pass edge_idxs or dummy edge features here depending on how your TGN expects it try: positive_probs, negative_probs = tgn.compute_edge_probabilities( u_array, i_array, negative_nodes_array, ts_array, edge_idx_array ) except RuntimeError as e: # fallback in case edge_idxs=None causes internal error dummy_edge_features = np.zeros((1, EDGE_FEAT_DIM), dtype=np.float32) positive_probs, negative_probs = tgn.compute_edge_probabilities( u_array, i_array, negative_nodes_array, ts_array, dummy_edge_features # Your TGN must support this case ) # Now positive_probs is a tensor — use it to get top-k values topk = torch.topk(positive_probs, k=min(5, positive_probs.shape[0])) top_values = [v.item() for v in topk.values] # Flatten and convert each tensor to float # Format the top-k values #formatted_result = "\n".join([f"{i+1}. {val:.4f}" for i, val in enumerate(top_values)]) formatted_result = "\n".join([f"{val:.4f}" for i, val in enumerate(top_values)]) #return f"Top {len(top_values)} predicted interaction probabilities:\n{formatted_result}" return f"Predicted interaction probability: {formatted_result}" with gr.Blocks() as demo: gr.Markdown("## 🧠 TGN Playground (Wikipedia)") gr.Markdown("Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).") with gr.Row(): src = gr.Number(label="Source Node ID") dst = gr.Number(label="Destination Node ID") ts = gr.Number(label="Timestamp") mem = gr.Checkbox(label="Reinitialize memory, don't update on prediction!") # 👈 your checkbox predict_btn = gr.Button("Predict") result = gr.Textbox(label="Prediction Output") predict_btn.click(fn=predict, inputs=[src, dst, ts, mem], outputs=result) gr.Markdown("### Download Log File") log_file = gr.File(label="Log File", interactive=False) download_btn = gr.Button("Download Logs") download_btn.click(fn=download_log, inputs=[], outputs=log_file) demo.launch()