veb-101 commited on
Commit
510fd79
·
1 Parent(s): b6db8ed

test run -4

Browse files
Files changed (2) hide show
  1. app.py +24 -32
  2. requirements.txt +2 -1
app.py CHANGED
@@ -44,14 +44,6 @@ def predict(input_image, model=None, preprocess_fn=None, device="cpu"):
44
 
45
 
46
  if __name__ == "__main__":
47
- # Create a mapping of class ID to RGB value.
48
- id2color = {
49
- 0: (0, 0, 0), # background pixel
50
- 1: (0, 0, 255), # Stomach
51
- 2: (0, 255, 0), # Small bowel
52
- 3: (255, 0, 0), # large bowel
53
- }
54
-
55
  class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
56
 
57
  DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
@@ -72,29 +64,29 @@ if __name__ == "__main__":
72
  ]
73
  )
74
 
75
- # images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
76
- # examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
77
- # demo = gr.Interface(
78
- # fn=partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE),
79
- # inputs=gr.Image(type="pil", height=300, width=300, label="Input image"),
80
- # outputs=gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor),
81
- # examples=examples,
82
- # cache_examples=False,
83
- # allow_flagging="never",
84
- # title="Medical Image Segmentation with UW-Madison GI Tract Dataset",
85
- # )
86
-
87
- with gr.Blocks(title="Medical Image Segmentation") as demo:
88
- gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
89
- with gr.Row():
90
- img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
91
- img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
92
-
93
- section_btn = gr.Button("Generate Predictions")
94
- section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
95
-
96
- images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
97
- examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
98
- gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
99
 
100
  demo.launch()
 
44
 
45
 
46
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
47
  class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
48
 
49
  DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
 
64
  ]
65
  )
66
 
67
+ images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
68
+ examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
69
+ demo = gr.Interface(
70
+ fn=partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE),
71
+ inputs=gr.Image(type="pil", height=300, width=300, label="Input image"),
72
+ outputs=gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor),
73
+ examples=examples,
74
+ cache_examples=False,
75
+ allow_flagging="never",
76
+ title="Medical Image Segmentation with UW-Madison GI Tract Dataset",
77
+ )
78
+
79
+ # with gr.Blocks(title="Medical Image Segmentation") as demo:
80
+ # gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
81
+ # with gr.Row():
82
+ # img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
83
+ # img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
84
+
85
+ # section_btn = gr.Button("Generate Predictions")
86
+ # section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
87
+
88
+ # images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
89
+ # examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
90
+ # gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
91
 
92
  demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  --find-links https://download.pytorch.org/whl/torch_stable.html
2
  torch==2.0.0+cpu
3
  torchvision==0.15.0
4
- transformers==4.30.2
 
 
1
  --find-links https://download.pytorch.org/whl/torch_stable.html
2
  torch==2.0.0+cpu
3
  torchvision==0.15.0
4
+ transformers==4.30.2
5
+ gradio