Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -111,9 +111,10 @@ def get_random_negative_node(source_node, destination_node, total_nodes):
|
|
111 |
return random.choice(candidates)
|
112 |
|
113 |
|
114 |
-
def predict(u_id, i_id, timestamp):
|
115 |
# Before prediction
|
116 |
-
|
|
|
117 |
|
118 |
# Then run prediction
|
119 |
u_array = np.array([int(u_id)]) # List of source nodes
|
@@ -183,12 +184,13 @@ with gr.Blocks() as demo:
|
|
183 |
with gr.Row():
|
184 |
src = gr.Number(label="Source Node ID")
|
185 |
dst = gr.Number(label="Destination Node ID")
|
186 |
-
ts
|
|
|
187 |
|
188 |
predict_btn = gr.Button("Predict")
|
189 |
result = gr.Textbox(label="Prediction Output")
|
190 |
|
191 |
-
predict_btn.click(fn=predict, inputs=[src, dst, ts], outputs=result)
|
192 |
|
193 |
gr.Markdown("### Download Log File")
|
194 |
log_file = gr.File(label="Log File", interactive=False)
|
|
|
111 |
return random.choice(candidates)
|
112 |
|
113 |
|
114 |
+
def predict(u_id, i_id, timestamp, mem):
|
115 |
# Before prediction
|
116 |
+
if mem:
|
117 |
+
tgn.memory.__init_memory__() # Re-initialize memory
|
118 |
|
119 |
# Then run prediction
|
120 |
u_array = np.array([int(u_id)]) # List of source nodes
|
|
|
184 |
with gr.Row():
|
185 |
src = gr.Number(label="Source Node ID")
|
186 |
dst = gr.Number(label="Destination Node ID")
|
187 |
+
ts = gr.Number(label="Timestamp")
|
188 |
+
mem = gr.Checkbox(label="Reinitialize memory, don't update on prediction!") # 👈 your checkbox
|
189 |
|
190 |
predict_btn = gr.Button("Predict")
|
191 |
result = gr.Textbox(label="Prediction Output")
|
192 |
|
193 |
+
predict_btn.click(fn=predict, inputs=[src, dst, ts, mem], outputs=result)
|
194 |
|
195 |
gr.Markdown("### Download Log File")
|
196 |
log_file = gr.File(label="Log File", interactive=False)
|