Spaces:
Runtime error
Runtime error
| # Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils | |
| # Apache-2.0 License | |
| # By lllyasviel | |
| import os | |
| import cv2 | |
| import json | |
| import random | |
| import glob | |
| import torch | |
| import einops | |
| import numpy as np | |
| import datetime | |
| import torchvision | |
| from PIL import Image | |
| def min_resize(x, m): | |
| if x.shape[0] < x.shape[1]: | |
| s0 = m | |
| s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1])) | |
| else: | |
| s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0])) | |
| s1 = m | |
| new_max = max(s1, s0) | |
| raw_max = max(x.shape[0], x.shape[1]) | |
| if new_max < raw_max: | |
| interpolation = cv2.INTER_AREA | |
| else: | |
| interpolation = cv2.INTER_LANCZOS4 | |
| y = cv2.resize(x, (s1, s0), interpolation=interpolation) | |
| return y | |
| def d_resize(x, y): | |
| H, W, C = y.shape | |
| new_min = min(H, W) | |
| raw_min = min(x.shape[0], x.shape[1]) | |
| if new_min < raw_min: | |
| interpolation = cv2.INTER_AREA | |
| else: | |
| interpolation = cv2.INTER_LANCZOS4 | |
| y = cv2.resize(x, (W, H), interpolation=interpolation) | |
| return y | |
| def resize_and_center_crop(image, target_width, target_height): | |
| if target_height == image.shape[0] and target_width == image.shape[1]: | |
| return image | |
| pil_image = Image.fromarray(image) | |
| original_width, original_height = pil_image.size | |
| scale_factor = max(target_width / original_width, target_height / original_height) | |
| resized_width = int(round(original_width * scale_factor)) | |
| resized_height = int(round(original_height * scale_factor)) | |
| resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) | |
| left = (resized_width - target_width) / 2 | |
| top = (resized_height - target_height) / 2 | |
| right = (resized_width + target_width) / 2 | |
| bottom = (resized_height + target_height) / 2 | |
| cropped_image = resized_image.crop((left, top, right, bottom)) | |
| return np.array(cropped_image) | |
| def resize_and_center_crop_pytorch(image, target_width, target_height): | |
| B, C, H, W = image.shape | |
| if H == target_height and W == target_width: | |
| return image | |
| scale_factor = max(target_width / W, target_height / H) | |
| resized_width = int(round(W * scale_factor)) | |
| resized_height = int(round(H * scale_factor)) | |
| resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False) | |
| top = (resized_height - target_height) // 2 | |
| left = (resized_width - target_width) // 2 | |
| cropped = resized[:, :, top:top + target_height, left:left + target_width] | |
| return cropped | |
| def resize_without_crop(image, target_width, target_height): | |
| if target_height == image.shape[0] and target_width == image.shape[1]: | |
| return image | |
| pil_image = Image.fromarray(image) | |
| resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) | |
| return np.array(resized_image) | |
| def just_crop(image, w, h): | |
| if h == image.shape[0] and w == image.shape[1]: | |
| return image | |
| original_height, original_width = image.shape[:2] | |
| k = min(original_height / h, original_width / w) | |
| new_width = int(round(w * k)) | |
| new_height = int(round(h * k)) | |
| x_start = (original_width - new_width) // 2 | |
| y_start = (original_height - new_height) // 2 | |
| cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width] | |
| return cropped_image | |
| def write_to_json(data, file_path): | |
| temp_file_path = file_path + ".tmp" | |
| with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: | |
| json.dump(data, temp_file, indent=4) | |
| os.replace(temp_file_path, file_path) | |
| return | |
| def read_from_json(file_path): | |
| with open(file_path, 'rt', encoding='utf-8') as file: | |
| data = json.load(file) | |
| return data | |
| def get_active_parameters(m): | |
| return {k: v for k, v in m.named_parameters() if v.requires_grad} | |
| def cast_training_params(m, dtype=torch.float32): | |
| result = {} | |
| for n, param in m.named_parameters(): | |
| if param.requires_grad: | |
| param.data = param.to(dtype) | |
| result[n] = param | |
| return result | |
| def separate_lora_AB(parameters, B_patterns=None): | |
| parameters_normal = {} | |
| parameters_B = {} | |
| if B_patterns is None: | |
| B_patterns = ['.lora_B.', '__zero__'] | |
| for k, v in parameters.items(): | |
| if any(B_pattern in k for B_pattern in B_patterns): | |
| parameters_B[k] = v | |
| else: | |
| parameters_normal[k] = v | |
| return parameters_normal, parameters_B | |
| def set_attr_recursive(obj, attr, value): | |
| attrs = attr.split(".") | |
| for name in attrs[:-1]: | |
| obj = getattr(obj, name) | |
| setattr(obj, attrs[-1], value) | |
| return | |
| def print_tensor_list_size(tensors): | |
| total_size = 0 | |
| total_elements = 0 | |
| if isinstance(tensors, dict): | |
| tensors = tensors.values() | |
| for tensor in tensors: | |
| total_size += tensor.nelement() * tensor.element_size() | |
| total_elements += tensor.nelement() | |
| total_size_MB = total_size / (1024 ** 2) | |
| total_elements_B = total_elements / 1e9 | |
| print(f"Total number of tensors: {len(tensors)}") | |
| print(f"Total size of tensors: {total_size_MB:.2f} MB") | |
| print(f"Total number of parameters: {total_elements_B:.3f} billion") | |
| return | |
| def batch_mixture(a, b=None, probability_a=0.5, mask_a=None): | |
| batch_size = a.size(0) | |
| if b is None: | |
| b = torch.zeros_like(a) | |
| if mask_a is None: | |
| mask_a = torch.rand(batch_size) < probability_a | |
| mask_a = mask_a.to(a.device) | |
| mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) | |
| result = torch.where(mask_a, a, b) | |
| return result | |
| def zero_module(module): | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| def supress_lower_channels(m, k, alpha=0.01): | |
| data = m.weight.data.clone() | |
| assert int(data.shape[1]) >= k | |
| data[:, :k] = data[:, :k] * alpha | |
| m.weight.data = data.contiguous().clone() | |
| return m | |
| def freeze_module(m): | |
| if not hasattr(m, '_forward_inside_frozen_module'): | |
| m._forward_inside_frozen_module = m.forward | |
| m.requires_grad_(False) | |
| m.forward = torch.no_grad()(m.forward) | |
| return m | |
| def get_latest_safetensors(folder_path): | |
| safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors')) | |
| if not safetensors_files: | |
| raise ValueError('No file to resume!') | |
| latest_file = max(safetensors_files, key=os.path.getmtime) | |
| latest_file = os.path.abspath(os.path.realpath(latest_file)) | |
| return latest_file | |
| def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): | |
| tags = tags_str.split(', ') | |
| tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) | |
| prompt = ', '.join(tags) | |
| return prompt | |
| def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0): | |
| numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma) | |
| if round_to_int: | |
| numbers = np.round(numbers).astype(int) | |
| return numbers.tolist() | |
| def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False): | |
| edges = np.linspace(0, 1, n + 1) | |
| points = np.random.uniform(edges[:-1], edges[1:]) | |
| numbers = inclusive + (exclusive - inclusive) * points | |
| if round_to_int: | |
| numbers = np.round(numbers).astype(int) | |
| return numbers.tolist() | |
| def soft_append_bcthw(history, current, overlap=0): | |
| if overlap <= 0: | |
| return torch.cat([history, current], dim=2) | |
| assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})" | |
| assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})" | |
| weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1) | |
| blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap] | |
| output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2) | |
| return output.to(history) | |
| def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0): | |
| b, c, t, h, w = x.shape | |
| per_row = b | |
| for p in [6, 5, 4, 3, 2]: | |
| if b % p == 0: | |
| per_row = p | |
| break | |
| os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) | |
| x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 | |
| x = x.detach().cpu().to(torch.uint8) | |
| x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) | |
| torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))}) | |
| return x | |
| def save_bcthw_as_png(x, output_filename): | |
| os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) | |
| x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 | |
| x = x.detach().cpu().to(torch.uint8) | |
| x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') | |
| torchvision.io.write_png(x, output_filename) | |
| return output_filename | |
| def save_bchw_as_png(x, output_filename): | |
| os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) | |
| x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 | |
| x = x.detach().cpu().to(torch.uint8) | |
| x = einops.rearrange(x, 'b c h w -> c h (b w)') | |
| torchvision.io.write_png(x, output_filename) | |
| return output_filename | |
| def add_tensors_with_padding(tensor1, tensor2): | |
| if tensor1.shape == tensor2.shape: | |
| return tensor1 + tensor2 | |
| shape1 = tensor1.shape | |
| shape2 = tensor2.shape | |
| new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) | |
| padded_tensor1 = torch.zeros(new_shape) | |
| padded_tensor2 = torch.zeros(new_shape) | |
| padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 | |
| padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 | |
| result = padded_tensor1 + padded_tensor2 | |
| return result | |
| def print_free_mem(): | |
| torch.cuda.empty_cache() | |
| free_mem, total_mem = torch.cuda.mem_get_info(0) | |
| free_mem_mb = free_mem / (1024 ** 2) | |
| total_mem_mb = total_mem / (1024 ** 2) | |
| print(f"Free memory: {free_mem_mb:.2f} MB") | |
| print(f"Total memory: {total_mem_mb:.2f} MB") | |
| return | |
| def print_gpu_parameters(device, state_dict, log_count=1): | |
| summary = {"device": device, "keys_count": len(state_dict)} | |
| logged_params = {} | |
| for i, (key, tensor) in enumerate(state_dict.items()): | |
| if i >= log_count: | |
| break | |
| logged_params[key] = tensor.flatten()[:3].tolist() | |
| summary["params"] = logged_params | |
| print(str(summary)) | |
| return | |
| def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18): | |
| from PIL import Image, ImageDraw, ImageFont | |
| txt = Image.new("RGB", (width, height), color="white") | |
| draw = ImageDraw.Draw(txt) | |
| font = ImageFont.truetype(font_path, size=size) | |
| if text == '': | |
| return np.array(txt) | |
| # Split text into lines that fit within the image width | |
| lines = [] | |
| words = text.split() | |
| current_line = words[0] | |
| for word in words[1:]: | |
| line_with_word = f"{current_line} {word}" | |
| if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width: | |
| current_line = line_with_word | |
| else: | |
| lines.append(current_line) | |
| current_line = word | |
| lines.append(current_line) | |
| # Draw the text line by line | |
| y = 0 | |
| line_height = draw.textbbox((0, 0), "A", font=font)[3] | |
| for line in lines: | |
| if y + line_height > height: | |
| break # stop drawing if the next line will be outside the image | |
| draw.text((0, y), line, fill="black", font=font) | |
| y += line_height | |
| return np.array(txt) | |
| def blue_mark(x): | |
| x = x.copy() | |
| c = x[:, :, 2] | |
| b = cv2.blur(c, (9, 9)) | |
| x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1) | |
| return x | |
| def green_mark(x): | |
| x = x.copy() | |
| x[:, :, 2] = -1 | |
| x[:, :, 0] = -1 | |
| return x | |
| def frame_mark(x): | |
| x = x.copy() | |
| x[:64] = -1 | |
| x[-64:] = -1 | |
| x[:, :8] = 1 | |
| x[:, -8:] = 1 | |
| return x | |
| def pytorch2numpy(imgs): | |
| results = [] | |
| for x in imgs: | |
| y = x.movedim(0, -1) | |
| y = y * 127.5 + 127.5 | |
| y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) | |
| results.append(y) | |
| return results | |
| def numpy2pytorch(imgs): | |
| h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 | |
| h = h.movedim(-1, 1) | |
| return h | |
| def duplicate_prefix_to_suffix(x, count, zero_out=False): | |
| if zero_out: | |
| return torch.cat([x, torch.zeros_like(x[:count])], dim=0) | |
| else: | |
| return torch.cat([x, x[:count]], dim=0) | |
| def weighted_mse(a, b, weight): | |
| return torch.mean(weight.float() * (a.float() - b.float()) ** 2) | |
| def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0): | |
| x = (x - x_min) / (x_max - x_min) | |
| x = max(0.0, min(x, 1.0)) | |
| x = x ** sigma | |
| return y_min + x * (y_max - y_min) | |
| def expand_to_dims(x, target_dims): | |
| return x.view(*x.shape, *([1] * max(0, target_dims - x.dim()))) | |
| def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int): | |
| if tensor is None: | |
| return None | |
| first_dim = tensor.shape[0] | |
| if first_dim == batch_size: | |
| return tensor | |
| if batch_size % first_dim != 0: | |
| raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.") | |
| repeat_times = batch_size // first_dim | |
| return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1)) | |
| def dim5(x): | |
| return expand_to_dims(x, 5) | |
| def dim4(x): | |
| return expand_to_dims(x, 4) | |
| def dim3(x): | |
| return expand_to_dims(x, 3) | |
| def crop_or_pad_yield_mask(x, length): | |
| B, F, C = x.shape | |
| device = x.device | |
| dtype = x.dtype | |
| if F < length: | |
| y = torch.zeros((B, length, C), dtype=dtype, device=device) | |
| mask = torch.zeros((B, length), dtype=torch.bool, device=device) | |
| y[:, :F, :] = x | |
| mask[:, :F] = True | |
| return y, mask | |
| return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device) | |
| def extend_dim(x, dim, minimal_length, zero_pad=False): | |
| original_length = int(x.shape[dim]) | |
| if original_length >= minimal_length: | |
| return x | |
| if zero_pad: | |
| padding_shape = list(x.shape) | |
| padding_shape[dim] = minimal_length - original_length | |
| padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device) | |
| else: | |
| idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1) | |
| last_element = x[idx] | |
| padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim) | |
| return torch.cat([x, padding], dim=dim) | |
| def lazy_positional_encoding(t, repeats=None): | |
| if not isinstance(t, list): | |
| t = [t] | |
| from diffusers.models.embeddings import get_timestep_embedding | |
| te = torch.tensor(t) | |
| te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0) | |
| if repeats is None: | |
| return te | |
| te = te[:, None, :].expand(-1, repeats, -1) | |
| return te | |
| def state_dict_offset_merge(A, B, C=None): | |
| result = {} | |
| keys = A.keys() | |
| for key in keys: | |
| A_value = A[key] | |
| B_value = B[key].to(A_value) | |
| if C is None: | |
| result[key] = A_value + B_value | |
| else: | |
| C_value = C[key].to(A_value) | |
| result[key] = A_value + B_value - C_value | |
| return result | |
| def state_dict_weighted_merge(state_dicts, weights): | |
| if len(state_dicts) != len(weights): | |
| raise ValueError("Number of state dictionaries must match number of weights") | |
| if not state_dicts: | |
| return {} | |
| total_weight = sum(weights) | |
| if total_weight == 0: | |
| raise ValueError("Sum of weights cannot be zero") | |
| normalized_weights = [w / total_weight for w in weights] | |
| keys = state_dicts[0].keys() | |
| result = {} | |
| for key in keys: | |
| result[key] = state_dicts[0][key] * normalized_weights[0] | |
| for i in range(1, len(state_dicts)): | |
| state_dict_value = state_dicts[i][key].to(result[key]) | |
| result[key] += state_dict_value * normalized_weights[i] | |
| return result | |
| def group_files_by_folder(all_files): | |
| grouped_files = {} | |
| for file in all_files: | |
| folder_name = os.path.basename(os.path.dirname(file)) | |
| if folder_name not in grouped_files: | |
| grouped_files[folder_name] = [] | |
| grouped_files[folder_name].append(file) | |
| list_of_lists = list(grouped_files.values()) | |
| return list_of_lists | |
| def generate_timestamp(): | |
| now = datetime.datetime.now() | |
| timestamp = now.strftime('%y%m%d_%H%M%S') | |
| milliseconds = f"{int(now.microsecond / 1000):03d}" | |
| random_number = random.randint(0, 9999) | |
| return f"{timestamp}_{milliseconds}_{random_number}" | |
| def write_PIL_image_with_png_info(image, metadata, path): | |
| from PIL.PngImagePlugin import PngInfo | |
| png_info = PngInfo() | |
| for key, value in metadata.items(): | |
| png_info.add_text(key, value) | |
| image.save(path, "PNG", pnginfo=png_info) | |
| return image | |
| def torch_safe_save(content, path): | |
| torch.save(content, path + '_tmp') | |
| os.replace(path + '_tmp', path) | |
| return path | |
| def move_optimizer_to_device(optimizer, device): | |
| for state in optimizer.state.values(): | |
| for k, v in state.items(): | |
| if isinstance(v, torch.Tensor): | |
| state[k] = v.to(device) | |