Intro

These are my efforts to train a real-world usable Cascaded Gaze image denoising network. denoise_util.py includes all definitions required to use Cascaded Gaze networks with PyTorch.

You can find inference code for some of these models on my github:

Models

small (cg_denoise_jpg+webp_artifacts_small.safetensors)

  • currently the best model here, whilst smaller than v1.
  • extensively trained on 256 * 256 BGR patches for jpg & webp compression artefact removal only.
  • robust: can handle artefacts that have been up or down scaled in ranges of 40% - 130%, via bilinear, bicubic or lanczos.

Loading small

from denoise_util import CascadedGaze
from safetensors.torch import load_file
from safetensors import safe_open
import json

device = "cuda"
fn = "cg_denoise_jpg+webp_artifacts_small.safetensors"

with safe_open(fn, framework="pt") as f:
    metadata = f.metadata()

if not metadata or 'config' not in metadata:
    raise ValueError("No configuration found in model metadata")

config = json.loads(metadata['config'])
#note: you might also want to look at metadata['colorspace'], which has the value 'bgr' here.

model = CascadedGaze(**config)

state_dict = load_file(fn)
model.load_state_dict(state_dict)
model = model.to(device)
model.requires_grad_(False)
model.eval()

sidd (cg_sidd.safetensors) The official SIDD benchmark trained CascadedGaze model. I have ported the weights and added metadata such that it can be used as easily as my small model.

However, my view is that the SIDD dataset is poor, and as a result this model is not useful in any task.

v1

  • an early experiment, not recommended
  • ~ 132M params, trained on 256 * 256 RGB patches for intermediate jpg & webp compression artefact removal. It's been trained on about 700k samples (photographs only) at a precision of bf16. Also capable of removing ISO-like noise and gaussian noise.
  • I recommend inputing tensors of [B,3,256,256], with values of floats scaled to 0 - 1.

no metadata included, load like this:

from denoise_util import CascadedGaze
from safetensors.torch import load_file

device = "cuda"

img_channel = 3
width = 60
enc_blks = [2, 2, 4, 6]
middle_blk_num = 12
dec_blks = [2, 2, 2, 2]
GCE_CONVS_nums = [3,3,2,2]

model = CascadedGaze(img_channel=img_channel,width=width, middle_blk_num=middle_blk_num,
        enc_blk_nums=enc_blks, dec_blk_nums=dec_blks,GCE_CONVS_nums=GCE_CONVS_nums)

state_dict = load_file("small.safetensors")
model.load_state_dict(state_dict)
model = model.to(device)
model.requires_grad_(False)
model.eval()

jpg+webp denoising mini sample sample: 4x zoom demonstration of this model applied to an image of a toy car found online.

  • an early experiment, not recommended
  • only ~18M parameters, trained on 256 * 256 BGR patches for jpg & webp compression artefact removal only. PSNR loss was used.
  • can handle artefacts that have been up or down scaled.
#Loading as above but with some settings changed:
enc_blks = [2, 2, 3]
middle_blk_num = 6
dec_blks = [2, 2, 2]
GCE_CONVS_nums = [3,3,2]

Usage

  • Using https://github.com/ProGamerGov/blended-tiling to handle converting images of arbitrary sizes into 256*256 tiles then back again.
  • You'll need to make ammendments to prevent the batches from being too large for your device.
  • presumes the model was already loaded with code above.
  • loading/saving images omitted, but you could use PIL or cv2, etc. note the BGR vs RGB cardinality of the models.
import torch
from PIL import Image
import torchvision
from blended_tiling import TilingModule

#load an image however you want

tiling_module = TilingModule(
    tile_size=[256, 256],
    tile_overlap=[0.1, 0.1], # you can configure this to taste
    base_size=pil_image.size, #nb: see .shape if you load with cv2
)

tensor = torchvision.transforms.functional.to_tensor(pil_image) #also compatible with cv2
tensor = torch.unsqueeze(tensor,0)
tiles = tiling_module.split_into_tiles(tensor)
tiles = tiles.to(device)
with torch.no_grad():
    result = model(tiles).cpu() #you'll likely want to handle re-batching of tiles to fit vram
result = tiling_module.rebuild_with_masks(result).squeeze().clamp(0, 1)

#save an image however you want
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support