Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| def device(gpu_id=0): | |
| if torch.cuda.is_available(): | |
| return torch.device(f"cuda:{gpu_id}") | |
| return torch.device("cpu") | |
| def load_matching_state_dict(model: nn.Module, state_dict): | |
| model_dict = model.state_dict() | |
| filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict} | |
| model.load_state_dict(filtered_dict) | |
| def resize(t: torch.Tensor, size: int) -> torch.Tensor: | |
| B, C, H, W = t.shape | |
| t = t.reshape(B, C, size, H // size, size, W // size) | |
| return t.mean([3, 5]) | |
| def make_image(tensor): | |
| return ( | |
| tensor.detach() | |
| .clamp_(min=-1, max=1) | |
| .add(1) | |
| .div_(2) | |
| .mul(255) | |
| .type(torch.uint8) | |
| .permute(0, 2, 3, 1) | |
| .to('cpu') | |
| .numpy() | |
| ) | |