ashu316 commited on
Commit
155336d
·
verified ·
1 Parent(s): 41aae2b

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +53 -0
  2. requirements.txt +4 -0
  3. tgn-attn-wikipedia.pth +3 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+
5
+ from utils.data_processing import get_data
6
+ from modules.tgn import TGN
7
+
8
+ # Load data
9
+ node_features, edge_features, full_data, train_data, val_data, test_data, \
10
+ new_node_val_data, new_node_test_data = get_data('wikipedia')
11
+
12
+ # Initialize model (must match training args)
13
+ tgn = TGN(
14
+ raw_message_dim=edge_features.shape[1],
15
+ memory_dim=172,
16
+ message_dim=100,
17
+ memory_update_at_end=False,
18
+ embedding_module_type='graph_attention',
19
+ message_function='identity',
20
+ memory_updater_type='gru',
21
+ n_heads=2,
22
+ dropout=0.1,
23
+ n_layers=1,
24
+ n_neighbors=10,
25
+ aggregator_type='last',
26
+ edge_features=edge_features,
27
+ node_features=node_features,
28
+ device='cpu',
29
+ use_memory=True,
30
+ )
31
+ tgn.load_state_dict(torch.load("tgn-attn-wikipedia.pth", map_location="cpu"))
32
+ tgn.eval()
33
+
34
+ def predict(u_id, i_id, timestamp):
35
+ u = np.array([int(u_id)])
36
+ i = np.array([int(i_id)])
37
+ ts = np.array([float(timestamp)])
38
+ prob = tgn.predict_edge_probabilities(u, i, ts)
39
+ return f"Predicted interaction probability: {prob[0]:.4f}"
40
+
41
+ demo = gr.Interface(
42
+ fn=predict,
43
+ inputs=[
44
+ gr.Number(label="Source Node ID"),
45
+ gr.Number(label="Destination Node ID"),
46
+ gr.Number(label="Timestamp"),
47
+ ],
48
+ outputs="text",
49
+ title="🧠 TGN Playground (Wikipedia)",
50
+ description="Enter node IDs and timestamp to predict future interaction probability using Temporal Graph Networks (TGN).",
51
+ )
52
+
53
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ pandas
4
+ gradio
tgn-attn-wikipedia.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e16adc3ca08a3217d24a7bda96d52623c15e52e7dc20378a7e1e7ce9f9a2eda
3
+ size 8108482