Stable-X commited on
Commit
654565f
·
verified ·
1 Parent(s): 2ada28d

Update trellis/pipelines/trellis_image_to_3d.py

Browse files
trellis/pipelines/trellis_image_to_3d.py CHANGED
@@ -232,17 +232,14 @@ class TrellisImageTo3DPipeline(Pipeline):
232
  if scale < 1:
233
  input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
234
 
235
- # # Get mask using BiRefNet
236
- # mask = self._get_birefnet_mask(input)
237
 
238
- # # Convert input to RGBA and apply mask
239
- # input_rgba = input.convert('RGBA')
240
- # input_array = np.array(input_rgba)
241
- # input_array[:, :, 3] = mask * 255 # Apply mask to alpha channel
242
- # output = Image.fromarray(input_array)
243
- if getattr(self, 'rembg_session', None) is None:
244
- self.rembg_session = rembg.new_session('u2net')
245
- output = rembg.remove(input, session=self.rembg_session)
246
 
247
  # Process the output image
248
  output_np = np.array(output)
@@ -341,7 +338,7 @@ class TrellisImageTo3DPipeline(Pipeline):
341
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
342
  ])
343
 
344
- input_images = transform_image(image).unsqueeze(0).cpu()
345
 
346
  with torch.no_grad():
347
  preds = self.birefnet_model(input_images)[-1].sigmoid().cpu()
@@ -793,11 +790,11 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline):
793
  del new_pipeline.VGGT_model.point_head
794
  new_pipeline.VGGT_model.eval()
795
 
796
- # new_pipeline.birefnet_model = AutoModelForImageSegmentation.from_pretrained(
797
- # 'ZhengPeng7/BiRefNet',
798
- # trust_remote_code=True
799
- # ).cpu()
800
- # new_pipeline.birefnet_model.eval()
801
 
802
  new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
803
  new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
 
232
  if scale < 1:
233
  input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
234
 
235
+ # Get mask using BiRefNet
236
+ mask = self._get_birefnet_mask(input)
237
 
238
+ # Convert input to RGBA and apply mask
239
+ input_rgba = input.convert('RGBA')
240
+ input_array = np.array(input_rgba)
241
+ input_array[:, :, 3] = mask * 255 # Apply mask to alpha channel
242
+ output = Image.fromarray(input_array)
 
 
 
243
 
244
  # Process the output image
245
  output_np = np.array(output)
 
338
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
339
  ])
340
 
341
+ input_images = transform_image(image).unsqueeze(0).to(self.device)
342
 
343
  with torch.no_grad():
344
  preds = self.birefnet_model(input_images)[-1].sigmoid().cpu()
 
790
  del new_pipeline.VGGT_model.point_head
791
  new_pipeline.VGGT_model.eval()
792
 
793
+ new_pipeline.birefnet_model = AutoModelForImageSegmentation.from_pretrained(
794
+ 'ZhengPeng7/BiRefNet',
795
+ trust_remote_code=True
796
+ ).to(new_pipeline.device)
797
+ new_pipeline.birefnet_model.eval()
798
 
799
  new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
800
  new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']