Sheng-Yong Niu commited on
Commit
b33b5c3
·
verified ·
1 Parent(s): 1c3639c

Upload 7 files

Browse files
Files changed (7) hide show
  1. Dawo_model.ipynb +0 -0
  2. README.md +153 -0
  3. config.json +11 -0
  4. dawo.py +122 -0
  5. dawo_wrapper.py +60 -0
  6. example.py +99 -0
  7. requirements.txt +5 -0
Dawo_model.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: code
3
+ license: mit
4
+ library_name: pytorch
5
+ tags:
6
+ - variational-autoencoder
7
+ - drug-response
8
+ - vae
9
+ - cancer-drug
10
+ - tahoe-deepdive
11
+ datasets:
12
+ - biomedical
13
+ - tahoebio/Tahoe-100M
14
+ ---
15
+
16
+
17
+ # DAWO: Drug-Aware and Cell-line-Aware Variational Autoencoder
18
+
19
+
20
+ [![tahoe-deepdive](https://img.shields.io/badge/tag-tahoe--deepdive-blue)](https://huggingface.co/datasets/tahoebio/Tahoe-100M)
21
+
22
+
23
+ ## Team Name
24
+ DAWO
25
+
26
+
27
+ ## Members
28
+ - Yuhan Hao
29
+ - Sheng-Yong Niu
30
+ - Jaanak Prashar
31
+ - Tiange (Alex) Cui
32
+ - Danila Bredikhin
33
+ - Mikaela Koutrouli
34
+
35
+
36
+ ## Project
37
+
38
+
39
+ ### Title
40
+ DAWO: Drug-Aware and Cell-line-Aware Variational Autoencoder for Drug Response Prediction
41
+
42
+
43
+ ### Overview
44
+ DAWO is a specialized Variational Autoencoder (VAE) designed to predict drug responses in cancer cell lines by integrating gene expression data with drug and cell line features. The model leverages multi-modal representation learning to capture complex interactions between drugs and cells, enabling more accurate prediction of drug responses across diverse conditions.
45
+
46
+
47
+ ### Motivation
48
+ Understanding and predicting how cancer cells respond to different therapeutic compounds is crucial for advancing precision medicine approaches in oncology. Traditional methods often fail to capture the complex relationships between drugs, cell lines, and their molecular profiles. DAWO addresses this challenge by combining a VAE architecture with drug-aware and cell-line-aware components to model these interactions effectively.
49
+
50
+
51
+ ### Methods
52
+ DAWO incorporates a multi-modal architecture with the following key components:
53
+
54
+
55
+ 1. **Gene Expression Encoder**: Processes normalized gene expression data from cancer cell lines (input dimension: 5000)
56
+ 2. **Drug Feature Encoder**: Processes drug features combining:
57
+ - Drug summary embeddings
58
+ - ChemBERTa molecular structure embeddings
59
+ - Semantic feature embeddings
60
+ (Total input dimension: 3122)
61
+ 3. **Cell Line Feature Encoder**: Processes cell line features focusing on driver gene mutations and other genomic characteristics (input dimension: 113)
62
+ 4. **Latent Space**: A 50-dimensional latent representation combining drug, cell line, and gene expression information
63
+ 5. **Decoder**: Reconstructs gene expression profiles from the latent representation
64
+ 6. **Classifier**: Predicts drug response categories from the latent representation (379 classes)
65
+
66
+
67
+ The model was trained using a combined loss function that balances reconstruction accuracy, latent space regularization, and classification performance.
68
+
69
+
70
+ ### Results
71
+ DAWO demonstrates strong performance in predicting drug responses across multiple cancer cell lines, with particular strength in:
72
+
73
+
74
+ 1. Distinguishing between responsive and non-responsive cell lines for specific drugs
75
+ 2. Generalizing to new drug-cell line combinations not seen during training
76
+ 3. Capturing meaningful biological signals in the latent space that reflect known drug mechanisms and cellular pathways
77
+
78
+
79
+ ### Discussion
80
+ Our model provides a powerful framework for drug response prediction that could accelerate drug discovery and repurposing efforts. The integration of multi-modal data (gene expression, drug features, cell line characteristics) enables DAWO to capture complex interaction patterns that simpler models miss.
81
+
82
+
83
+ Limitations include the need for comprehensive feature sets for new drugs and cell lines, and potential biases from the training data distribution. Future work will focus on incorporating additional molecular modalities and expanding the training data to improve generalization across diverse drug classes.
84
+
85
+
86
+ ## Model Description
87
+ Using a variational autoencoder (VAE) approach, DAWO learns latent representations of these data sources and combines them to predict drug responses and identify potential drug-cell line interactions.
88
+
89
+
90
+ ## Model Inputs and Outputs
91
+
92
+
93
+ ### Inputs:
94
+ - **Gene Expression Data**: Normalized gene expression profiles (shape: [batch_size, 5000])
95
+ - **Drug Features**: Combined drug embeddings including:
96
+ - Drug summary embeddings
97
+ - ChemBERTa molecular structure embeddings
98
+ - Semantic feature embeddings
99
+ (Total shape: [batch_size, 3122])
100
+ - **Cell Line Features**: Cell line genomic profiles (shape: [batch_size, 113])
101
+
102
+
103
+ ### Outputs:
104
+ - **Reconstructed Gene Expression**: Reconstructed expression profiles (shape: [batch_size, 5000])
105
+ - **Latent Representation**: Compressed representation in latent space (shape: [batch_size, 50])
106
+ - **Drug Response Predictions**: Predicted response classes (shape: [batch_size, 379])
107
+ - **Response Probabilities**: Softmax probabilities for each response class (shape: [batch_size, 379])
108
+
109
+
110
+ ## How to Use
111
+
112
+
113
+ ```python
114
+ from dawo_wrapper import DAWOWrapper
115
+
116
+
117
+ # Initialize model
118
+ model = DAWOWrapper(repo_path="path/to/model")
119
+
120
+
121
+ # Prepare inputs
122
+ # gene_expression: tensor of shape [batch_size, 5000]
123
+ # drug_features: tensor of shape [batch_size, 3122]
124
+ # cell_features: tensor of shape [batch_size, 113]
125
+
126
+
127
+ # Make predictions
128
+ results = model.predict(gene_expression, drug_features, cell_features)
129
+
130
+
131
+ # Access outputs
132
+ reconstructed_expression = results["x_hat"]
133
+ latent_representation = results["mu"]
134
+ drug_response_predictions = results["y_pred"]
135
+ response_probabilities = results["probs"]
136
+ ```
137
+
138
+
139
+ ## Dataset
140
+ This model was developed using the [Tahoe-100M](https://huggingface.co/datasets/tahoebio/Tahoe-100M) dataset as part of the Tahoe-DeepDive Hackathon 2025.
141
+
142
+
143
+ ## License
144
+ MIT License
145
+
146
+ Copyright (c) 2023 Team DAWO
147
+
148
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
149
+
150
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
151
+
152
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
153
+
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_dim_X": 5000,
3
+ "input_dim_Y": 3122,
4
+ "input_dim_Z": 113,
5
+ "latent_dim": 50,
6
+ "latent_dim_mid": 500,
7
+ "Y_emb": 50,
8
+ "Z_emb": 50,
9
+ "num_classes": 379,
10
+ "beta": 0.1
11
+ }
dawo.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import scipy.sparse as sp
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import Dataset, TensorDataset
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def Anndata_to_Tensor(adata, label=None, label_continuous= None ,batch=None, device='cpu'):
12
+ # sparse matrix to tensor
13
+ if isinstance(adata.X, (sp.csr.csr_matrix, sp.csc.csc_matrix)):
14
+ X_tensor = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
15
+ else:
16
+ X_tensor = torch.tensor(adata.X, dtype=torch.float32).to(device)
17
+
18
+ tensors = {'X_tensor': X_tensor}
19
+
20
+ if label is not None:
21
+ labels_num, _ = pd.factorize(adata.obs[label], sort=True)
22
+ tensors['labels_num'] = torch.tensor(labels_num, dtype=torch.long)
23
+
24
+ if label_continuous is not None:
25
+ tensors['label_continuous'] = torch.tensor(adata.obs[label_continuous], dtype=torch.float64)
26
+
27
+ if batch is not None:
28
+ batch_one_hot = pd.get_dummies(adata.obs[batch]).to_numpy()
29
+ tensors['batch_one_hot'] = torch.from_numpy(batch_one_hot)
30
+
31
+ if len(tensors) == 1 and 'X_tensor' in tensors:
32
+ return tensors['X_tensor']
33
+ else:
34
+ # return TensorDataset with available tensors
35
+ return TensorDataset(*tensors.values())
36
+
37
+
38
+ def loss_function(x_hat, x, mu, logvar, β=0.1):
39
+ BCE = nn.functional.mse_loss(
40
+ x_hat, x.view(-1, x_hat.shape[1]), reduction='sum'
41
+ )
42
+ KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
43
+
44
+ return BCE+ β * KLD
45
+
46
+ class DAWO(nn.Module):
47
+ def __init__(self, input_dim_X, input_dim_Y, input_dim_Z, latent_dim_mid=500, latent_dim=50, Y_emb=50, Z_emb=50, num_classes=10):
48
+ super(DAWO, self).__init__()
49
+
50
+ self.encoder = nn.Sequential(
51
+ nn.BatchNorm1d(input_dim_X),
52
+ nn.Linear(input_dim_X, latent_dim_mid),
53
+ nn.ReLU(),
54
+ nn.Dropout(0.2),
55
+ nn.Linear(latent_dim_mid, latent_dim * 2),
56
+ )
57
+
58
+ self.encoder_Y = nn.Sequential(
59
+ nn.BatchNorm1d(input_dim_Y),
60
+ nn.Linear(input_dim_Y, latent_dim_mid),
61
+ nn.ReLU(),
62
+ nn.Dropout(0.2),
63
+ nn.Linear(latent_dim_mid, Y_emb),
64
+ )
65
+
66
+ self.encoder_Z = nn.Sequential(
67
+ nn.BatchNorm1d(input_dim_Z),
68
+ nn.Linear(input_dim_Z, latent_dim_mid),
69
+ nn.ReLU(),
70
+ nn.Dropout(0.2),
71
+ nn.Linear(latent_dim_mid, Z_emb),
72
+ )
73
+
74
+
75
+
76
+ self.decoder = nn.Sequential(
77
+ nn.Linear(latent_dim + Y_emb + Z_emb, latent_dim_mid),
78
+ nn.ReLU(),
79
+ nn.Dropout(0.2),
80
+ nn.Linear(latent_dim_mid, input_dim_X),
81
+ )
82
+
83
+ self.classifier = nn.Sequential(
84
+ nn.BatchNorm1d(latent_dim + Z_emb),
85
+ nn.Linear(latent_dim + Z_emb, 256),
86
+ nn.ReLU(),
87
+ nn.Dropout(0.2),
88
+ nn.Linear(256, 128),
89
+ nn.ReLU(),
90
+ nn.Dropout(0.2),
91
+ nn.Linear(128, num_classes)
92
+ )
93
+
94
+ self.input_dim = input_dim_X
95
+ self.input_dim_Y = input_dim_Y
96
+ self.input_dim_Z = input_dim_Z
97
+ self.latent_dim = latent_dim
98
+
99
+ def reparameterise(self, mu, logvar):
100
+ if self.training:
101
+ std = logvar.mul(0.5).exp_()
102
+ eps = torch.randn_like(std)
103
+ return eps.mul(std).add_(mu)
104
+ else:
105
+ return mu
106
+
107
+ def forward(self, x, y, z):
108
+ mu_logvar = self.encoder(x.view(-1, self.input_dim)).view(-1, 2, self.latent_dim)
109
+ l_y = self.encoder_Y(y.view(-1, self.input_dim_Y))
110
+ l_z = self.encoder_Z(z.view(-1, self.input_dim_Z))
111
+
112
+ mu = mu_logvar[:, 0, :]
113
+ logvar = mu_logvar[:, 1, :]
114
+ l_x = self.reparameterise(mu, logvar)
115
+
116
+ l_xyz = torch.cat((l_x, l_y, l_z), dim=1)
117
+ l_xz = torch.cat((l_x, l_z), dim=1)
118
+
119
+ x_hat = self.decoder(l_xyz)
120
+ y_pred = self.classifier(l_xz)
121
+
122
+ return x_hat, mu, logvar, y_pred
dawo_wrapper.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ from dawo import DAWO, loss_function, Anndata_to_Tensor
6
+
7
+
8
+ class DAWOWrapper:
9
+ """
10
+ Minimal wrapper for DAWO model to use with Hugging Face Hub
11
+ """
12
+ def __init__(self, repo_path):
13
+ """
14
+ Initialize the DAWO model
15
+
16
+ Args:
17
+ repo_path: Path to repository with model files
18
+ """
19
+ # Load configuration
20
+ config_path = os.path.join(repo_path, "config.json")
21
+ with open(config_path, 'r') as f:
22
+ config = json.load(f)
23
+
24
+ # Create model with original DAWO class
25
+ self.model = DAWO(
26
+ input_dim_X=config["input_dim_X"],
27
+ input_dim_Y=config["input_dim_Y"],
28
+ input_dim_Z=config["input_dim_Z"],
29
+ latent_dim=config["latent_dim"],
30
+ Y_emb=config["Y_emb"],
31
+ Z_emb=config["Z_emb"],
32
+ num_classes=config["num_classes"]
33
+ )
34
+
35
+ # Load weights
36
+ self.model.load_state_dict(torch.load(os.path.join(repo_path, "model.pth")))
37
+ self.model.eval()
38
+
39
+ def predict(self, x, y, z):
40
+ """
41
+ Make predictions with the DAWO model
42
+
43
+ Args:
44
+ x: Gene expression tensor (batch_size, input_dim_X)
45
+ y: Drug feature tensor (batch_size, input_dim_Y)
46
+ z: Cell line feature tensor (batch_size, input_dim_Z)
47
+
48
+ Returns:
49
+ Dict with model outputs
50
+ """
51
+ with torch.no_grad():
52
+ x_hat, mu, logvar, y_pred = self.model(x, y, z)
53
+
54
+ return {
55
+ "x_hat": x_hat, # Reconstructed gene expression
56
+ "mu": mu, # Latent mean
57
+ "logvar": logvar, # Latent log variance
58
+ "y_pred": y_pred, # Drug response predictions
59
+ "probs": torch.softmax(y_pred, dim=1) # Drug response probabilities
60
+ }
example.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pandas as pd
4
+ import json
5
+
6
+ from dawo_wrapper import DAWOWrapper
7
+
8
+ print("DAWO Model Example: Drug Response Prediction")
9
+ print("============================================")
10
+
11
+ # Initialize the model
12
+ print("\n1. Loading the DAWO model...")
13
+ model = DAWOWrapper(repo_path="./")
14
+
15
+ # Load data files from the data folder
16
+ print("\n2. Loading drug and cell line features...")
17
+
18
+ # Set data directory (use local data directory)
19
+ data_dir = "./data"
20
+
21
+ # Drug feature components
22
+ print(" - Loading drug semantic features...")
23
+ drug_semantic = pd.read_csv(f'{data_dir}/semantic_features_combined.csv', index_col='drug')
24
+ print(f" Shape: {drug_semantic.shape}")
25
+
26
+ print(" - Loading drug structure embeddings...")
27
+ drug_structure = pd.read_csv(f'{data_dir}/chemberta_cls_embeddings.csv', index_col='drug')
28
+ print(f" Shape: {drug_structure.shape}")
29
+
30
+ print(" - Loading drug summary embeddings...")
31
+ with open(f'{data_dir}/drug_summaries.json', 'r') as f:
32
+ drug_name = json.load(f)
33
+ drug_emb = np.load(f'{data_dir}/drug_summary_lowd.npy')
34
+ print(f" Shape: {drug_emb.shape}")
35
+
36
+ # Cell line features
37
+ print(" - Loading cell line driver gene mutation profiles...")
38
+ cell_features = pd.read_parquet(f'{data_dir}/drivergene_cellline_matrix.parquet')
39
+ cell_features.index = cell_features['cell_name']
40
+ cell_features.drop(columns=['cell_name'], inplace=True)
41
+ print(f" Shape: {cell_features.shape}")
42
+
43
+ # Select sample drug and cell line
44
+ print("\n3. Preparing inputs for prediction:")
45
+
46
+ # Select a drug for demonstration
47
+ sample_drug = list(drug_name.keys())[0]
48
+ print(f" - Selected drug: {sample_drug}")
49
+
50
+ # Create complete drug feature vector by concatenating the three embedding types
51
+ print(" - Constructing drug feature vector...")
52
+ drug_idx = list(drug_name.keys()).index(sample_drug)
53
+ drug_feature = np.concatenate((
54
+ drug_emb[drug_idx], # Drug summary embedding
55
+ drug_structure.loc[sample_drug].values, # Molecular structure embedding
56
+ drug_semantic.loc[sample_drug].values # Semantic feature embedding
57
+ ))
58
+ drug_features = torch.tensor(drug_feature, dtype=torch.float32).unsqueeze(0) # Add batch dimension
59
+ print(f" Combined drug feature shape: {drug_features.shape}")
60
+
61
+ # Select a cell line for demonstration
62
+ sample_cell = cell_features.index[0]
63
+ print(f" - Selected cell line: {sample_cell}")
64
+
65
+ # Create cell line feature vector
66
+ print(" - Constructing cell line feature vector...")
67
+ cell_feature = cell_features.loc[sample_cell].values
68
+ cell_features_tensor = torch.tensor(cell_feature, dtype=torch.float32).unsqueeze(0) # Add batch dimension
69
+ print(f" Original cell feature shape: {cell_features_tensor.shape}")
70
+
71
+ # Pad cell features to match the expected dimension (113)
72
+ print(" - Padding cell features to match model dimensions...")
73
+ padded_features = torch.zeros((1, 113), dtype=torch.float32)
74
+ padded_features[0, :cell_features_tensor.shape[1]] = cell_features_tensor
75
+ cell_features_tensor = padded_features
76
+ print(f" Padded cell feature shape: {cell_features_tensor.shape}")
77
+
78
+ # Create simulated gene expression data (normally this would be real data)
79
+ print(" - Creating sample gene expression data...")
80
+ gene_expression = torch.randn(1, 5000) # 1 sample, 5000 genes
81
+ print(f" Gene expression shape: {gene_expression.shape}")
82
+
83
+ # Run prediction with prepared data
84
+ print("\n4. Running prediction with DAWO model...")
85
+ results = model.predict(gene_expression, drug_features, cell_features_tensor)
86
+
87
+ # Print results
88
+ print("\n5. Results:")
89
+ print(f" - Reconstructed gene expression shape: {results['x_hat'].shape}")
90
+ print(f" - Latent representation shape: {results['mu'].shape}")
91
+ print(f" - Drug response prediction shape: {results['y_pred'].shape}")
92
+ print(f" - Response probabilities shape: {results['probs'].shape}")
93
+
94
+ # Show top predicted classes
95
+ print("\n - Top predicted drug response classes:")
96
+ probs = results['probs'].squeeze().numpy()
97
+ top3_indices = np.argsort(probs)[-3:][::-1]
98
+ for i, idx in enumerate(top3_indices):
99
+ print(f" Class {idx}: {probs[idx]:.4f} probability")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=1.10.0
2
+ numpy>=1.20.0
3
+ pandas>=1.3.0
4
+ scipy>=1.7.0
5
+ pyarrow>=7.0.0