File size: 3,956 Bytes
b33b5c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset
import torch.nn.functional as F
def Anndata_to_Tensor(adata, label=None, label_continuous= None ,batch=None, device='cpu'):
# sparse matrix to tensor
if isinstance(adata.X, (sp.csr.csr_matrix, sp.csc.csc_matrix)):
X_tensor = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
else:
X_tensor = torch.tensor(adata.X, dtype=torch.float32).to(device)
tensors = {'X_tensor': X_tensor}
if label is not None:
labels_num, _ = pd.factorize(adata.obs[label], sort=True)
tensors['labels_num'] = torch.tensor(labels_num, dtype=torch.long)
if label_continuous is not None:
tensors['label_continuous'] = torch.tensor(adata.obs[label_continuous], dtype=torch.float64)
if batch is not None:
batch_one_hot = pd.get_dummies(adata.obs[batch]).to_numpy()
tensors['batch_one_hot'] = torch.from_numpy(batch_one_hot)
if len(tensors) == 1 and 'X_tensor' in tensors:
return tensors['X_tensor']
else:
# return TensorDataset with available tensors
return TensorDataset(*tensors.values())
def loss_function(x_hat, x, mu, logvar, β=0.1):
BCE = nn.functional.mse_loss(
x_hat, x.view(-1, x_hat.shape[1]), reduction='sum'
)
KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
return BCE+ β * KLD
class DAWO(nn.Module):
def __init__(self, input_dim_X, input_dim_Y, input_dim_Z, latent_dim_mid=500, latent_dim=50, Y_emb=50, Z_emb=50, num_classes=10):
super(DAWO, self).__init__()
self.encoder = nn.Sequential(
nn.BatchNorm1d(input_dim_X),
nn.Linear(input_dim_X, latent_dim_mid),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(latent_dim_mid, latent_dim * 2),
)
self.encoder_Y = nn.Sequential(
nn.BatchNorm1d(input_dim_Y),
nn.Linear(input_dim_Y, latent_dim_mid),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(latent_dim_mid, Y_emb),
)
self.encoder_Z = nn.Sequential(
nn.BatchNorm1d(input_dim_Z),
nn.Linear(input_dim_Z, latent_dim_mid),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(latent_dim_mid, Z_emb),
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim + Y_emb + Z_emb, latent_dim_mid),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(latent_dim_mid, input_dim_X),
)
self.classifier = nn.Sequential(
nn.BatchNorm1d(latent_dim + Z_emb),
nn.Linear(latent_dim + Z_emb, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
self.input_dim = input_dim_X
self.input_dim_Y = input_dim_Y
self.input_dim_Z = input_dim_Z
self.latent_dim = latent_dim
def reparameterise(self, mu, logvar):
if self.training:
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def forward(self, x, y, z):
mu_logvar = self.encoder(x.view(-1, self.input_dim)).view(-1, 2, self.latent_dim)
l_y = self.encoder_Y(y.view(-1, self.input_dim_Y))
l_z = self.encoder_Z(z.view(-1, self.input_dim_Z))
mu = mu_logvar[:, 0, :]
logvar = mu_logvar[:, 1, :]
l_x = self.reparameterise(mu, logvar)
l_xyz = torch.cat((l_x, l_y, l_z), dim=1)
l_xz = torch.cat((l_x, l_z), dim=1)
x_hat = self.decoder(l_xyz)
y_pred = self.classifier(l_xz)
return x_hat, mu, logvar, y_pred
|