ashu316 commited on
Commit
d4a76d9
·
verified ·
1 Parent(s): 9795a60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -111,9 +111,10 @@ def get_random_negative_node(source_node, destination_node, total_nodes):
111
  return random.choice(candidates)
112
 
113
 
114
- def predict(u_id, i_id, timestamp):
115
  # Before prediction
116
- tgn.memory.__init_memory__() # Re-initialize memory
 
117
 
118
  # Then run prediction
119
  u_array = np.array([int(u_id)]) # List of source nodes
@@ -183,12 +184,13 @@ with gr.Blocks() as demo:
183
  with gr.Row():
184
  src = gr.Number(label="Source Node ID")
185
  dst = gr.Number(label="Destination Node ID")
186
- ts = gr.Number(label="Timestamp")
 
187
 
188
  predict_btn = gr.Button("Predict")
189
  result = gr.Textbox(label="Prediction Output")
190
 
191
- predict_btn.click(fn=predict, inputs=[src, dst, ts], outputs=result)
192
 
193
  gr.Markdown("### Download Log File")
194
  log_file = gr.File(label="Log File", interactive=False)
 
111
  return random.choice(candidates)
112
 
113
 
114
+ def predict(u_id, i_id, timestamp, mem):
115
  # Before prediction
116
+ if mem:
117
+ tgn.memory.__init_memory__() # Re-initialize memory
118
 
119
  # Then run prediction
120
  u_array = np.array([int(u_id)]) # List of source nodes
 
184
  with gr.Row():
185
  src = gr.Number(label="Source Node ID")
186
  dst = gr.Number(label="Destination Node ID")
187
+ ts = gr.Number(label="Timestamp")
188
+ mem = gr.Checkbox(label="Reinitialize memory, don't update on prediction!") # 👈 your checkbox
189
 
190
  predict_btn = gr.Button("Predict")
191
  result = gr.Textbox(label="Prediction Output")
192
 
193
+ predict_btn.click(fn=predict, inputs=[src, dst, ts, mem], outputs=result)
194
 
195
  gr.Markdown("### Download Log File")
196
  log_file = gr.File(label="Log File", interactive=False)