Spaces:
Runtime error
Runtime error
| import os | |
| from oss_utils import ossService | |
| import requests | |
| import random | |
| import json | |
| import time | |
| from diffusers.utils import load_image | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| # oss config | |
| BUCKET = os.environ.get('BUCKET', '') | |
| ENDPOINT = os.environ.get('ENDPOINT', '') | |
| PREFIX = os.environ.get('OSS_DIR_PREFIX', '') | |
| AK = os.environ.get('AK', '') | |
| SK = os.environ.get('SK', '') | |
| DASHONE_SERVICE_ID = os.environ.get('SERVICE_ID', '') | |
| URL = os.environ.get('URL', '') | |
| GET_URL = os.environ.get('GET_URL', '') | |
| OUTPUT_PATH = './output' | |
| os.makedirs(OUTPUT_PATH, exist_ok=True) | |
| oss_service = ossService(AK, SK, ENDPOINT, BUCKET, PREFIX) | |
| def async_request_and_query(data, request_id): | |
| url = URL | |
| headers = {'Content-Type': 'application/json'} | |
| # 1.发起一个异步请求 | |
| print('Start sending request') | |
| response = requests.post(url, headers=headers, data=json.dumps(data)) | |
| if response.status_code != requests.codes.ok: | |
| response.raise_for_status() | |
| response_json = json.loads(response.content.decode("utf-8")) | |
| print('Finish sending request') | |
| # 2.异步查询结果 | |
| is_running = True | |
| running_print_count = 0 | |
| sign_oss_path = None | |
| task_id = response_json['header']['task_id'] | |
| get_url = GET_URL | |
| get_header = headers | |
| get_data = {"header": {"request_id":request_id,"service_id":DASHONE_SERVICE_ID,"task_id":f"{task_id}"}} | |
| print('Start querying results') | |
| while is_running: | |
| response2 = requests.post(get_url, headers=get_header, data=json.dumps(get_data)) | |
| if response2.status_code != requests.codes.ok: | |
| response2.raise_for_status() | |
| response2_json = json.loads(response2.content.decode("utf-8")) | |
| task_status = response2_json['header']['task_status'] | |
| if task_status == 'SUCCESS': | |
| sign_oss_path = response2_json['payload']['output']['res'] | |
| break | |
| elif task_status in ['FAILED', 'ERROR'] or running_print_count >= 120: | |
| raise ValueError(f'Task Failed') | |
| else: | |
| time.sleep(1) | |
| running_print_count += 1 | |
| continue | |
| print('Task succeeded!') | |
| return sign_oss_path | |
| # def add_transparent_watermark(pil_image, watermark_text, position, opacity, font_path, font_size): | |
| # # 加载字体 | |
| # font = ImageFont.truetype(font_path, font_size) | |
| # # 创建一个半透明的水印图层 | |
| # watermark_layer = Image.new("RGBA", pil_image.size) | |
| # draw = ImageDraw.Draw(watermark_layer) | |
| # # 文本颜色和透明度 | |
| # text_color = (255, 255, 255, opacity) # 白色文本 | |
| # outline_color = (0, 0, 0, opacity) # 黑色轮廓 | |
| # # 获取文本尺寸 | |
| # text_width = draw.textlength(watermark_text, font=font) | |
| # text_height = text_width // 5 | |
| # # 计算水印位置 | |
| # img_width, img_height = pil_image.size | |
| # x = img_width - text_width - position[0] | |
| # y = img_height - text_height - position[1] | |
| # outline_range = 1 # 轮廓的粗细 | |
| # for adj in range(-outline_range, outline_range+1): | |
| # for ord in range(-outline_range, outline_range+1): | |
| # if adj != 0 or ord != 0: # 避免中心位置,那是真正的文本 | |
| # draw.text((x+adj, y+ord), watermark_text, font=font, fill=outline_color) | |
| # # 将文本绘制到水印层上 | |
| # draw.text((x, y), watermark_text, font=font, fill=text_color) | |
| # # 将水印层叠加到原始图像上 | |
| # pil_image_with_watermark = Image.alpha_composite(pil_image.convert("RGBA"), watermark_layer) | |
| # # 返回添加了水印的图像 | |
| # return pil_image_with_watermark | |
| def inference(input_image, upscale): | |
| # process alpha channel | |
| alpha_channel = input_image.split()[-1] | |
| input_image = input_image.convert('RGB') | |
| local_save_path = os.path.join(OUTPUT_PATH, 'tmp.png') | |
| local_output_save_path = os.path.join(OUTPUT_PATH, 'out.png') | |
| input_image.save(local_save_path) | |
| # generate image url | |
| oss_key = os.path.join(PREFIX, 'tmp.png') | |
| _, image_url = oss_service.uploadOssFile(oss_key, local_save_path) | |
| # rm local file | |
| if os.path.isfile(local_save_path): | |
| os.remove(local_save_path) | |
| data = {} | |
| data_header = {} | |
| data_payload = {} | |
| data_input = {} | |
| data_para = {} | |
| data_header['request_id'] = "".join(random.sample("0123456789abcdefghijklmnopqrstuvwxyz", 10)) | |
| data_header['service_id'] = DASHONE_SERVICE_ID | |
| data_input['image_url'] = image_url | |
| data_para['upscale'] = upscale | |
| data_para['platform'] = 'modelscope' | |
| data_payload['input'] = data_input | |
| data_payload['parameters'] = data_para | |
| data['header'] = data_header | |
| data['payload'] = data_payload | |
| try: | |
| output_url = async_request_and_query(data=data, request_id=data_header['request_id']) | |
| download_status = oss_service.downloadFile(output_url, local_output_save_path) | |
| if not download_status: | |
| raise ValueError(f'Download output image failed') | |
| except: | |
| output_image = load_image('./error.png') | |
| return [output_image], 'The input image format or resolution does not meet the requirements. Please change the image or resize it and try again.' | |
| output_image = load_image(local_output_save_path) | |
| if os.path.isfile(local_output_save_path): | |
| os.remove(local_output_save_path) | |
| out_width, out_height = output_image.size | |
| # add alpha channel | |
| output_alpha_channel = alpha_channel.resize(output_image.size, resample=Image.LANCZOS) | |
| # merge | |
| output_image_alpha = Image.merge("RGBA", (*output_image.split(), output_alpha_channel)) | |
| # new_width = out_width // 4 | |
| # new_height = out_height // 4 | |
| # resized_image = np.array(input_image.resize((new_width, new_height))) | |
| # np_output_image = np.array(output_image) | |
| # np_output_image[0:new_height, 0:new_width] = resized_image.copy() | |
| # new_output_image = Image.fromarray(np_output_image) | |
| # new_output_image = add_transparent_watermark( | |
| # pil_image=new_output_image, | |
| # watermark_text="追影-放大镜", | |
| # position=(5, 5), | |
| # opacity=200, | |
| # font_path="AlibabaPuHuiTi-3-45-Light.ttf", # 例如:"Arial", "Helvetica", "Times New Roman" | |
| # font_size= out_width // 30 | |
| # ) | |
| org_width, org_height = input_image.size | |
| if max(org_width, org_height) > 1920 or min(org_width, org_height) > 1080: | |
| msg = 'The input image size has exceeded the optimal range. You can consider scaling the input image for better generation effect' | |
| else: | |
| msg = 'Task succeeded' | |
| return [output_image_alpha], msg | |