Spaces:
Build error
Build error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from utils.data_processing import get_data | |
| from utils.utils import NeighborFinder | |
| from modules.tgn import TGN | |
| import logging | |
| import random | |
| import os | |
| logging.basicConfig(filename="debug.log", level=logging.INFO) | |
| def download_log(): | |
| return "debug.log" if os.path.exists("debug.log") else None | |
| # Load data | |
| node_features, edge_features, full_data, train_data, val_data, test_data, \ | |
| new_node_val_data, new_node_test_data = get_data('wikipedia') | |
| try: | |
| logging.info(f"Full data dict: {vars(full_data)}") | |
| except TypeError: | |
| logging.info("vars(full_data) failed. Trying manual extraction...") | |
| logging.info(f"Sources: {full_data.sources[:10]}") | |
| logging.info(f"Timestamps: {full_data.timestamps[:10]}") | |
| # Extract edge information | |
| edge_sources = full_data.sources | |
| edge_destinations = full_data.destinations | |
| edge_timestamps = full_data.timestamps | |
| edge_idxs = full_data.edge_idxs | |
| # Construct adjacency list | |
| all_nodes = list(full_data.sources) + list(full_data.destinations) | |
| n_nodes = max(all_nodes) + 1 # assuming node IDs are 0-indexed | |
| adj_list = [[] for _ in range(n_nodes)] | |
| for src, dst, ts, eidx in zip(edge_sources, edge_destinations, edge_timestamps, edge_idxs): | |
| adj_list[src].append((dst, eidx, ts)) | |
| adj_list[dst].append((src, eidx, ts)) # If the graph is undirected | |
| # Initialize the neighbor finder | |
| nf = NeighborFinder(adj_list=adj_list, uniform=False) | |
| # Initialize model (must match training args) | |
| tgn = TGN( | |
| neighbor_finder=nf, | |
| node_features=node_features, | |
| edge_features=edge_features, | |
| device='cpu', | |
| n_layers=1, | |
| n_heads=2, | |
| dropout=0.1, | |
| use_memory=True, | |
| message_dimension=100, | |
| memory_dimension=172, | |
| memory_update_at_start=True, # Should match training | |
| embedding_module_type='graph_attention', | |
| message_function='identity', | |
| aggregator_type='last', | |
| memory_updater_type='gru', | |
| n_neighbors=10, | |
| use_destination_embedding_in_message=False, | |
| use_source_embedding_in_message=False, | |
| dyrep=False | |
| ) | |
| # Initialize memory buffers with the correct dimensions | |
| new_num_nodes = 9228 # Current number of nodes | |
| memory_dim = 172 # Your defined memory dimension | |
| # Manually reinitialize memory buffers to the correct size as Parameters | |
| tgn.memory.memory = nn.Parameter(torch.randn(new_num_nodes, memory_dim)) | |
| tgn.memory.last_update = nn.Parameter(torch.zeros(new_num_nodes)) | |
| tgn.memory_updater.memory.memory = nn.Parameter(torch.randn(new_num_nodes, memory_dim)) | |
| tgn.memory_updater.memory.last_update = nn.Parameter(torch.zeros(new_num_nodes)) | |
| # Load the state_dict from the pretrained model | |
| state_dict = torch.load("tgn-attn-wikipedia.pth", map_location="cpu") | |
| # Remove any conflicting memory-related parameters | |
| state_dict.pop("memory.memory", None) | |
| state_dict.pop("memory.last_update", None) | |
| state_dict.pop("memory_updater.memory.memory", None) | |
| state_dict.pop("memory_updater.memory.last_update", None) | |
| # Now load the state_dict into the model (strict=False to ignore the missing memory buffers) | |
| tgn.load_state_dict(state_dict, strict=False) | |
| tgn.eval() | |
| def predict_old(u_id, i_id, timestamp): | |
| u = np.array([int(u_id)]) | |
| i = np.array([int(i_id)]) | |
| ts = np.array([float(timestamp)]) | |
| #prob = tgn.predict_edge_probabilities(u, i, ts) | |
| prob = tgn.compute_edge_probabilities(u, i, ts) | |
| return f"Predicted interaction probability: {prob[0]:.4f}" | |
| def get_random_negative_node(source_node, destination_node, total_nodes): | |
| candidates = list(range(total_nodes)) | |
| candidates.remove(destination_node) | |
| if source_node in candidates: | |
| candidates.remove(source_node) # optional | |
| return random.choice(candidates) | |
| def predict(u_id, i_id, timestamp, mem): | |
| # Before prediction | |
| if mem: | |
| tgn.memory.__init_memory__() # Re-initialize memory | |
| # Then run prediction | |
| u_array = np.array([int(u_id)]) # List of source nodes | |
| i_array = np.array([int(i_id)]) # List of destination nodes | |
| ts_array = np.array([int(timestamp)]) # List of timestamps | |
| # Look for matching entry in full_data to get the edge_idx | |
| edge_idx = None | |
| for idx in range(len(full_data.sources)): | |
| if (full_data.sources[idx] == u_id and | |
| full_data.destinations[idx] == i_id and | |
| full_data.timestamps[idx] == timestamp): | |
| edge_idx = idx | |
| break | |
| # Edge index for the specific source-destination pair | |
| edge_idx_array = np.array([0]) | |
| EDGE_FEAT_DIM = 172 # Or detect dynamically if needed | |
| if edge_idx is not None: | |
| edge_idx_array = np.array([int(edge_idx)]) # List of edge indices | |
| logging.info(f"Edge features shape: {edge_idx_array}") | |
| # Negative node sampling: choose random negative nodes for this example (should be handled in your dataset) | |
| total_nodes = len(full_data.unique_nodes) # or whatever gives your total node count | |
| random_negative_node = get_random_negative_node(u_id, i_id, total_nodes) | |
| negative_nodes_array = np.array([random_negative_node]) # List of negative nodes | |
| # Call compute_edge_probabilities | |
| # You can pass edge_idxs or dummy edge features here depending on how your TGN expects it | |
| try: | |
| positive_probs, negative_probs = tgn.compute_edge_probabilities( | |
| u_array, | |
| i_array, | |
| negative_nodes_array, | |
| ts_array, | |
| edge_idx_array | |
| ) | |
| except RuntimeError as e: | |
| # fallback in case edge_idxs=None causes internal error | |
| dummy_edge_features = np.zeros((1, EDGE_FEAT_DIM), dtype=np.float32) | |
| positive_probs, negative_probs = tgn.compute_edge_probabilities( | |
| u_array, | |
| i_array, | |
| negative_nodes_array, | |
| ts_array, | |
| dummy_edge_features # Your TGN must support this case | |
| ) | |
| # Now positive_probs is a tensor β use it to get top-k values | |
| topk = torch.topk(positive_probs, k=min(5, positive_probs.shape[0])) | |
| top_values = [v.item() for v in topk.values] # Flatten and convert each tensor to float | |
| # Format the top-k values | |
| #formatted_result = "\n".join([f"{i+1}. {val:.4f}" for i, val in enumerate(top_values)]) | |
| formatted_result = "\n".join([f"{val:.4f}" for i, val in enumerate(top_values)]) | |
| #return f"Top {len(top_values)} predicted interaction probabilities:\n{formatted_result}" | |
| return f"Predicted interaction probability: {formatted_result}" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π§ TGN Playground (Wikipedia)") | |
| gr.Markdown("Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).") | |
| with gr.Row(): | |
| src = gr.Number(label="Source Node ID") | |
| dst = gr.Number(label="Destination Node ID") | |
| ts = gr.Number(label="Timestamp") | |
| mem = gr.Checkbox(label="Reinitialize memory, don't update on prediction!") # π your checkbox | |
| predict_btn = gr.Button("Predict") | |
| result = gr.Textbox(label="Prediction Output") | |
| predict_btn.click(fn=predict, inputs=[src, dst, ts, mem], outputs=result) | |
| gr.Markdown("### Download Log File") | |
| log_file = gr.File(label="Log File", interactive=False) | |
| download_btn = gr.Button("Download Logs") | |
| download_btn.click(fn=download_log, inputs=[], outputs=log_file) | |
| demo.launch() | |