Spaces:
Runtime error
Runtime error
liuyizhang
commited on
Commit
Β·
2a71ebd
1
Parent(s):
5c28041
add time cost by step (ms)
Browse files- app.py +40 -12
- kosmos_utils.py +1 -1
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -519,24 +519,42 @@ def relate_anything(input_image, k):
|
|
| 519 |
mask_source_draw = "draw a mask on input image"
|
| 520 |
mask_source_segment = "type what to detect below"
|
| 521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
| 523 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
if (task_type == 'Kosmos-2'):
|
| 525 |
global kosmos_model, kosmos_processor
|
| 526 |
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(input_image, kosmos_input, kosmos_model, kosmos_processor)
|
| 527 |
-
|
|
|
|
| 528 |
|
| 529 |
if (task_type == 'relate anything'):
|
| 530 |
output_images = relate_anything(input_image['image'], num_relation)
|
| 531 |
-
|
|
|
|
| 532 |
|
| 533 |
text_prompt = text_prompt.strip()
|
| 534 |
if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
| 535 |
if text_prompt == '':
|
| 536 |
-
return [], gr.Gallery.update(label='Detection prompt is not found!ππππ'), None, None, None
|
| 537 |
|
| 538 |
if input_image is None:
|
| 539 |
-
return [], gr.Gallery.update(label='Please upload a image!ππππ'), None, None, None
|
| 540 |
|
| 541 |
file_temp = int(time.time())
|
| 542 |
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
|
|
@@ -552,10 +570,12 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 552 |
image_pil, image = load_image(input_image['image'].convert("RGB"))
|
| 553 |
input_img = input_image['image']
|
| 554 |
output_images.append(input_image['image'])
|
|
|
|
| 555 |
else:
|
| 556 |
image_pil, image = load_image(input_image.convert("RGB"))
|
| 557 |
input_img = input_image
|
| 558 |
output_images.append(input_image)
|
|
|
|
| 559 |
|
| 560 |
size = image_pil.size
|
| 561 |
|
|
@@ -576,7 +596,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 576 |
)
|
| 577 |
if boxes_filt.size(0) == 0:
|
| 578 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
|
| 579 |
-
return [], gr.Gallery.update(label='No objects detected, please try others.ππππ'), None, None, None
|
| 580 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
| 581 |
|
| 582 |
pred_dict = {
|
|
@@ -587,6 +607,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 587 |
|
| 588 |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
| 589 |
output_images.append(image_with_box)
|
|
|
|
| 590 |
|
| 591 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
| 592 |
if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
|
@@ -622,12 +643,13 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 622 |
plt.savefig(image_path, bbox_inches="tight")
|
| 623 |
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 624 |
os.remove(image_path)
|
| 625 |
-
output_images.append(segment_image_result)
|
|
|
|
| 626 |
|
| 627 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
| 628 |
if task_type == 'detection' or task_type == 'segment':
|
| 629 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 630 |
-
return output_images, gr.Gallery.update(label='result images'), None, None, None
|
| 631 |
elif task_type == 'inpainting' or task_type == 'remove':
|
| 632 |
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
| 633 |
task_type = 'remove'
|
|
@@ -644,6 +666,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 644 |
mask = masks[0][0].cpu().numpy()
|
| 645 |
mask_pil = Image.fromarray(mask)
|
| 646 |
output_images.append(mask_pil.convert("RGB"))
|
|
|
|
| 647 |
|
| 648 |
if task_type == 'inpainting':
|
| 649 |
# inpainting pipeline
|
|
@@ -682,21 +705,24 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 682 |
extend_pixels=remove_mask_extend, useRectangle=useRectangle)
|
| 683 |
mask_imgs.append(mask_pil_exp)
|
| 684 |
mask_pil = mix_masks(mask_imgs)
|
| 685 |
-
output_images.append(mask_pil.convert("RGB"))
|
|
|
|
| 686 |
|
| 687 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
| 688 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
|
| 689 |
# output_images.append(image_inpainting)
|
|
|
|
| 690 |
|
| 691 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
|
| 692 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
| 693 |
output_images.append(image_inpainting)
|
|
|
|
| 694 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 695 |
-
return output_images, gr.Gallery.update(label='result images'), None, None, None
|
| 696 |
else:
|
| 697 |
logger.info(f"task_type:{task_type} error!")
|
| 698 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
| 699 |
-
return output_images, gr.Gallery.update(label='result images'), None, None, None
|
| 700 |
|
| 701 |
def change_radio_display(task_type, mask_source_radio):
|
| 702 |
text_prompt_visible = True
|
|
@@ -828,7 +854,9 @@ if __name__ == "__main__":
|
|
| 828 |
|
| 829 |
with gr.Column():
|
| 830 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
|
| 831 |
-
).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
|
|
|
|
|
|
|
| 832 |
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
|
| 833 |
kosmos_text_output = gr.HighlightedText(
|
| 834 |
label="Generated Description",
|
|
@@ -860,7 +888,7 @@ if __name__ == "__main__":
|
|
| 860 |
run_button.click(fn=run_anything_task, inputs=[
|
| 861 |
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
| 862 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
|
| 863 |
-
outputs=[image_gallery, image_gallery, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
| 864 |
|
| 865 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
|
| 866 |
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
|
|
|
| 519 |
mask_source_draw = "draw a mask on input image"
|
| 520 |
mask_source_segment = "type what to detect below"
|
| 521 |
|
| 522 |
+
def get_time_cost(run_task_time, time_cost_str):
|
| 523 |
+
now_time = int(time.time()*1000)
|
| 524 |
+
if run_task_time == 0:
|
| 525 |
+
time_cost_str = 'start'
|
| 526 |
+
else:
|
| 527 |
+
if time_cost_str != '':
|
| 528 |
+
time_cost_str += f'-->'
|
| 529 |
+
time_cost_str += f'{now_time - run_task_time}'
|
| 530 |
+
run_task_time = now_time
|
| 531 |
+
return run_task_time, time_cost_str
|
| 532 |
+
|
| 533 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
| 534 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
| 535 |
+
|
| 536 |
+
run_task_time = 0
|
| 537 |
+
time_cost_str = ''
|
| 538 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 539 |
+
|
| 540 |
if (task_type == 'Kosmos-2'):
|
| 541 |
global kosmos_model, kosmos_processor
|
| 542 |
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(input_image, kosmos_input, kosmos_model, kosmos_processor)
|
| 543 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 544 |
+
return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
| 545 |
|
| 546 |
if (task_type == 'relate anything'):
|
| 547 |
output_images = relate_anything(input_image['image'], num_relation)
|
| 548 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 549 |
+
return output_images, gr.Gallery.update(label='relate images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 550 |
|
| 551 |
text_prompt = text_prompt.strip()
|
| 552 |
if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
| 553 |
if text_prompt == '':
|
| 554 |
+
return [], gr.Gallery.update(label='Detection prompt is not found!ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 555 |
|
| 556 |
if input_image is None:
|
| 557 |
+
return [], gr.Gallery.update(label='Please upload a image!ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 558 |
|
| 559 |
file_temp = int(time.time())
|
| 560 |
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
|
|
|
|
| 570 |
image_pil, image = load_image(input_image['image'].convert("RGB"))
|
| 571 |
input_img = input_image['image']
|
| 572 |
output_images.append(input_image['image'])
|
| 573 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 574 |
else:
|
| 575 |
image_pil, image = load_image(input_image.convert("RGB"))
|
| 576 |
input_img = input_image
|
| 577 |
output_images.append(input_image)
|
| 578 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 579 |
|
| 580 |
size = image_pil.size
|
| 581 |
|
|
|
|
| 596 |
)
|
| 597 |
if boxes_filt.size(0) == 0:
|
| 598 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
|
| 599 |
+
return [], gr.Gallery.update(label='No objects detected, please try others.ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 600 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
| 601 |
|
| 602 |
pred_dict = {
|
|
|
|
| 607 |
|
| 608 |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
| 609 |
output_images.append(image_with_box)
|
| 610 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 611 |
|
| 612 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
| 613 |
if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
|
|
|
| 643 |
plt.savefig(image_path, bbox_inches="tight")
|
| 644 |
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 645 |
os.remove(image_path)
|
| 646 |
+
output_images.append(segment_image_result)
|
| 647 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 648 |
|
| 649 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
| 650 |
if task_type == 'detection' or task_type == 'segment':
|
| 651 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 652 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 653 |
elif task_type == 'inpainting' or task_type == 'remove':
|
| 654 |
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
| 655 |
task_type = 'remove'
|
|
|
|
| 666 |
mask = masks[0][0].cpu().numpy()
|
| 667 |
mask_pil = Image.fromarray(mask)
|
| 668 |
output_images.append(mask_pil.convert("RGB"))
|
| 669 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 670 |
|
| 671 |
if task_type == 'inpainting':
|
| 672 |
# inpainting pipeline
|
|
|
|
| 705 |
extend_pixels=remove_mask_extend, useRectangle=useRectangle)
|
| 706 |
mask_imgs.append(mask_pil_exp)
|
| 707 |
mask_pil = mix_masks(mask_imgs)
|
| 708 |
+
output_images.append(mask_pil.convert("RGB"))
|
| 709 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 710 |
|
| 711 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
| 712 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
|
| 713 |
# output_images.append(image_inpainting)
|
| 714 |
+
# run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 715 |
|
| 716 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
|
| 717 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
| 718 |
output_images.append(image_inpainting)
|
| 719 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 720 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 721 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 722 |
else:
|
| 723 |
logger.info(f"task_type:{task_type} error!")
|
| 724 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
| 725 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 726 |
|
| 727 |
def change_radio_display(task_type, mask_source_radio):
|
| 728 |
text_prompt_visible = True
|
|
|
|
| 854 |
|
| 855 |
with gr.Column():
|
| 856 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
|
| 857 |
+
).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
|
| 858 |
+
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
|
| 859 |
+
|
| 860 |
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
|
| 861 |
kosmos_text_output = gr.HighlightedText(
|
| 862 |
label="Generated Description",
|
|
|
|
| 888 |
run_button.click(fn=run_anything_task, inputs=[
|
| 889 |
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
| 890 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
|
| 891 |
+
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
| 892 |
|
| 893 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
|
| 894 |
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
kosmos_utils.py
CHANGED
|
@@ -230,4 +230,4 @@ def kosmos_generate_predictions(image_input, text_input, kosmos_model, kosmos_pr
|
|
| 230 |
if end < len(processed_text):
|
| 231 |
colored_text.append((processed_text[end:len(processed_text)], None))
|
| 232 |
|
| 233 |
-
return annotated_image, colored_text, str(filtered_entities)
|
|
|
|
| 230 |
if end < len(processed_text):
|
| 231 |
colored_text.append((processed_text[end:len(processed_text)], None))
|
| 232 |
|
| 233 |
+
return annotated_image, colored_text, str(filtered_entities)
|
requirements.txt
CHANGED
|
@@ -17,7 +17,7 @@ termcolor
|
|
| 17 |
timm
|
| 18 |
torch
|
| 19 |
torchvision
|
| 20 |
-
transformers
|
| 21 |
yapf
|
| 22 |
numba
|
| 23 |
scipy
|
|
|
|
| 17 |
timm
|
| 18 |
torch
|
| 19 |
torchvision
|
| 20 |
+
transformers==4.27.4
|
| 21 |
yapf
|
| 22 |
numba
|
| 23 |
scipy
|