English
real-esrgan / handler.py
garg-aayush's picture
update handler file
d2a76eb
raw
history blame
1.77 kB
import torch
from PIL import Image
from io import BytesIO
from realesrgan import RealESRGANer
from typing import Dict, List, Any
import os
from pathlib import Path
from basicsr.archs.rrdbnet_arch import RRDBNet
import numpy as np
import cv2
import torch
import base64
# torch.cuda.empty_cache()
# torch.cuda.set_per_process_memory_fraction(0.99)
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64,garbage_collection_threshold:0.7"
class EndpointHandler:
def __init__(self, path=""):
self.model = RealESRGANer(
scale=4,
model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
# dni_weight=dni_weight,
model= RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
tile=1000,
tile_pad=10,
# pre_pad=args.pre_pad,
half=True,
# gpu_id=args.gpu_id
)
def __call__(self, data: Any) -> Dict[str, List[float]]:
inputs = data.pop("inputs", data)
outscale = 3
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
# Convert PIL image to NumPy array
opencv_image = np.array(image)
# Convert RGB to BGR (PIL uses RGB, OpenCV expects BGR)
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
output, _ = self.model.enhance(opencv_image, outscale=outscale)
# BGR to RGB
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
img_byte_arr = BytesIO()
output = Image.fromarray(output)
output.save(img_byte_arr, format='PNG')
img_str = base64.b64encode(img_byte_arr.getvalue())
return {"out_image": img_str.decode()}