File size: 1,817 Bytes
6ab9bc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce4e7b6
6ab9bc6
 
 
 
 
 
402507b
6ab9bc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c992bc
6ab9bc6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from torch.nn.init import _calculate_fan_in_and_fan_out
from timm.models.layers import to_2tuple, trunc_normal_

import torchvision.transforms as transforms
from torchvision import models

import gradio as gr
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from model import dehazeformer_t

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
t_model_load = dehazeformer_t().to(device)
t_model_load
best_model_weights = torch.load('best_t_model_weights.pth', map_location=torch.device('cpu'))
t_model_load.load_state_dict(best_model_weights)

def pred_one_image(inp):
  one_image = np.array(inp.resize((256, 256)).convert("RGB"))/255
  # convert to other format HWC -> CHW
  one_image = np.moveaxis(one_image, -1, 0)
  # mask = np.expand_dims(mask, 0)
  one_image = torch.tensor(one_image).float()
  one_image = one_image.unsqueeze(0)
  one_image = one_image.to(device)

  with torch.no_grad():
    t_model_load.eval()
    output = t_model_load(one_image)
    print(output.shape)
    output = output[0].cpu().permute((1, 2, 0))
    plt.figure(figsize=(10, 10))
    plt.imshow(output.numpy())  # convert CHW -> HWC
    plt.axis("off")
    # 保存图像,可以指定文件名和格式,例如 'image.png'
    plt.savefig('image.png', format='png', dpi=300)  # dpi是图像的分辨率
    out_img = Image.open('image.png')

  return out_img

demo = gr.Interface(fn=pred_one_image,
             inputs=gr.Image(type="pil"),
             outputs=gr.Image(type="pil"),
             examples=['noisy_10961455225_0786d3edd2_c.jpg'],
             )

demo.launch(debug=True)
# demo.launch()