Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from utils.data_processing import get_data | |
from utils.utils import NeighborFinder | |
from modules.tgn import TGN | |
import logging | |
import random | |
import os | |
logging.basicConfig(filename="debug.log", level=logging.INFO) | |
def download_log(): | |
return "debug.log" if os.path.exists("debug.log") else None | |
# Load data | |
node_features, edge_features, full_data, train_data, val_data, test_data, \ | |
new_node_val_data, new_node_test_data = get_data('wikipedia') | |
try: | |
logging.info(f"Full data dict: {vars(full_data)}") | |
except TypeError: | |
logging.info("vars(full_data) failed. Trying manual extraction...") | |
logging.info(f"Sources: {full_data.sources[:10]}") | |
logging.info(f"Timestamps: {full_data.timestamps[:10]}") | |
# Extract edge information | |
edge_sources = full_data.sources | |
edge_destinations = full_data.destinations | |
edge_timestamps = full_data.timestamps | |
edge_idxs = full_data.edge_idxs | |
# Construct adjacency list | |
all_nodes = list(full_data.sources) + list(full_data.destinations) | |
n_nodes = max(all_nodes) + 1 # assuming node IDs are 0-indexed | |
adj_list = [[] for _ in range(n_nodes)] | |
for src, dst, ts, eidx in zip(edge_sources, edge_destinations, edge_timestamps, edge_idxs): | |
adj_list[src].append((dst, eidx, ts)) | |
adj_list[dst].append((src, eidx, ts)) # If the graph is undirected | |
# Initialize the neighbor finder | |
nf = NeighborFinder(adj_list=adj_list, uniform=False) | |
# Initialize model (must match training args) | |
tgn = TGN( | |
neighbor_finder=nf, | |
node_features=node_features, | |
edge_features=edge_features, | |
device='cpu', | |
n_layers=1, | |
n_heads=2, | |
dropout=0.1, | |
use_memory=True, | |
message_dimension=100, | |
memory_dimension=172, | |
memory_update_at_start=True, # Should match training | |
embedding_module_type='graph_attention', | |
message_function='identity', | |
aggregator_type='last', | |
memory_updater_type='gru', | |
n_neighbors=10, | |
use_destination_embedding_in_message=False, | |
use_source_embedding_in_message=False, | |
dyrep=False | |
) | |
# Initialize memory buffers with the correct dimensions | |
new_num_nodes = 9228 # Current number of nodes | |
memory_dim = 172 # Your defined memory dimension | |
# Manually reinitialize memory buffers to the correct size as Parameters | |
tgn.memory.memory = nn.Parameter(torch.randn(new_num_nodes, memory_dim)) | |
tgn.memory.last_update = nn.Parameter(torch.zeros(new_num_nodes)) | |
tgn.memory_updater.memory.memory = nn.Parameter(torch.randn(new_num_nodes, memory_dim)) | |
tgn.memory_updater.memory.last_update = nn.Parameter(torch.zeros(new_num_nodes)) | |
# Load the state_dict from the pretrained model | |
state_dict = torch.load("tgn-attn-wikipedia.pth", map_location="cpu") | |
# Remove any conflicting memory-related parameters | |
state_dict.pop("memory.memory", None) | |
state_dict.pop("memory.last_update", None) | |
state_dict.pop("memory_updater.memory.memory", None) | |
state_dict.pop("memory_updater.memory.last_update", None) | |
# Now load the state_dict into the model (strict=False to ignore the missing memory buffers) | |
tgn.load_state_dict(state_dict, strict=False) | |
tgn.eval() | |
def predict_old(u_id, i_id, timestamp): | |
u = np.array([int(u_id)]) | |
i = np.array([int(i_id)]) | |
ts = np.array([float(timestamp)]) | |
#prob = tgn.predict_edge_probabilities(u, i, ts) | |
prob = tgn.compute_edge_probabilities(u, i, ts) | |
return f"Predicted interaction probability: {prob[0]:.4f}" | |
def get_random_negative_node(source_node, destination_node, total_nodes): | |
candidates = list(range(total_nodes)) | |
candidates.remove(destination_node) | |
if source_node in candidates: | |
candidates.remove(source_node) # optional | |
return random.choice(candidates) | |
def predict(u_id, i_id, timestamp, mem): | |
# Before prediction | |
if mem: | |
tgn.memory.__init_memory__() # Re-initialize memory | |
# Then run prediction | |
u_array = np.array([int(u_id)]) # List of source nodes | |
i_array = np.array([int(i_id)]) # List of destination nodes | |
ts_array = np.array([int(timestamp)]) # List of timestamps | |
# Look for matching entry in full_data to get the edge_idx | |
edge_idx = None | |
for idx in range(len(full_data.sources)): | |
if (full_data.sources[idx] == u_id and | |
full_data.destinations[idx] == i_id and | |
full_data.timestamps[idx] == timestamp): | |
edge_idx = idx | |
break | |
# Edge index for the specific source-destination pair | |
edge_idx_array = np.array([0]) | |
EDGE_FEAT_DIM = 172 # Or detect dynamically if needed | |
if edge_idx is not None: | |
edge_idx_array = np.array([int(edge_idx)]) # List of edge indices | |
logging.info(f"Edge features shape: {edge_idx_array}") | |
# Negative node sampling: choose random negative nodes for this example (should be handled in your dataset) | |
total_nodes = len(full_data.unique_nodes) # or whatever gives your total node count | |
random_negative_node = get_random_negative_node(u_id, i_id, total_nodes) | |
negative_nodes_array = np.array([random_negative_node]) # List of negative nodes | |
# Call compute_edge_probabilities | |
# You can pass edge_idxs or dummy edge features here depending on how your TGN expects it | |
try: | |
positive_probs, negative_probs = tgn.compute_edge_probabilities( | |
u_array, | |
i_array, | |
negative_nodes_array, | |
ts_array, | |
edge_idx_array | |
) | |
except RuntimeError as e: | |
# fallback in case edge_idxs=None causes internal error | |
dummy_edge_features = np.zeros((1, EDGE_FEAT_DIM), dtype=np.float32) | |
positive_probs, negative_probs = tgn.compute_edge_probabilities( | |
u_array, | |
i_array, | |
negative_nodes_array, | |
ts_array, | |
dummy_edge_features # Your TGN must support this case | |
) | |
# Now positive_probs is a tensor β use it to get top-k values | |
topk = torch.topk(positive_probs, k=min(5, positive_probs.shape[0])) | |
top_values = [v.item() for v in topk.values] # Flatten and convert each tensor to float | |
# Format the top-k values | |
#formatted_result = "\n".join([f"{i+1}. {val:.4f}" for i, val in enumerate(top_values)]) | |
formatted_result = "\n".join([f"{val:.4f}" for i, val in enumerate(top_values)]) | |
#return f"Top {len(top_values)} predicted interaction probabilities:\n{formatted_result}" | |
return f"Predicted interaction probability: {formatted_result}" | |
with gr.Blocks() as demo: | |
gr.Markdown("## π§ TGN Playground (Wikipedia)") | |
gr.Markdown("Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).") | |
with gr.Row(): | |
src = gr.Number(label="Source Node ID") | |
dst = gr.Number(label="Destination Node ID") | |
ts = gr.Number(label="Timestamp") | |
mem = gr.Checkbox(label="Reinitialize memory, don't update on prediction!") # π your checkbox | |
predict_btn = gr.Button("Predict") | |
result = gr.Textbox(label="Prediction Output") | |
predict_btn.click(fn=predict, inputs=[src, dst, ts, mem], outputs=result) | |
gr.Markdown("### Download Log File") | |
log_file = gr.File(label="Log File", interactive=False) | |
download_btn = gr.Button("Download Logs") | |
download_btn.click(fn=download_log, inputs=[], outputs=log_file) | |
demo.launch() | |