Svane20 commited on
Commit
f28556a
·
1 Parent(s): 16f064e

Updated model to use PyTorch instead of ONNX

Browse files
Files changed (5) hide show
  1. app.py +10 -165
  2. model.py → models.py +0 -0
  3. models/.gitkeep +0 -0
  4. pipeline.py +80 -0
  5. replacements.py +59 -0
app.py CHANGED
@@ -1,183 +1,28 @@
1
  import gradio as gr
2
- import torch
3
- from torchvision.transforms import Compose, Resize, ToTensor, Normalize
4
- import pymatting
5
  import numpy as np
6
  from PIL import Image
7
- from typing import Tuple
8
- import random
9
- from pathlib import Path
10
 
11
- from model import SwinMattingModel
 
12
 
 
13
 
14
- def _load_checkpoint(model, checkpoint_path):
15
- # Load the checkpoint
16
- checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
17
-
18
- # Check if there are any errors when loading the state dictionary
19
- missing_keys, unexpected_keys = model.load_state_dict(checkpoint)
20
- if missing_keys:
21
- print(missing_keys)
22
- raise RuntimeError("Missing keys in checkpoint.")
23
-
24
- if unexpected_keys:
25
- print(unexpected_keys)
26
- raise RuntimeError("Unexpected keys in checkpoint.")
27
-
28
-
29
- def _load_model(checkpoint, device):
30
- model = SwinMattingModel({
31
- "encoder": {
32
- "model_name": "microsoft/swin-small-patch4-window7-224"
33
- },
34
- "decoder": {
35
- "use_attn": True,
36
- "refine_channels": 16
37
- }
38
- })
39
- _load_checkpoint(model, checkpoint)
40
-
41
- model.to(device)
42
- model.eval()
43
-
44
- return model
45
-
46
-
47
- transforms = Compose(
48
- [
49
- Resize(size=(512, 512)),
50
- ToTensor(),
51
- Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
52
- ],
53
- )
54
-
55
- share_repo = False
56
- checkpoint_path = "swin_small_patch4_window7_224_512_v1_latest.pt"
57
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
- model = _load_model(checkpoint_path, device)
59
-
60
- print(f"Using device: {device}")
61
- if device.type == "cuda":
62
- print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
63
-
64
-
65
- def _get_foreground_estimation(image, alpha):
66
- """
67
- Estimate the foreground using the image and the predicted alpha mask.
68
-
69
- Args:
70
- image (np.ndarray): The input image.
71
- alpha (np.ndarray): The predicted alpha mask.
72
-
73
- Returns:
74
- np.ndarray: The estimated foreground.
75
- """
76
- # Normalize the image to [0, 1] range
77
- normalized_image = np.array(image) / 255.0
78
-
79
- # Invert the alpha mask since the pymatting library expects the sky to be the background
80
- inverted_alpha = 1 - alpha
81
-
82
- return pymatting.estimate_foreground_ml(image=normalized_image, alpha=inverted_alpha)
83
-
84
-
85
- def _sky_replacement(foreground, alpha_mask):
86
- """
87
- Perform sky replacement using the estimated foreground and predicted alpha mask.
88
-
89
- Args:
90
- foreground (np.ndarray): The estimated foreground.
91
- alpha_mask (np.ndarray): The predicted alpha mask.
92
-
93
- Returns:
94
- np.ndarray: The sky-replaced image.
95
- """
96
- new_sky_path = Path(__file__).parent / "assets/skies/francesco-ungaro-i75WTJn-RBY-unsplash.jpg"
97
- new_sky_img = Image.open(new_sky_path).convert("RGB")
98
-
99
- # Get the target size from the foreground image
100
- h, w = foreground.shape[:2]
101
-
102
- # Check the size of the sky image
103
- sky_width, sky_height = new_sky_img.size
104
-
105
- # If the sky image is smaller than the target size
106
- if sky_width < w or sky_height < h:
107
- scale = max(w / sky_width, h / sky_height)
108
- new_size = (int(sky_width * scale), int(sky_height * scale))
109
- new_sky_img = new_sky_img.resize(new_size, resample=Image.Resampling.LANCZOS)
110
- sky_width, sky_height = new_sky_img.size
111
-
112
- # Determine the maximum possible top-left coordinates for the crop
113
- max_left = sky_width - w
114
- max_top = sky_height - h
115
-
116
- # Choose random offsets for left and top within the valid range
117
- left = random.randint(a=0, b=max_left) if max_left > 0 else 0
118
- top = random.randint(a=0, b=max_top) if max_top > 0 else 0
119
-
120
- # Crop the sky image to the target size using the random offsets
121
- new_sky_img = new_sky_img.crop((left, top, left + w, top + h))
122
-
123
- new_sky = np.asarray(new_sky_img).astype(np.float32) / 255.0
124
- if foreground.dtype != np.float32:
125
- foreground = foreground.astype(np.float32) / 255.0
126
- if foreground.shape[2] == 4:
127
- foreground = foreground[:, :, :3]
128
-
129
- # Ensure that the alpha mask values are within the range [0, 1]
130
- alpha_mask = np.clip(alpha_mask, a_min=0, a_max=1)
131
-
132
- # Blend the foreground with the new sky using the alpha mask
133
- return (1 - alpha_mask[:, :, None]) * foreground + alpha_mask[:, :, None] * new_sky
134
-
135
-
136
- def _inference(image):
137
- """
138
- Perform inference on the input image using the ONNX model.
139
-
140
- Args:
141
- image (Image): The input image.
142
-
143
- Returns:
144
- np.ndarray: The predicted alpha mask.
145
- """
146
- with torch.inference_mode():
147
- output = model(image)
148
-
149
- # Ensure the output is in valid range [0, 1]
150
- output = output.detach().cpu().numpy()
151
- output = np.clip(output, a_min=0, a_max=1)
152
-
153
- return np.squeeze(output, axis=0).squeeze()
154
 
155
 
156
  def predict(image):
157
- """
158
- Perform sky replacement on the input image.
159
-
160
- Args:
161
- image (Image): The input image.
162
-
163
- Returns:
164
- Tuple[Image, Image]: The predicted alpha mask and the sky-replaced image.
165
- """
166
- image_tensor = transforms(image).unsqueeze(0).to(device)
167
- predicted_alpha = _inference(image_tensor)
168
-
169
- # Downscale the input image to match predicted_alpha
170
  h, w = predicted_alpha.shape
171
- downscaled_image = image.resize((w, h), Image.Resampling.LANCZOS)
172
 
173
  # Estimate foreground and run sky_replacement
174
- foreground = _get_foreground_estimation(downscaled_image, predicted_alpha)
175
- replaced_sky = _sky_replacement(foreground, predicted_alpha)
176
 
177
  # Resize the predicted alpha and replaced sky to original dimensions
178
  predicted_alpha_pil = Image.fromarray((predicted_alpha * 255).astype(np.uint8), mode='L')
179
  predicted_alpha_pil = predicted_alpha_pil.resize((h, w), Image.Resampling.LANCZOS)
180
- replaced_sky_pil = Image.fromarray((replaced_sky * 255).astype(np.uint8)) # mode='RGB' typically
181
  replaced_sky_pil = replaced_sky_pil.resize((h, w), Image.Resampling.LANCZOS)
182
 
183
  return predicted_alpha_pil, replaced_sky_pil
@@ -291,4 +136,4 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
291
  run_button.click(fn=predict, inputs=input_image, outputs=[output_mask, output_sky])
292
 
293
  # Launch the interface
294
- demo.launch(share=share_repo, ssr_mode=False)
 
1
  import gradio as gr
 
 
 
2
  import numpy as np
3
  from PIL import Image
 
 
 
4
 
5
+ from pipeline import Pipeline
6
+ from replacements import get_foreground_estimation, sky_replacement
7
 
8
+ SHARE_REPO = False
9
 
10
+ pipeline = Pipeline(model_name="swin_small_patch4_window7_224")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def predict(image):
14
+ # Run inference to get the predicted alpha mask
15
+ predicted_alpha = pipeline.inference(image)
 
 
 
 
 
 
 
 
 
 
 
16
  h, w = predicted_alpha.shape
 
17
 
18
  # Estimate foreground and run sky_replacement
19
+ foreground = get_foreground_estimation(image, predicted_alpha)
20
+ replaced_sky = sky_replacement(foreground, predicted_alpha)
21
 
22
  # Resize the predicted alpha and replaced sky to original dimensions
23
  predicted_alpha_pil = Image.fromarray((predicted_alpha * 255).astype(np.uint8), mode='L')
24
  predicted_alpha_pil = predicted_alpha_pil.resize((h, w), Image.Resampling.LANCZOS)
25
+ replaced_sky_pil = Image.fromarray((replaced_sky * 255).astype(np.uint8))
26
  replaced_sky_pil = replaced_sky_pil.resize((h, w), Image.Resampling.LANCZOS)
27
 
28
  return predicted_alpha_pil, replaced_sky_pil
 
136
  run_button.click(fn=predict, inputs=input_image, outputs=[output_mask, output_sky])
137
 
138
  # Launch the interface
139
+ demo.launch(share=SHARE_REPO, ssr_mode=False)
model.py → models.py RENAMED
File without changes
models/.gitkeep ADDED
File without changes
pipeline.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
3
+ import numpy as np
4
+
5
+ from models import SwinMattingModel
6
+
7
+
8
+ class Pipeline:
9
+ def __init__(self, model_name: str):
10
+ self.transforms = Compose(
11
+ [
12
+ Resize(size=(512, 512)),
13
+ ToTensor(),
14
+ Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
15
+ ],
16
+ )
17
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ self.is_torch_script = self.device.type == 'cpu'
19
+ self.model = self._load_model(model_name)
20
+
21
+ self._log_device_info()
22
+
23
+ def inference(self, image):
24
+ if self.model is None:
25
+ raise RuntimeError("Model is not loaded. Call load_model() first.")
26
+
27
+ tensor = self.transforms(image).unsqueeze(0).to(self.device)
28
+ with torch.inference_mode():
29
+ output = self.model(tensor)
30
+
31
+ output = output.detach().cpu().numpy()
32
+ output = np.clip(output, a_min=0, a_max=1)
33
+
34
+ return np.squeeze(output, axis=0).squeeze()
35
+
36
+ def _load_pytorch_model(self, checkpoint):
37
+ model = SwinMattingModel({
38
+ "encoder": {
39
+ "model_name": "microsoft/swin-small-patch4-window7-224"
40
+ },
41
+ "decoder": {
42
+ "use_attn": True,
43
+ "refine_channels": 16
44
+ }
45
+ })
46
+ self._load_checkpoint(model, checkpoint)
47
+
48
+ model.to(self.device)
49
+ model.eval()
50
+
51
+ return model
52
+
53
+ def _load_model(self, model_name):
54
+ checkpoint_path = self._get_model_checkpoint(model_name)
55
+
56
+ model = torch.jit.load(checkpoint_path, map_location=self.device) if self.is_torch_script \
57
+ else self._load_pytorch_model(checkpoint_path)
58
+
59
+ model.to(self.device)
60
+ model.eval()
61
+
62
+ return model
63
+
64
+ def _get_model_checkpoint(self, model_name):
65
+ return f"models/{model_name}_torch_script.pt" if self.is_torch_script else f"models/{model_name}_minimal.pt"
66
+
67
+ def _load_checkpoint(self, model, checkpoint_path):
68
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
69
+
70
+ missing_keys, unexpected_keys = model.load_state_dict(checkpoint)
71
+ if missing_keys:
72
+ print(missing_keys)
73
+ raise RuntimeError("Missing keys in checkpoint.")
74
+ if unexpected_keys:
75
+ print(unexpected_keys)
76
+ raise RuntimeError("Unexpected keys in checkpoint.")
77
+
78
+ def _log_device_info(self):
79
+ if self.device.type == 'cuda':
80
+ print(f"Hardware: {torch.cuda.get_device_name(torch.cuda.current_device())}")
replacements.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pymatting
2
+ import numpy as np
3
+ from PIL import Image
4
+ import random
5
+ from pathlib import Path
6
+
7
+ def get_foreground_estimation(image, alpha):
8
+ # Downscale the input image to match predicted_alpha
9
+ h, w = alpha.shape
10
+ downscaled_image = image.resize((w, h), Image.Resampling.LANCZOS)
11
+
12
+ # Normalize the image to [0, 1] range
13
+ normalized_image = np.array(downscaled_image) / 255.0
14
+
15
+ # Invert the alpha mask since the pymatting library expects the sky to be the background
16
+ inverted_alpha = 1 - alpha
17
+
18
+ return pymatting.estimate_foreground_ml(image=normalized_image, alpha=inverted_alpha)
19
+
20
+
21
+ def sky_replacement(foreground, alpha_mask):
22
+ new_sky_path = Path(__file__).parent / "assets/skies/francesco-ungaro-i75WTJn-RBY-unsplash.jpg"
23
+ new_sky_img = Image.open(new_sky_path).convert("RGB")
24
+
25
+ # Get the target size from the foreground image
26
+ h, w = foreground.shape[:2]
27
+
28
+ # Check the size of the sky image
29
+ sky_width, sky_height = new_sky_img.size
30
+
31
+ # If the sky image is smaller than the target size
32
+ if sky_width < w or sky_height < h:
33
+ scale = max(w / sky_width, h / sky_height)
34
+ new_size = (int(sky_width * scale), int(sky_height * scale))
35
+ new_sky_img = new_sky_img.resize(new_size, resample=Image.Resampling.LANCZOS)
36
+ sky_width, sky_height = new_sky_img.size
37
+
38
+ # Determine the maximum possible top-left coordinates for the crop
39
+ max_left = sky_width - w
40
+ max_top = sky_height - h
41
+
42
+ # Choose random offsets for left and top within the valid range
43
+ left = random.randint(a=0, b=max_left) if max_left > 0 else 0
44
+ top = random.randint(a=0, b=max_top) if max_top > 0 else 0
45
+
46
+ # Crop the sky image to the target size using the random offsets
47
+ new_sky_img = new_sky_img.crop((left, top, left + w, top + h))
48
+
49
+ new_sky = np.asarray(new_sky_img).astype(np.float32) / 255.0
50
+ if foreground.dtype != np.float32:
51
+ foreground = foreground.astype(np.float32) / 255.0
52
+ if foreground.shape[2] == 4:
53
+ foreground = foreground[:, :, :3]
54
+
55
+ # Ensure that the alpha mask values are within the range [0, 1]
56
+ alpha_mask = np.clip(alpha_mask, a_min=0, a_max=1)
57
+
58
+ # Blend the foreground with the new sky using the alpha mask
59
+ return (1 - alpha_mask[:, :, None]) * foreground + alpha_mask[:, :, None] * new_sky