| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from PIL import Image | 
					
					
						
						| 
							 | 
						from io import BytesIO | 
					
					
						
						| 
							 | 
						from realesrgan import RealESRGANer | 
					
					
						
						| 
							 | 
						from typing import Dict, List, Any | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						from pathlib import Path | 
					
					
						
						| 
							 | 
						from basicsr.archs.rrdbnet_arch import RRDBNet | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import cv2 | 
					
					
						
						| 
							 | 
						import PIL | 
					
					
						
						| 
							 | 
						import boto3 | 
					
					
						
						| 
							 | 
						import uuid, io | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import base64 | 
					
					
						
						| 
							 | 
						import requests | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class EndpointHandler: | 
					
					
						
						| 
							 | 
						    def __init__(self, path=""): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.tiling_size = int(os.environ["TILING_SIZE"]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.model = RealESRGANer( | 
					
					
						
						| 
							 | 
						            scale=4,   | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth", | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            model= RRDBNet(num_in_ch=3, | 
					
					
						
						| 
							 | 
						                           num_out_ch=3, | 
					
					
						
						| 
							 | 
						                           num_feat=64, | 
					
					
						
						| 
							 | 
						                           num_block=23, | 
					
					
						
						| 
							 | 
						                           num_grow_ch=32, | 
					
					
						
						| 
							 | 
						                           scale=4 | 
					
					
						
						| 
							 | 
						                           ), | 
					
					
						
						| 
							 | 
						            tile=self.tiling_size, | 
					
					
						
						| 
							 | 
						            tile_pad=0, | 
					
					
						
						| 
							 | 
						            half=True, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.s3 = boto3.client('s3',  | 
					
					
						
						| 
							 | 
						                               aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'], | 
					
					
						
						| 
							 | 
						                               aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'], | 
					
					
						
						| 
							 | 
						                               ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.bucket_name = os.environ["S3_BUCKET_NAME"] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    def __call__(self, data: Any) -> Dict[str, List[float]]: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            inputs = data.pop("inputs", data) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            outscale = float(inputs.pop("outscale", 3)) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            image = self.download_image_url(inputs['image_url']) | 
					
					
						
						| 
							 | 
						            in_size, in_mode = image.size, image.mode | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            assert in_mode in ["RGB", "RGBA", "L"], f"Unsupported image mode: {in_mode}" | 
					
					
						
						| 
							 | 
						            if self.tiling_size == 0: | 
					
					
						
						| 
							 | 
						                assert in_size[0] * in_size[1] <  1400*1400, f"Image is too large: {in_size}: {in_size[0] * in_size[1]} is greater than {self.tiling_size*self.tiling_size}" | 
					
					
						
						| 
							 | 
						            assert outscale > 1 and outscale <=10, f"Outscale must be between 1 and 10: {outscale}" | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            print(f"image.size: {in_size}, image.mode: {in_mode}, outscale: {outscale}") | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            opencv_image = np.array(image) | 
					
					
						
						| 
							 | 
						            if in_mode == "RGB": | 
					
					
						
						| 
							 | 
						                opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR) | 
					
					
						
						| 
							 | 
						            elif in_mode == "RGBA": | 
					
					
						
						| 
							 | 
						                opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGBA2BGRA) | 
					
					
						
						| 
							 | 
						            elif in_mode == "L": | 
					
					
						
						| 
							 | 
						                opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_GRAY2RGB) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                raise ValueError(f"Unsupported image mode: {in_mode}") | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            output, _ = self.model.enhance(opencv_image, outscale=outscale) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            print(f"output.shape: {output.shape}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            out_shape = output.shape | 
					
					
						
						| 
							 | 
						            if len(out_shape) == 3: | 
					
					
						
						| 
							 | 
						                if out_shape[2] == 3: | 
					
					
						
						| 
							 | 
						                    output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) | 
					
					
						
						| 
							 | 
						                elif out_shape[2] == 4: | 
					
					
						
						| 
							 | 
						                    output = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            img_byte_arr = BytesIO() | 
					
					
						
						| 
							 | 
						            output = Image.fromarray(output) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            image_url, key = self.upload_to_s3(output) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            return {"image_url": image_url, | 
					
					
						
						| 
							 | 
						                    "image_key": key, | 
					
					
						
						| 
							 | 
						                    "error": None | 
					
					
						
						| 
							 | 
						                    } | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        except AssertionError as e: | 
					
					
						
						| 
							 | 
						            print(f"AssertionError: {e}") | 
					
					
						
						| 
							 | 
						            return {"out_image": None, "error": str(e)} | 
					
					
						
						| 
							 | 
						        except KeyError as e: | 
					
					
						
						| 
							 | 
						            print(f"KeyError: {e}") | 
					
					
						
						| 
							 | 
						            return {"out_image": None, "error": f"Missing key: {e}"} | 
					
					
						
						| 
							 | 
						        except ValueError as e: | 
					
					
						
						| 
							 | 
						            print(f"ValueError: {e}") | 
					
					
						
						| 
							 | 
						            return {"out_image": None, "error": str(e)} | 
					
					
						
						| 
							 | 
						        except PIL.UnidentifiedImageError as e: | 
					
					
						
						| 
							 | 
						            print(f"PIL.UnidentifiedImageError: {e}") | 
					
					
						
						| 
							 | 
						            return {"out_image": None, "error": "Invalid image format"} | 
					
					
						
						| 
							 | 
						        except Exception as e: | 
					
					
						
						| 
							 | 
						            print(f"Exception: {e}") | 
					
					
						
						| 
							 | 
						            return {"out_image": None, "error": "An unexpected error occurred"} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def upload_to_s3(self, image): | 
					
					
						
						| 
							 | 
						        "Upload the image to s3 and return the url." | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        prefix = str(uuid.uuid4()) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        in_mem_file = io.BytesIO() | 
					
					
						
						| 
							 | 
						        image.save(in_mem_file, 'PNG') | 
					
					
						
						| 
							 | 
						        in_mem_file.seek(0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        key = f"{prefix}.png" | 
					
					
						
						| 
							 | 
						        self.s3.upload_fileobj(in_mem_file, Bucket=self.bucket_name, Key=key) | 
					
					
						
						| 
							 | 
						        image_url = "https://{0}.s3.amazonaws.com/{1}".format(self.bucket_name, key) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return image_url, key | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def download_image_url(self, image_url): | 
					
					
						
						| 
							 | 
						        "Download the image from the url and return the image." | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        response = requests.get(image_url) | 
					
					
						
						| 
							 | 
						        image = Image.open(BytesIO(response.content)) | 
					
					
						
						| 
							 | 
						        return image |