Edit model card
# class-conditional-diffusion-cub-200

A Diffusion model on Cub 200 dataset for generating bird images.

## Usage Predict function to generate images
```python

  def load_model(model_path, device):
      # Initialize the same model architecture as during training
      model = ClassConditionedUnet().to(device)
      
      # Load the trained weights
      model.load_state_dict(torch.load(model_path))
      
      # Set model to evaluation mode
      model.eval()
      
      return model
  
  
  def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
      model.eval()  # Ensure the model is in evaluation mode
      
      # Prepare a batch of random noise as input
      shape = (num_samples, 3, 256, 256)  # Input shape: (batch_size, channels, height, width)
      noisy_image = torch.randn(shape).to(device)
      
      # Ensure class_label is a tensor and properly repeated for the batch
      class_labels = torch.tensor([class_label] * num_samples, dtype=torch.long).to(device)
  
      # Reverse the diffusion process step by step
      for t in tqdm(range(49, -1, -1), desc="Reverse Diffusion Steps"):  # Iterate backwards through timesteps
          t_tensor = torch.tensor([t], dtype=torch.long).to(device)  # Single time step for the batch
          
          # Predict noise with the model and remove it from the image
          with torch.no_grad():
              noise_pred = model(noisy_image, t_tensor.expand(num_samples), class_labels)  # Class conditioning here
          
          # Step with the scheduler (model_output, timestep, sample)
          noisy_image = noise_scheduler.step(noise_pred, t, noisy_image).prev_sample
      
      # Post-process the output to get image values between [0, 1]
      generated_images = (noisy_image + 1) / 2  # Rescale from [-1, 1] to [0, 1]
      
      return generated_images
  
  
  def display_images(images, num_rows=2):
      # Create a grid of images
      grid = torchvision.utils.make_grid(images, nrow=num_rows)
      np_grid = grid.permute(1, 2, 0).cpu().numpy()  # Convert to (H, W, C) format for visualization
      
      # Plot the images
      plt.figure(figsize=(12, 6))
      plt.imshow(np.clip(np_grid, 0, 1))  # Clip values to ensure valid range
      plt.axis('off')
      plt.show()
```

Example of loading a model and generating predictions

```python
model_path = "model_epoch_0.pth"  # Path to your saved model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = load_model(model_path, device)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
class_label = 1  # Example class label, change to your desired class
generated_images = predict(model, class_label, noise_scheduler, num_samples=2, device=device)
display_images(generated_images)
```
Downloads last month
0
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train muneebable/class-conditional-diffusion-cub-200