Spaces:
Runtime error
Runtime error
| """ | |
| AnyText: Multilingual Visual Text Generation And Editing | |
| Paper: https://arxiv.org/abs/2311.03054 | |
| Code: https://github.com/tyxsspa/AnyText | |
| Copyright (c) Alibaba, Inc. and its affiliates. | |
| """ | |
| import os | |
| from pathlib import Path | |
| from iopaint.model.utils import set_seed | |
| from safetensors.torch import load_file | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| import torch | |
| import re | |
| import numpy as np | |
| import cv2 | |
| import einops | |
| from PIL import ImageFont | |
| from iopaint.model.anytext.cldm.model import create_model, load_state_dict | |
| from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler | |
| from iopaint.model.anytext.utils import ( | |
| check_channels, | |
| draw_glyph, | |
| draw_glyph2, | |
| ) | |
| BBOX_MAX_NUM = 8 | |
| PLACE_HOLDER = "*" | |
| max_chars = 20 | |
| ANYTEXT_CFG = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml" | |
| ) | |
| def check_limits(tensor): | |
| float16_min = torch.finfo(torch.float16).min | |
| float16_max = torch.finfo(torch.float16).max | |
| # 检查张量中是否有值小于float16的最小值或大于float16的最大值 | |
| is_below_min = (tensor < float16_min).any() | |
| is_above_max = (tensor > float16_max).any() | |
| return is_below_min or is_above_max | |
| class AnyTextPipeline: | |
| def __init__(self, ckpt_path, font_path, device, use_fp16=True): | |
| self.cfg_path = ANYTEXT_CFG | |
| self.font_path = font_path | |
| self.use_fp16 = use_fp16 | |
| self.device = device | |
| self.font = ImageFont.truetype(font_path, size=60) | |
| self.model = create_model( | |
| self.cfg_path, | |
| device=self.device, | |
| use_fp16=self.use_fp16, | |
| ) | |
| if self.use_fp16: | |
| self.model = self.model.half() | |
| if Path(ckpt_path).suffix == ".safetensors": | |
| state_dict = load_file(ckpt_path, device="cpu") | |
| else: | |
| state_dict = load_state_dict(ckpt_path, location="cpu") | |
| self.model.load_state_dict(state_dict, strict=False) | |
| self.model = self.model.eval().to(self.device) | |
| self.ddim_sampler = DDIMSampler(self.model, device=self.device) | |
| def __call__( | |
| self, | |
| prompt: str, | |
| negative_prompt: str, | |
| image: np.ndarray, | |
| masked_image: np.ndarray, | |
| num_inference_steps: int, | |
| strength: float, | |
| guidance_scale: float, | |
| height: int, | |
| width: int, | |
| seed: int, | |
| sort_priority: str = "y", | |
| callback=None, | |
| ): | |
| """ | |
| Args: | |
| prompt: | |
| negative_prompt: | |
| image: | |
| masked_image: | |
| num_inference_steps: | |
| strength: | |
| guidance_scale: | |
| height: | |
| width: | |
| seed: | |
| sort_priority: x: left-right, y: top-down | |
| Returns: | |
| result: list of images in numpy.ndarray format | |
| rst_code: 0: normal -1: error 1:warning | |
| rst_info: string of error or warning | |
| """ | |
| set_seed(seed) | |
| str_warning = "" | |
| mode = "text-editing" | |
| revise_pos = False | |
| img_count = 1 | |
| ddim_steps = num_inference_steps | |
| w = width | |
| h = height | |
| strength = strength | |
| cfg_scale = guidance_scale | |
| eta = 0.0 | |
| prompt, texts = self.modify_prompt(prompt) | |
| if prompt is None and texts is None: | |
| return ( | |
| None, | |
| -1, | |
| "You have input Chinese prompt but the translator is not loaded!", | |
| "", | |
| ) | |
| n_lines = len(texts) | |
| if mode in ["text-generation", "gen"]: | |
| edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image | |
| elif mode in ["text-editing", "edit"]: | |
| if masked_image is None or image is None: | |
| return ( | |
| None, | |
| -1, | |
| "Reference image and position image are needed for text editing!", | |
| "", | |
| ) | |
| if isinstance(image, str): | |
| image = cv2.imread(image)[..., ::-1] | |
| assert image is not None, f"Can't read ori_image image from{image}!" | |
| elif isinstance(image, torch.Tensor): | |
| image = image.cpu().numpy() | |
| else: | |
| assert isinstance( | |
| image, np.ndarray | |
| ), f"Unknown format of ori_image: {type(image)}" | |
| edit_image = image.clip(1, 255) # for mask reason | |
| edit_image = check_channels(edit_image) | |
| # edit_image = resize_image( | |
| # edit_image, max_length=768 | |
| # ) # make w h multiple of 64, resize if w or h > max_length | |
| h, w = edit_image.shape[:2] # change h, w by input ref_img | |
| # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) | |
| if masked_image is None: | |
| pos_imgs = np.zeros((w, h, 1)) | |
| if isinstance(masked_image, str): | |
| masked_image = cv2.imread(masked_image)[..., ::-1] | |
| assert ( | |
| masked_image is not None | |
| ), f"Can't read draw_pos image from{masked_image}!" | |
| pos_imgs = 255 - masked_image | |
| elif isinstance(masked_image, torch.Tensor): | |
| pos_imgs = masked_image.cpu().numpy() | |
| else: | |
| assert isinstance( | |
| masked_image, np.ndarray | |
| ), f"Unknown format of draw_pos: {type(masked_image)}" | |
| pos_imgs = 255 - masked_image | |
| pos_imgs = pos_imgs[..., 0:1] | |
| pos_imgs = cv2.convertScaleAbs(pos_imgs) | |
| _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) | |
| # seprate pos_imgs | |
| pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) | |
| if len(pos_imgs) == 0: | |
| pos_imgs = [np.zeros((h, w, 1))] | |
| if len(pos_imgs) < n_lines: | |
| if n_lines == 1 and texts[0] == " ": | |
| pass # text-to-image without text | |
| else: | |
| raise RuntimeError( | |
| f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images" | |
| ) | |
| elif len(pos_imgs) > n_lines: | |
| str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." | |
| # get pre_pos, poly_list, hint that needed for anytext | |
| pre_pos = [] | |
| poly_list = [] | |
| for input_pos in pos_imgs: | |
| if input_pos.mean() != 0: | |
| input_pos = ( | |
| input_pos[..., np.newaxis] | |
| if len(input_pos.shape) == 2 | |
| else input_pos | |
| ) | |
| poly, pos_img = self.find_polygon(input_pos) | |
| pre_pos += [pos_img / 255.0] | |
| poly_list += [poly] | |
| else: | |
| pre_pos += [np.zeros((h, w, 1))] | |
| poly_list += [None] | |
| np_hint = np.sum(pre_pos, axis=0).clip(0, 1) | |
| # prepare info dict | |
| info = {} | |
| info["glyphs"] = [] | |
| info["gly_line"] = [] | |
| info["positions"] = [] | |
| info["n_lines"] = [len(texts)] * img_count | |
| gly_pos_imgs = [] | |
| for i in range(len(texts)): | |
| text = texts[i] | |
| if len(text) > max_chars: | |
| str_warning = ( | |
| f'"{text}" length > max_chars: {max_chars}, will be cut off...' | |
| ) | |
| text = text[:max_chars] | |
| gly_scale = 2 | |
| if pre_pos[i].mean() != 0: | |
| gly_line = draw_glyph(self.font, text) | |
| glyphs = draw_glyph2( | |
| self.font, | |
| text, | |
| poly_list[i], | |
| scale=gly_scale, | |
| width=w, | |
| height=h, | |
| add_space=False, | |
| ) | |
| gly_pos_img = cv2.drawContours( | |
| glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1 | |
| ) | |
| if revise_pos: | |
| resize_gly = cv2.resize( | |
| glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]) | |
| ) | |
| new_pos = cv2.morphologyEx( | |
| (resize_gly * 255).astype(np.uint8), | |
| cv2.MORPH_CLOSE, | |
| kernel=np.ones( | |
| (resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), | |
| dtype=np.uint8, | |
| ), | |
| iterations=1, | |
| ) | |
| new_pos = ( | |
| new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos | |
| ) | |
| contours, _ = cv2.findContours( | |
| new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE | |
| ) | |
| if len(contours) != 1: | |
| str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." | |
| else: | |
| rect = cv2.minAreaRect(contours[0]) | |
| poly = np.int0(cv2.boxPoints(rect)) | |
| pre_pos[i] = ( | |
| cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 | |
| ) | |
| gly_pos_img = cv2.drawContours( | |
| glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1 | |
| ) | |
| gly_pos_imgs += [gly_pos_img] # for show | |
| else: | |
| glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) | |
| gly_line = np.zeros((80, 512, 1)) | |
| gly_pos_imgs += [ | |
| np.zeros((h * gly_scale, w * gly_scale, 1)) | |
| ] # for show | |
| pos = pre_pos[i] | |
| info["glyphs"] += [self.arr2tensor(glyphs, img_count)] | |
| info["gly_line"] += [self.arr2tensor(gly_line, img_count)] | |
| info["positions"] += [self.arr2tensor(pos, img_count)] | |
| # get masked_x | |
| masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) | |
| masked_img = np.transpose(masked_img, (2, 0, 1)) | |
| masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) | |
| if self.use_fp16: | |
| masked_img = masked_img.half() | |
| encoder_posterior = self.model.encode_first_stage(masked_img[None, ...]) | |
| masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach() | |
| if self.use_fp16: | |
| masked_x = masked_x.half() | |
| info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0) | |
| hint = self.arr2tensor(np_hint, img_count) | |
| cond = self.model.get_learned_conditioning( | |
| dict( | |
| c_concat=[hint], | |
| c_crossattn=[[prompt] * img_count], | |
| text_info=info, | |
| ) | |
| ) | |
| un_cond = self.model.get_learned_conditioning( | |
| dict( | |
| c_concat=[hint], | |
| c_crossattn=[[negative_prompt] * img_count], | |
| text_info=info, | |
| ) | |
| ) | |
| shape = (4, h // 8, w // 8) | |
| self.model.control_scales = [strength] * 13 | |
| samples, intermediates = self.ddim_sampler.sample( | |
| ddim_steps, | |
| img_count, | |
| shape, | |
| cond, | |
| verbose=False, | |
| eta=eta, | |
| unconditional_guidance_scale=cfg_scale, | |
| unconditional_conditioning=un_cond, | |
| callback=callback | |
| ) | |
| if self.use_fp16: | |
| samples = samples.half() | |
| x_samples = self.model.decode_first_stage(samples) | |
| x_samples = ( | |
| (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5) | |
| .cpu() | |
| .numpy() | |
| .clip(0, 255) | |
| .astype(np.uint8) | |
| ) | |
| results = [x_samples[i] for i in range(img_count)] | |
| # if ( | |
| # mode == "edit" and False | |
| # ): # replace backgound in text editing but not ideal yet | |
| # results = [r * np_hint + edit_image * (1 - np_hint) for r in results] | |
| # results = [r.clip(0, 255).astype(np.uint8) for r in results] | |
| # if len(gly_pos_imgs) > 0 and show_debug: | |
| # glyph_bs = np.stack(gly_pos_imgs, axis=2) | |
| # glyph_img = np.sum(glyph_bs, axis=2) * 255 | |
| # glyph_img = glyph_img.clip(0, 255).astype(np.uint8) | |
| # results += [np.repeat(glyph_img, 3, axis=2)] | |
| rst_code = 1 if str_warning else 0 | |
| return results, rst_code, str_warning | |
| def modify_prompt(self, prompt): | |
| prompt = prompt.replace("“", '"') | |
| prompt = prompt.replace("”", '"') | |
| p = '"(.*?)"' | |
| strs = re.findall(p, prompt) | |
| if len(strs) == 0: | |
| strs = [" "] | |
| else: | |
| for s in strs: | |
| prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1) | |
| # if self.is_chinese(prompt): | |
| # if self.trans_pipe is None: | |
| # return None, None | |
| # old_prompt = prompt | |
| # prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1] | |
| # print(f"Translate: {old_prompt} --> {prompt}") | |
| return prompt, strs | |
| # def is_chinese(self, text): | |
| # text = checker._clean_text(text) | |
| # for char in text: | |
| # cp = ord(char) | |
| # if checker._is_chinese_char(cp): | |
| # return True | |
| # return False | |
| def separate_pos_imgs(self, img, sort_priority, gap=102): | |
| num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) | |
| components = [] | |
| for label in range(1, num_labels): | |
| component = np.zeros_like(img) | |
| component[labels == label] = 255 | |
| components.append((component, centroids[label])) | |
| if sort_priority == "y": | |
| fir, sec = 1, 0 # top-down first | |
| elif sort_priority == "x": | |
| fir, sec = 0, 1 # left-right first | |
| components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) | |
| sorted_components = [c[0] for c in components] | |
| return sorted_components | |
| def find_polygon(self, image, min_rect=False): | |
| contours, hierarchy = cv2.findContours( | |
| image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE | |
| ) | |
| max_contour = max(contours, key=cv2.contourArea) # get contour with max area | |
| if min_rect: | |
| # get minimum enclosing rectangle | |
| rect = cv2.minAreaRect(max_contour) | |
| poly = np.int0(cv2.boxPoints(rect)) | |
| else: | |
| # get approximate polygon | |
| epsilon = 0.01 * cv2.arcLength(max_contour, True) | |
| poly = cv2.approxPolyDP(max_contour, epsilon, True) | |
| n, _, xy = poly.shape | |
| poly = poly.reshape(n, xy) | |
| cv2.drawContours(image, [poly], -1, 255, -1) | |
| return poly, image | |
| def arr2tensor(self, arr, bs): | |
| arr = np.transpose(arr, (2, 0, 1)) | |
| _arr = torch.from_numpy(arr.copy()).float().to(self.device) | |
| if self.use_fp16: | |
| _arr = _arr.half() | |
| _arr = torch.stack([_arr for _ in range(bs)], dim=0) | |
| return _arr | |