Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import argparse | |
| import os | |
| from PIL import Image | |
| import torch | |
| from torchvision.transforms import Resize, ToTensor | |
| from diffusers import AutoencoderKL | |
| from pytorch_fid import fid_score | |
| from skimage.metrics import peak_signal_noise_ratio as psnr | |
| import lpips | |
| from tqdm import tqdm | |
| from torchvision import transforms | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_images(folder_path): | |
| images = [] | |
| for filename in os.listdir(folder_path): | |
| if filename.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| img_path = os.path.join(folder_path, filename) | |
| images.append(img_path) | |
| return images | |
| def paramiter_count(model): | |
| state_dict = model.state_dict() | |
| paramiter_count = 0 | |
| for key in state_dict: | |
| paramiter_count += torch.numel(state_dict[key]) | |
| return int(paramiter_count) | |
| def calculate_metrics(vae, images, max_imgs=-1, save_output=False): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| vae = vae.to(device) | |
| lpips_model = lpips.LPIPS(net='alex').to(device) | |
| rfid_scores = [] | |
| psnr_scores = [] | |
| lpips_scores = [] | |
| # transform = transforms.Compose([ | |
| # transforms.Resize(256, antialias=True), | |
| # transforms.CenterCrop(256) | |
| # ]) | |
| # needs values between -1 and 1 | |
| to_tensor = ToTensor() | |
| # remove _reconstructed.png files | |
| images = [img for img in images if not img.endswith("_reconstructed.png")] | |
| if max_imgs > 0 and len(images) > max_imgs: | |
| images = images[:max_imgs] | |
| for img_path in tqdm(images): | |
| try: | |
| img = Image.open(img_path).convert('RGB') | |
| # img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device) | |
| img_tensor = to_tensor(img).unsqueeze(0).to(device) | |
| img_tensor = 2 * img_tensor - 1 | |
| # if width or height is not divisible by 8, crop it | |
| if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0: | |
| img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8] | |
| except Exception as e: | |
| print(f"Error processing {img_path}: {e}") | |
| continue | |
| with torch.no_grad(): | |
| reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample | |
| # Calculate rFID | |
| # rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed) | |
| # rfid_scores.append(rfid) | |
| # Calculate PSNR | |
| psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy()) | |
| psnr_scores.append(psnr_val) | |
| # Calculate LPIPS | |
| lpips_val = lpips_model(img_tensor, reconstructed).item() | |
| lpips_scores.append(lpips_val) | |
| # avg_rfid = sum(rfid_scores) / len(rfid_scores) | |
| avg_rfid = 0 | |
| avg_psnr = sum(psnr_scores) / len(psnr_scores) | |
| avg_lpips = sum(lpips_scores) / len(lpips_scores) | |
| if save_output: | |
| filename_no_ext = os.path.splitext(os.path.basename(img_path))[0] | |
| folder = os.path.dirname(img_path) | |
| save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png") | |
| reconstructed = (reconstructed + 1) / 2 | |
| reconstructed = reconstructed.clamp(0, 1) | |
| reconstructed = transforms.ToPILImage()(reconstructed[0].cpu()) | |
| reconstructed.save(save_path) | |
| return avg_rfid, avg_psnr, avg_lpips | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions") | |
| parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model") | |
| parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images") | |
| parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.") | |
| # boolean store true | |
| parser.add_argument("--save_output", action="store_true", help="Save the output images") | |
| args = parser.parse_args() | |
| if os.path.isfile(args.vae_path): | |
| vae = AutoencoderKL.from_single_file(args.vae_path) | |
| else: | |
| try: | |
| vae = AutoencoderKL.from_pretrained(args.vae_path) | |
| except: | |
| vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") | |
| vae.eval() | |
| vae = vae.to(device) | |
| print(f"Model has {paramiter_count(vae)} parameters") | |
| images = load_images(args.image_folder) | |
| avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output) | |
| # print(f"Average rFID: {avg_rfid}") | |
| print(f"Average PSNR: {avg_psnr}") | |
| print(f"Average LPIPS: {avg_lpips}") | |
| if __name__ == "__main__": | |
| main() | |
