Miquel Farré commited on
Commit
cdf8663
·
1 Parent(s): d610685
Files changed (1) hide show
  1. app.py +24 -8
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  from torchvision import transforms
7
  import os
8
  import uuid
 
9
 
10
  torch.set_float32_matmul_precision(["high", "highest"][0])
11
 
@@ -25,17 +26,25 @@ transform_image = transforms.Compose(
25
  def fn(image):
26
  im = load_img(image, output_type="pil")
27
  im = im.convert("RGB")
 
28
  processed_image = process(im)
29
 
30
  # Generate a unique filename for the processed image
31
  unique_id = str(uuid.uuid4())[:8]
32
  output_path = f"output_{unique_id}.jpg"
33
 
34
- # Convert RGBA to RGB before saving as JPEG
35
- rgb_image = processed_image.convert("RGB")
36
- rgb_image.save(output_path, format="JPEG")
37
-
38
- return processed_image, output_path
 
 
 
 
 
 
 
39
 
40
  @spaces.GPU
41
  def process(image):
@@ -57,15 +66,22 @@ def process(image):
57
  def process_file(f):
58
  im = load_img(f, output_type="pil")
59
  im = im.convert("RGB")
 
60
  transparent = process(im)
61
 
62
  # Save as JPEG instead of PNG
63
  unique_id = str(uuid.uuid4())[:8]
64
  output_path = f"output_{unique_id}.jpg"
65
 
66
- # Convert RGBA to RGB before saving as JPEG
67
- rgb_image = transparent.convert("RGB")
68
- rgb_image.save(output_path, format="JPEG")
 
 
 
 
 
 
69
 
70
  return output_path
71
 
 
6
  from torchvision import transforms
7
  import os
8
  import uuid
9
+ from PIL import Image
10
 
11
  torch.set_float32_matmul_precision(["high", "highest"][0])
12
 
 
26
  def fn(image):
27
  im = load_img(image, output_type="pil")
28
  im = im.convert("RGB")
29
+ # Get the segmented image (RGBA)
30
  processed_image = process(im)
31
 
32
  # Generate a unique filename for the processed image
33
  unique_id = str(uuid.uuid4())[:8]
34
  output_path = f"output_{unique_id}.jpg"
35
 
36
+ # Create a white background and properly composite with the RGBA image
37
+ white_bg = Image.new("RGB", processed_image.size, (255, 255, 255))
38
+ if processed_image.mode == 'RGBA':
39
+ # Use the alpha channel as a mask for compositing
40
+ white_bg.paste(processed_image, mask=processed_image.split()[3]) # The 4th channel is alpha
41
+ white_bg.save(output_path, format="JPEG")
42
+ # Return the composited image for display to match what's being downloaded
43
+ return white_bg, output_path
44
+ else:
45
+ rgb_image = processed_image.convert("RGB")
46
+ rgb_image.save(output_path, format="JPEG")
47
+ return rgb_image, output_path
48
 
49
  @spaces.GPU
50
  def process(image):
 
66
  def process_file(f):
67
  im = load_img(f, output_type="pil")
68
  im = im.convert("RGB")
69
+ # Get the segmented image (RGBA)
70
  transparent = process(im)
71
 
72
  # Save as JPEG instead of PNG
73
  unique_id = str(uuid.uuid4())[:8]
74
  output_path = f"output_{unique_id}.jpg"
75
 
76
+ # Create a white background and properly composite with the RGBA image
77
+ white_bg = Image.new("RGB", transparent.size, (255, 255, 255))
78
+ if transparent.mode == 'RGBA':
79
+ # Use the alpha channel as a mask for compositing
80
+ white_bg.paste(transparent, mask=transparent.split()[3]) # The 4th channel is alpha
81
+ white_bg.save(output_path, format="JPEG")
82
+ else:
83
+ rgb_image = transparent.convert("RGB")
84
+ rgb_image.save(output_path, format="JPEG")
85
 
86
  return output_path
87