zhemai28 commited on
Commit
9c2a2dc
·
1 Parent(s): 828e042
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -72,6 +72,9 @@ def format_prompt_points(points):
72
  prompt_points.append([point[0], point[1]])
73
  label = 1 if point[2] == 1.0 else 0
74
  point_labels.append(label)
 
 
 
75
  return prompt_points, point_labels, prompt_boxes
76
 
77
  def segment_with_points(
@@ -85,9 +88,9 @@ def segment_with_points(
85
  print(prompt_points, point_labels, prompt_boxes)
86
  # segment
87
  inputs = processor(image,
88
- input_boxes = [prompt_boxes],
89
- input_points=[[prompt_points]],
90
- input_labels=[point_labels],
91
  return_tensors="pt").to(device)
92
  with torch.no_grad():
93
  outputs = model(**inputs)
@@ -140,4 +143,4 @@ with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
140
  clear_btn.click(clear, outputs=[candidate_pic, segpic_output1, segpic_output2, segpic_output3])
141
 
142
  demo.queue()
143
- demo.launch(share=True)
 
72
  prompt_points.append([point[0], point[1]])
73
  label = 1 if point[2] == 1.0 else 0
74
  point_labels.append(label)
75
+ prompt_points = [[prompt_points]] if len(prompt_points) > 0 else None
76
+ point_labels = [point_labels] if len(point_labels) > 0 else None
77
+ prompt_boxes = [prompt_boxes] if len(prompt_boxes) > 0 else None
78
  return prompt_points, point_labels, prompt_boxes
79
 
80
  def segment_with_points(
 
88
  print(prompt_points, point_labels, prompt_boxes)
89
  # segment
90
  inputs = processor(image,
91
+ input_boxes = prompt_boxes,
92
+ input_points=prompt_points,
93
+ input_labels=point_labels,
94
  return_tensors="pt").to(device)
95
  with torch.no_grad():
96
  outputs = model(**inputs)
 
143
  clear_btn.click(clear, outputs=[candidate_pic, segpic_output1, segpic_output2, segpic_output3])
144
 
145
  demo.queue()
146
+ demo.launch()