|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
import json |
|
import scanpy as sc |
|
|
|
from dawo_wrapper import DAWOWrapper |
|
from dawo import Anndata_to_Tensor |
|
|
|
print("DAWO Model Example: Drug Response Prediction") |
|
print("============================================") |
|
|
|
|
|
print("\n1. Loading the DAWO model...") |
|
model = DAWOWrapper(repo_path="./") |
|
|
|
|
|
print("\n2. Loading drug and cell line features...") |
|
|
|
|
|
data_dir = "./data" |
|
|
|
|
|
print(" - Loading drug semantic features...") |
|
drug_semantic = pd.read_csv(f'{data_dir}/semantic_features_combined.csv', index_col='drug') |
|
print(f" Shape: {drug_semantic.shape}") |
|
|
|
print(" - Loading drug structure embeddings...") |
|
drug_structure = pd.read_csv(f'{data_dir}/chemberta_cls_embeddings.csv', index_col='drug') |
|
print(f" Shape: {drug_structure.shape}") |
|
|
|
print(" - Loading drug summary embeddings...") |
|
with open(f'{data_dir}/drug_summaries.json', 'r') as f: |
|
drug_name = json.load(f) |
|
drug_emb = np.load(f'{data_dir}/drug_summary_lowd.npy') |
|
print(f" Shape: {drug_emb.shape}") |
|
|
|
|
|
print(" - Loading cell line driver gene mutation profiles...") |
|
cell_features = pd.read_parquet(f'{data_dir}/drivergene_cellline_matrix.parquet') |
|
cell_features.index = cell_features['cell_name'] |
|
cell_features.drop(columns=['cell_name'], inplace=True) |
|
print(f" Shape: {cell_features.shape}") |
|
|
|
|
|
print("\n3. Preparing inputs for prediction:") |
|
|
|
|
|
sample_drug = "Dabrafenib" |
|
print(f" - Selected drug: {sample_drug}") |
|
|
|
|
|
print(" - Constructing drug feature vector...") |
|
drug_idx = list(drug_name.keys()).index(sample_drug) |
|
drug_feature = np.concatenate(( |
|
drug_emb[drug_idx], |
|
drug_structure.loc[sample_drug].values, |
|
drug_semantic.loc[sample_drug].values |
|
)) |
|
drug_features = torch.tensor(drug_feature, dtype=torch.float32).unsqueeze(0) |
|
print(f" Combined drug feature shape: {drug_features.shape}") |
|
|
|
|
|
sample_cell = "MIA PaCa-2" |
|
print(f" - Selected cell line: {sample_cell}") |
|
|
|
|
|
print(" - Constructing cell line feature vector...") |
|
cell_feature = cell_features.loc[sample_cell].values |
|
cell_features_tensor = torch.tensor(cell_feature, dtype=torch.float32).unsqueeze(0) |
|
print(f" Original cell feature shape: {cell_features_tensor.shape}") |
|
|
|
|
|
print(" - Padding cell features to match model dimensions...") |
|
padded_features = torch.zeros((1, 113), dtype=torch.float32) |
|
padded_features[0, :cell_features_tensor.shape[1]] = cell_features_tensor |
|
cell_features_tensor = padded_features |
|
print(f" Padded cell feature shape: {cell_features_tensor.shape}") |
|
|
|
|
|
print("\n - Loading real gene expression data for Dabrafenib on MIA_PaCa-2 cell line...") |
|
adata = sc.read_h5ad(f'{data_dir}/Dabrafenib.MIA_PaCa-2.h5ad') |
|
print(f" AnnData object shape: {adata.shape}") |
|
|
|
|
|
print(" - Preprocessing gene expression data...") |
|
if 'highly_variable' not in adata.var: |
|
print(" Selecting top 5000 highly variable genes...") |
|
sc.pp.highly_variable_genes(adata, n_top_genes=5000) |
|
adata = adata[:, adata.var.highly_variable] |
|
else: |
|
print(" Using pre-identified highly variable genes...") |
|
if adata.shape[1] > 5000: |
|
print(" Subsetting to 5000 genes...") |
|
adata = adata[:, 0:5000] |
|
|
|
|
|
print(" - Converting gene expression to tensor...") |
|
gene_expression = Anndata_to_Tensor(adata) |
|
if len(gene_expression.shape) == 1: |
|
gene_expression = gene_expression.unsqueeze(0) |
|
print(f" Gene expression shape: {gene_expression.shape}") |
|
|
|
|
|
print("\n4. Running prediction with DAWO model...") |
|
results = model.predict(gene_expression, drug_features, cell_features_tensor) |
|
|
|
|
|
print("\n5. Results:") |
|
print(f" - Reconstructed gene expression shape: {results['x_hat'].shape}") |
|
print(f" - Latent representation shape: {results['mu'].shape}") |
|
print(f" - Drug response prediction shape: {results['y_pred'].shape}") |
|
print(f" - Response probabilities shape: {results['probs'].shape}") |
|
|
|
|
|
print("\n - Top predicted drug response classes:") |
|
probs = results['probs'].squeeze().numpy() |
|
top3_indices = np.argsort(probs)[-3:][::-1] |
|
for i, idx in enumerate(top3_indices): |
|
print(f" Class {idx}: {probs[idx]:.4f} probability") |