Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -176,7 +176,7 @@ transform = transforms.Compose([
|
|
| 176 |
transforms.ToTensor(),
|
| 177 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 178 |
])
|
| 179 |
-
weight_dtype = torch.
|
| 180 |
|
| 181 |
# line model
|
| 182 |
line_model_path = os.path.join(model_global_path, 'LE', 'erika.pth')
|
|
@@ -201,7 +201,7 @@ global MultiResNetModel
|
|
| 201 |
global cur_style
|
| 202 |
|
| 203 |
cur_style = 'line + shadow'
|
| 204 |
-
weight_dtype = torch.
|
| 205 |
|
| 206 |
block_out_channels = [128, 128, 256, 512, 512]
|
| 207 |
MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
|
|
@@ -313,7 +313,7 @@ print('loaded pipeline')
|
|
| 313 |
|
| 314 |
@spaces.GPU
|
| 315 |
def change_ckpt(style):
|
| 316 |
-
weight_dtype = torch.
|
| 317 |
|
| 318 |
if style == 'line':
|
| 319 |
MultiResNetModel_path = os.path.join(model_global_path, 'line_GSRP', 'MultiResNetModel.bin')
|
|
|
|
| 176 |
transforms.ToTensor(),
|
| 177 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 178 |
])
|
| 179 |
+
weight_dtype = torch.float16
|
| 180 |
|
| 181 |
# line model
|
| 182 |
line_model_path = os.path.join(model_global_path, 'LE', 'erika.pth')
|
|
|
|
| 201 |
global cur_style
|
| 202 |
|
| 203 |
cur_style = 'line + shadow'
|
| 204 |
+
weight_dtype = torch.float16
|
| 205 |
|
| 206 |
block_out_channels = [128, 128, 256, 512, 512]
|
| 207 |
MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
|
|
|
|
| 313 |
|
| 314 |
@spaces.GPU
|
| 315 |
def change_ckpt(style):
|
| 316 |
+
weight_dtype = torch.float16
|
| 317 |
|
| 318 |
if style == 'line':
|
| 319 |
MultiResNetModel_path = os.path.join(model_global_path, 'line_GSRP', 'MultiResNetModel.bin')
|