bartduis commited on
Commit
993f92c
·
verified ·
1 Parent(s): 8a8835f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -30
app.py CHANGED
@@ -31,6 +31,12 @@ moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device)
31
  dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg")
32
  dino_model.eval()
33
  dino_model.to(device)
 
 
 
 
 
 
34
 
35
 
36
 
@@ -106,13 +112,6 @@ def prep_for_rayst3r(img,depth_dict,mask):
106
  @GPU(duration = 180)
107
  def rayst3r_to_glb(img,depth_dict,mask,max_total_points=10e6,rotated=False):
108
  prep_for_rayst3r(img,depth_dict,mask)
109
- print('Doneneee')
110
-
111
- print("Loading RaySt3R model")
112
- rayst3r_checkpoint = hf_hub_download("bartduis/rayst3r", "rayst3r.pth")
113
- rayst3r_model = EvalWrapper(rayst3r_checkpoint,device='cpu')
114
- rayst3r_model = rayst3r_model.to(device)
115
- print("Loaded rayst3r_model")
116
 
117
  rayst3r_points = eval_scene(rayst3r_model,os.path.join(outdir, "input"),do_filter_all_masks=True,dino_model=dino_model, device = device).cpu()
118
 
@@ -202,33 +201,10 @@ def process_image(input_img):
202
  shutil.rmtree(outdir)
203
  os.makedirs(outdir)
204
  input_glb = input_to_glb(outdir,input_img,depth_dict,mask,rotated=rotated)
205
- print('Input done')
206
- print('calling Ray')
207
  inference_glb = rayst3r_to_glb(input_img,depth_dict,mask,rotated=rotated)
208
  # print(input_glb)
209
  return input_glb, inference_glb
210
 
211
- # def process_image(input_img):
212
- # # resize the input image
213
- # rotated = False
214
- # #if input_img.shape[0] > input_img.shape[1]:
215
- # #input_img = cv2.rotate(input_img, cv2.ROTATE_90_COUNTERCLOCKWISE)
216
- # #rotated = True
217
- # input_img = cv2.resize(input_img, (640, 480))
218
- # # mask, rgb = mask_rembg(input_img)
219
- # # depth_dict = depth_moge(input_img)
220
-
221
- # # if os.path.exists(outdir):
222
- # # shutil.rmtree(outdir)
223
- # # os.makedirs(outdir)
224
-
225
- # # input_glb = input_to_glb(outdir,input_img,depth_dict,mask,rotated=rotated)
226
-
227
- # # # visualize the input points in 3D in gradio
228
- # # inference_glb = rayst3r_to_glb(input_img,depth_dict,mask,rotated=rotated)
229
-
230
- # return input_img, input_img
231
-
232
  demo = gr.Interface(
233
  process_image,
234
  gr.Image(),
 
31
  dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg")
32
  dino_model.eval()
33
  dino_model.to(device)
34
+
35
+ print("Loading RaySt3R model")
36
+ rayst3r_checkpoint = hf_hub_download("bartduis/rayst3r", "rayst3r.pth")
37
+ rayst3r_model = EvalWrapper(rayst3r_checkpoint,device='cpu')
38
+ rayst3r_model = rayst3r_model.to(device)
39
+ print("Loaded rayst3r_model")
40
 
41
 
42
 
 
112
  @GPU(duration = 180)
113
  def rayst3r_to_glb(img,depth_dict,mask,max_total_points=10e6,rotated=False):
114
  prep_for_rayst3r(img,depth_dict,mask)
 
 
 
 
 
 
 
115
 
116
  rayst3r_points = eval_scene(rayst3r_model,os.path.join(outdir, "input"),do_filter_all_masks=True,dino_model=dino_model, device = device).cpu()
117
 
 
201
  shutil.rmtree(outdir)
202
  os.makedirs(outdir)
203
  input_glb = input_to_glb(outdir,input_img,depth_dict,mask,rotated=rotated)
 
 
204
  inference_glb = rayst3r_to_glb(input_img,depth_dict,mask,rotated=rotated)
205
  # print(input_glb)
206
  return input_glb, inference_glb
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  demo = gr.Interface(
209
  process_image,
210
  gr.Image(),