Spaces:
Running
on
Zero
Running
on
Zero
Update trellis/pipelines/trellis_image_to_3d.py
Browse files
trellis/pipelines/trellis_image_to_3d.py
CHANGED
@@ -232,17 +232,14 @@ class TrellisImageTo3DPipeline(Pipeline):
|
|
232 |
if scale < 1:
|
233 |
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
234 |
|
235 |
-
#
|
236 |
-
|
237 |
|
238 |
-
#
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
if getattr(self, 'rembg_session', None) is None:
|
244 |
-
self.rembg_session = rembg.new_session('u2net')
|
245 |
-
output = rembg.remove(input, session=self.rembg_session)
|
246 |
|
247 |
# Process the output image
|
248 |
output_np = np.array(output)
|
@@ -341,7 +338,7 @@ class TrellisImageTo3DPipeline(Pipeline):
|
|
341 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
342 |
])
|
343 |
|
344 |
-
input_images = transform_image(image).unsqueeze(0).
|
345 |
|
346 |
with torch.no_grad():
|
347 |
preds = self.birefnet_model(input_images)[-1].sigmoid().cpu()
|
@@ -793,11 +790,11 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline):
|
|
793 |
del new_pipeline.VGGT_model.point_head
|
794 |
new_pipeline.VGGT_model.eval()
|
795 |
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
|
802 |
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
|
803 |
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
|
|
|
232 |
if scale < 1:
|
233 |
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
234 |
|
235 |
+
# Get mask using BiRefNet
|
236 |
+
mask = self._get_birefnet_mask(input)
|
237 |
|
238 |
+
# Convert input to RGBA and apply mask
|
239 |
+
input_rgba = input.convert('RGBA')
|
240 |
+
input_array = np.array(input_rgba)
|
241 |
+
input_array[:, :, 3] = mask * 255 # Apply mask to alpha channel
|
242 |
+
output = Image.fromarray(input_array)
|
|
|
|
|
|
|
243 |
|
244 |
# Process the output image
|
245 |
output_np = np.array(output)
|
|
|
338 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
339 |
])
|
340 |
|
341 |
+
input_images = transform_image(image).unsqueeze(0).to(self.device)
|
342 |
|
343 |
with torch.no_grad():
|
344 |
preds = self.birefnet_model(input_images)[-1].sigmoid().cpu()
|
|
|
790 |
del new_pipeline.VGGT_model.point_head
|
791 |
new_pipeline.VGGT_model.eval()
|
792 |
|
793 |
+
new_pipeline.birefnet_model = AutoModelForImageSegmentation.from_pretrained(
|
794 |
+
'ZhengPeng7/BiRefNet',
|
795 |
+
trust_remote_code=True
|
796 |
+
).to(new_pipeline.device)
|
797 |
+
new_pipeline.birefnet_model.eval()
|
798 |
|
799 |
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
|
800 |
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
|