|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- geekyrakshit/LoL-Dataset |
|
pipeline_tag: image-to-image |
|
tags: |
|
- image-enhancement |
|
- computer-vision |
|
- image-to-image |
|
--- |
|
# MIRNet low-light image enhancement |
|
[](https://huggingface.co/spaces/dblasko/mirnet-low-light-img-enhancement) |
|
MIRNet-based low-light image enhancer specialized on restoring dark images from events (concerts, parties, clubs...). |
|
|
|
## Project source-code and further documentation |
|
Documentation about pre-training, fine-tuning, model architecture, usage and all source code used for building and inference can be found in the [GitHub repository of the project](https://github.com/dblasko/low-light-event-img-enhancer/). |
|
This page currently stores the PyTorch model weights and model definition, a HuggingFace pipeline will be implemented in the future. |
|
|
|
## Using the model |
|
To use the model, you need to have the `model` folder, that you can dowload from this repository as well as on [GitHub](https://github.com/dblasko/low-light-event-img-enhancer/), present in your project folder. |
|
|
|
Then, the following code can be used to download the model weights from HuggingFace and load them in PyTorch for downstream use of the model: |
|
```python |
|
import torch |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
from model.MIRNet.model import MIRNet |
|
|
|
device = ( |
|
torch.device("cuda") |
|
if torch.cuda.is_available() |
|
else torch.device("mps") |
|
if torch.backends.mps.is_available() |
|
else torch.device("cpu") |
|
) |
|
|
|
# Download the model weights from the Hugging Face Hub |
|
model_path = hf_hub_download( |
|
repo_id="dblasko/mirnet-low-light-img-enhancement", filename="mirnet_finetuned.pth" |
|
) |
|
|
|
# Load the model |
|
model = MIRNet().to(device) |
|
model.load_state_dict(torch.load(model_path, map_location=device)["model_state_dict"]) |
|
|
|
# Use the model, for example for inference on an image |
|
model.eval() |
|
with torch.no_grad(): |
|
img = Image.open("image_path.png").convert("RGB") |
|
img_tensor = T.Compose( |
|
[ |
|
T.Resize(400), # Adjust image resizing depending on hardware |
|
T.ToTensor(), |
|
T.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), |
|
] |
|
)(img).unsqueeze(0) |
|
img_tensor = img_tensor.to(device) |
|
|
|
if img_tensor.shape[2] % 8 != 0: |
|
img_tensor = img_tensor[:, :, : -(img_tensor.shape[2] % 8), :] |
|
if img_tensor.shape[3] % 8 != 0: |
|
img_tensor = img_tensor[:, :, :, : -(img_tensor.shape[3] % 8)] |
|
|
|
output = model(img_tensor) |
|
|
|
``` |
|
|