dawo / example.py
Sheng-Yong Niu
Upload 7 files
70a5d5c verified
raw
history blame
4.81 kB
import torch
import numpy as np
import pandas as pd
import json
import scanpy as sc
from dawo_wrapper import DAWOWrapper
from dawo import Anndata_to_Tensor
print("DAWO Model Example: Drug Response Prediction")
print("============================================")
# Initialize the model
print("\n1. Loading the DAWO model...")
model = DAWOWrapper(repo_path="./")
# Load data files from the data folder
print("\n2. Loading drug and cell line features...")
# Set data directory
data_dir = "./data"
# Drug feature components
print(" - Loading drug semantic features...")
drug_semantic = pd.read_csv(f'{data_dir}/semantic_features_combined.csv', index_col='drug')
print(f" Shape: {drug_semantic.shape}")
print(" - Loading drug structure embeddings...")
drug_structure = pd.read_csv(f'{data_dir}/chemberta_cls_embeddings.csv', index_col='drug')
print(f" Shape: {drug_structure.shape}")
print(" - Loading drug summary embeddings...")
with open(f'{data_dir}/drug_summaries.json', 'r') as f:
drug_name = json.load(f)
drug_emb = np.load(f'{data_dir}/drug_summary_lowd.npy')
print(f" Shape: {drug_emb.shape}")
# Cell line features
print(" - Loading cell line driver gene mutation profiles...")
cell_features = pd.read_parquet(f'{data_dir}/drivergene_cellline_matrix.parquet')
cell_features.index = cell_features['cell_name']
cell_features.drop(columns=['cell_name'], inplace=True)
print(f" Shape: {cell_features.shape}")
# Select sample drug and cell line
print("\n3. Preparing inputs for prediction:")
# Select a drug for demonstration - use Dabrafenib
sample_drug = "Dabrafenib"
print(f" - Selected drug: {sample_drug}")
# Create complete drug feature vector by concatenating the three embedding types
print(" - Constructing drug feature vector...")
drug_idx = list(drug_name.keys()).index(sample_drug)
drug_feature = np.concatenate((
drug_emb[drug_idx], # Drug summary embedding
drug_structure.loc[sample_drug].values, # Molecular structure embedding
drug_semantic.loc[sample_drug].values # Semantic feature embedding
))
drug_features = torch.tensor(drug_feature, dtype=torch.float32).unsqueeze(0) # Add batch dimension
print(f" Combined drug feature shape: {drug_features.shape}")
# Select the MIA PaCa-2 cell line
sample_cell = "MIA PaCa-2"
print(f" - Selected cell line: {sample_cell}")
# Create cell line feature vector
print(" - Constructing cell line feature vector...")
cell_feature = cell_features.loc[sample_cell].values
cell_features_tensor = torch.tensor(cell_feature, dtype=torch.float32).unsqueeze(0) # Add batch dimension
print(f" Original cell feature shape: {cell_features_tensor.shape}")
# Pad cell features to match the expected dimension (113)
print(" - Padding cell features to match model dimensions...")
padded_features = torch.zeros((1, 113), dtype=torch.float32)
padded_features[0, :cell_features_tensor.shape[1]] = cell_features_tensor
cell_features_tensor = padded_features
print(f" Padded cell feature shape: {cell_features_tensor.shape}")
# Load gene expression data
print("\n - Loading real gene expression data for Dabrafenib on MIA_PaCa-2 cell line...")
adata = sc.read_h5ad(f'{data_dir}/Dabrafenib.MIA_PaCa-2.h5ad')
print(f" AnnData object shape: {adata.shape}")
# Preprocess gene expression
print(" - Preprocessing gene expression data...")
if 'highly_variable' not in adata.var:
print(" Selecting top 5000 highly variable genes...")
sc.pp.highly_variable_genes(adata, n_top_genes=5000)
adata = adata[:, adata.var.highly_variable]
else:
print(" Using pre-identified highly variable genes...")
if adata.shape[1] > 5000:
print(" Subsetting to 5000 genes...")
adata = adata[:, 0:5000]
# Convert to tensor
print(" - Converting gene expression to tensor...")
gene_expression = Anndata_to_Tensor(adata)
if len(gene_expression.shape) == 1:
gene_expression = gene_expression.unsqueeze(0) # Add batch dimension
print(f" Gene expression shape: {gene_expression.shape}")
# Run prediction
print("\n4. Running prediction with DAWO model...")
results = model.predict(gene_expression, drug_features, cell_features_tensor)
# Print results
print("\n5. Results:")
print(f" - Reconstructed gene expression shape: {results['x_hat'].shape}")
print(f" - Latent representation shape: {results['mu'].shape}")
print(f" - Drug response prediction shape: {results['y_pred'].shape}")
print(f" - Response probabilities shape: {results['probs'].shape}")
# Show top predicted classes
print("\n - Top predicted drug response classes:")
probs = results['probs'].squeeze().numpy()
top3_indices = np.argsort(probs)[-3:][::-1]
for i, idx in enumerate(top3_indices):
print(f" Class {idx}: {probs[idx]:.4f} probability")