Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -129,12 +129,13 @@ def format_bpe_display(bpe):
|
|
| 129 |
# return current_vis[x], format_bpe_display(current_bpe[x])
|
| 130 |
# else:
|
| 131 |
# return None, "索引超出范围"
|
| 132 |
-
|
| 133 |
-
def update_slider_index(x):
|
| 134 |
-
if 0 <= x < len(
|
| 135 |
-
return
|
| 136 |
else:
|
| 137 |
-
return None, "索引超出范围"
|
|
|
|
| 138 |
|
| 139 |
# Gradio界面
|
| 140 |
with gr.Blocks(title="BPE Visualization Demo") as demo:
|
|
@@ -188,26 +189,38 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
|
|
| 188 |
# print("current_vis",len(current_vis))
|
| 189 |
# print("current_bpe",len(current_bpe))
|
| 190 |
# return image, vis, bpe_text, slider_max_val
|
|
|
|
| 191 |
@spaces.GPU
|
| 192 |
def on_run_clicked(model_type, image, text):
|
| 193 |
-
global current_vis, current_bpe, current_index
|
| 194 |
-
current_index = 0 # Reset index when new image is processed
|
| 195 |
model, tokenizer, transform, device = load_model(model_type)
|
|
|
|
| 196 |
image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
run_btn.click(
|
| 204 |
on_run_clicked,
|
| 205 |
inputs=[model_type, image_input, text_input],
|
| 206 |
-
outputs=[orig_img, heatmap, bpe_display, index_slider]
|
| 207 |
).then(
|
| 208 |
-
lambda
|
| 209 |
inputs=index_slider,
|
| 210 |
-
outputs=[prev_btn, index_slider, next_btn, bpe_display]
|
| 211 |
)
|
| 212 |
|
| 213 |
prev_btn.click(
|
|
@@ -219,12 +232,7 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
|
|
| 219 |
lambda: (*update_index(1), current_index),
|
| 220 |
outputs=[heatmap, bpe_display, index_slider]
|
| 221 |
)
|
| 222 |
-
|
| 223 |
-
# index_slider.change(
|
| 224 |
-
# lambda x: (current_vis[x], format_bpe_display(current_bpe[x])) if 0<=x<len(current_vis else (None,"Invaild")
|
| 225 |
-
# inputs=index_slider,
|
| 226 |
-
# outputs=[heatmap, bpe_display]
|
| 227 |
-
# )
|
| 228 |
|
| 229 |
# index_slider.change(
|
| 230 |
# update_slider_index,
|
|
@@ -232,9 +240,10 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
|
|
| 232 |
# outputs=[heatmap, bpe_display]
|
| 233 |
# )
|
| 234 |
index_slider.change(
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
| 238 |
|
| 239 |
|
| 240 |
if __name__ == "__main__":
|
|
|
|
| 129 |
# return current_vis[x], format_bpe_display(current_bpe[x])
|
| 130 |
# else:
|
| 131 |
# return None, "索引超出范围"
|
| 132 |
+
# 状态更新函数,利用传递的状态(vis, bpe)
|
| 133 |
+
def update_slider_index(x, vis, bpe):
|
| 134 |
+
if 0 <= x < len(vis):
|
| 135 |
+
return vis[x], format_bpe_display(bpe[x]), vis, bpe
|
| 136 |
else:
|
| 137 |
+
return None, "索引超出范围", vis, bpe
|
| 138 |
+
|
| 139 |
|
| 140 |
# Gradio界面
|
| 141 |
with gr.Blocks(title="BPE Visualization Demo") as demo:
|
|
|
|
| 189 |
# print("current_vis",len(current_vis))
|
| 190 |
# print("current_bpe",len(current_bpe))
|
| 191 |
# return image, vis, bpe_text, slider_max_val
|
| 192 |
+
|
| 193 |
@spaces.GPU
|
| 194 |
def on_run_clicked(model_type, image, text):
|
|
|
|
|
|
|
| 195 |
model, tokenizer, transform, device = load_model(model_type)
|
| 196 |
+
current_index = 0 # Reset index when new image is processed
|
| 197 |
image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
|
| 198 |
+
slider_max_val = len(bpe) - 1
|
| 199 |
+
bpe_text = format_bpe_display(bpe[current_index])
|
| 200 |
+
# 将处理结果传递给后续步骤
|
| 201 |
+
return image, vis[current_index], bpe_text, slider_max_val, vis, bpe
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
|
| 206 |
+
# run_btn.click(
|
| 207 |
+
# on_run_clicked,
|
| 208 |
+
# inputs=[model_type, image_input, text_input],
|
| 209 |
+
# outputs=[orig_img, heatmap, bpe_display, index_slider],
|
| 210 |
+
# ).then(
|
| 211 |
+
# lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
|
| 212 |
+
# inputs=index_slider,
|
| 213 |
+
# outputs=[prev_btn, index_slider, next_btn, bpe_display],
|
| 214 |
+
# )
|
| 215 |
+
# Gradio 按钮点击后的处理
|
| 216 |
run_btn.click(
|
| 217 |
on_run_clicked,
|
| 218 |
inputs=[model_type, image_input, text_input],
|
| 219 |
+
outputs=[orig_img, heatmap, bpe_display, index_slider, 'state', 'state']
|
| 220 |
).then(
|
| 221 |
+
lambda outputs: (gr.update(visible=True), gr.update(visible=True, maximum=outputs[3], value=0), gr.update(visible=True), gr.update(visible=True), outputs[4], outputs[5]),
|
| 222 |
inputs=index_slider,
|
| 223 |
+
outputs=[prev_btn, index_slider, next_btn, bpe_display, 'state', 'state']
|
| 224 |
)
|
| 225 |
|
| 226 |
prev_btn.click(
|
|
|
|
| 232 |
lambda: (*update_index(1), current_index),
|
| 233 |
outputs=[heatmap, bpe_display, index_slider]
|
| 234 |
)
|
| 235 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
# index_slider.change(
|
| 238 |
# update_slider_index,
|
|
|
|
| 240 |
# outputs=[heatmap, bpe_display]
|
| 241 |
# )
|
| 242 |
index_slider.change(
|
| 243 |
+
update_slider_index,
|
| 244 |
+
inputs=[index_slider, 'state', 'state'],
|
| 245 |
+
outputs=[heatmap, bpe_display, 'state', 'state']
|
| 246 |
+
)
|
| 247 |
|
| 248 |
|
| 249 |
if __name__ == "__main__":
|