import numpy as np import torch from torchvision import transforms from tqdm import tqdm from PIL import Image import soundfile as sf from mel_module import Mel class Generator: def __init__(self, config, unet, scheduler, vae, embedding, progress_callback=None): self.config = config self.unet = unet self.scheduler = scheduler self.vae = vae self.embedding = embedding self.progress_callback = progress_callback def tensor_to_mel(self, tensor): denormalize = transforms.Normalize( mean=[-m/s for m, s in zip([0.5], [0.5])], std=[1/s for s in [0.5]] ) dn_tensor= denormalize(tensor.detach().cpu()) s = np.array(dn_tensor.squeeze())*255 return Mel(spectrogram=s) def generate(self): with torch.no_grad(): uncond_image = torch.zeros((1, 1, self.config.image_size, self.config.image_size), device=self.config.device) mu, log_var = self.vae.encode(uncond_image) uncond_latent = torch.cat((mu, log_var), dim=1) uncond_latent = uncond_latent.unsqueeze(0) print("uncond", uncond_latent.shape) embeddings = torch.cat([uncond_latent, self.embedding]) generator = torch.Generator(device=self.config.device) noise = torch.randn( (1, 1, self.config.image_size, self.config.image_size), generator=generator, device=self.config.device, ) total_steps = len(self.scheduler.timesteps) for i, t in enumerate(self.progress_callback.tqdm(self.scheduler.timesteps)): image_model_input = torch.cat([noise] * 2) image_model_input = self.scheduler.scale_model_input(image_model_input, timestep=t) with torch.no_grad(): noise_pred = self.unet(image_model_input, t, encoder_hidden_states=embeddings).sample noise_pred_uncond, noise_pred_img = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.config.guidance_scale * (noise_pred_img - noise_pred_uncond) noise = self.scheduler.step(noise_pred, t, noise).prev_sample image_tensor = noise.squeeze(1) # [1, 512, 512] mel = self.tensor_to_mel(image_tensor) mel.save_audio()