Sheng-Yong Niu
commited on
Upload 7 files
Browse files- example.py +30 -9
- requirements.txt +3 -1
example.py
CHANGED
@@ -2,8 +2,10 @@ import torch
|
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
import json
|
|
|
5 |
|
6 |
from dawo_wrapper import DAWOWrapper
|
|
|
7 |
|
8 |
print("DAWO Model Example: Drug Response Prediction")
|
9 |
print("============================================")
|
@@ -15,7 +17,7 @@ model = DAWOWrapper(repo_path="./")
|
|
15 |
# Load data files from the data folder
|
16 |
print("\n2. Loading drug and cell line features...")
|
17 |
|
18 |
-
# Set data directory
|
19 |
data_dir = "./data"
|
20 |
|
21 |
# Drug feature components
|
@@ -43,8 +45,8 @@ print(f" Shape: {cell_features.shape}")
|
|
43 |
# Select sample drug and cell line
|
44 |
print("\n3. Preparing inputs for prediction:")
|
45 |
|
46 |
-
# Select a drug for demonstration
|
47 |
-
sample_drug =
|
48 |
print(f" - Selected drug: {sample_drug}")
|
49 |
|
50 |
# Create complete drug feature vector by concatenating the three embedding types
|
@@ -58,8 +60,8 @@ drug_feature = np.concatenate((
|
|
58 |
drug_features = torch.tensor(drug_feature, dtype=torch.float32).unsqueeze(0) # Add batch dimension
|
59 |
print(f" Combined drug feature shape: {drug_features.shape}")
|
60 |
|
61 |
-
# Select
|
62 |
-
sample_cell =
|
63 |
print(f" - Selected cell line: {sample_cell}")
|
64 |
|
65 |
# Create cell line feature vector
|
@@ -75,12 +77,31 @@ padded_features[0, :cell_features_tensor.shape[1]] = cell_features_tensor
|
|
75 |
cell_features_tensor = padded_features
|
76 |
print(f" Padded cell feature shape: {cell_features_tensor.shape}")
|
77 |
|
78 |
-
#
|
79 |
-
print(" -
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
print(f" Gene expression shape: {gene_expression.shape}")
|
82 |
|
83 |
-
# Run prediction
|
84 |
print("\n4. Running prediction with DAWO model...")
|
85 |
results = model.predict(gene_expression, drug_features, cell_features_tensor)
|
86 |
|
|
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
import json
|
5 |
+
import scanpy as sc
|
6 |
|
7 |
from dawo_wrapper import DAWOWrapper
|
8 |
+
from dawo import Anndata_to_Tensor
|
9 |
|
10 |
print("DAWO Model Example: Drug Response Prediction")
|
11 |
print("============================================")
|
|
|
17 |
# Load data files from the data folder
|
18 |
print("\n2. Loading drug and cell line features...")
|
19 |
|
20 |
+
# Set data directory
|
21 |
data_dir = "./data"
|
22 |
|
23 |
# Drug feature components
|
|
|
45 |
# Select sample drug and cell line
|
46 |
print("\n3. Preparing inputs for prediction:")
|
47 |
|
48 |
+
# Select a drug for demonstration - use Dabrafenib
|
49 |
+
sample_drug = "Dabrafenib"
|
50 |
print(f" - Selected drug: {sample_drug}")
|
51 |
|
52 |
# Create complete drug feature vector by concatenating the three embedding types
|
|
|
60 |
drug_features = torch.tensor(drug_feature, dtype=torch.float32).unsqueeze(0) # Add batch dimension
|
61 |
print(f" Combined drug feature shape: {drug_features.shape}")
|
62 |
|
63 |
+
# Select the MIA PaCa-2 cell line
|
64 |
+
sample_cell = "MIA PaCa-2"
|
65 |
print(f" - Selected cell line: {sample_cell}")
|
66 |
|
67 |
# Create cell line feature vector
|
|
|
77 |
cell_features_tensor = padded_features
|
78 |
print(f" Padded cell feature shape: {cell_features_tensor.shape}")
|
79 |
|
80 |
+
# Load gene expression data
|
81 |
+
print("\n - Loading real gene expression data for Dabrafenib on MIA_PaCa-2 cell line...")
|
82 |
+
adata = sc.read_h5ad(f'{data_dir}/Dabrafenib.MIA_PaCa-2.h5ad')
|
83 |
+
print(f" AnnData object shape: {adata.shape}")
|
84 |
+
|
85 |
+
# Preprocess gene expression
|
86 |
+
print(" - Preprocessing gene expression data...")
|
87 |
+
if 'highly_variable' not in adata.var:
|
88 |
+
print(" Selecting top 5000 highly variable genes...")
|
89 |
+
sc.pp.highly_variable_genes(adata, n_top_genes=5000)
|
90 |
+
adata = adata[:, adata.var.highly_variable]
|
91 |
+
else:
|
92 |
+
print(" Using pre-identified highly variable genes...")
|
93 |
+
if adata.shape[1] > 5000:
|
94 |
+
print(" Subsetting to 5000 genes...")
|
95 |
+
adata = adata[:, 0:5000]
|
96 |
+
|
97 |
+
# Convert to tensor
|
98 |
+
print(" - Converting gene expression to tensor...")
|
99 |
+
gene_expression = Anndata_to_Tensor(adata)
|
100 |
+
if len(gene_expression.shape) == 1:
|
101 |
+
gene_expression = gene_expression.unsqueeze(0) # Add batch dimension
|
102 |
print(f" Gene expression shape: {gene_expression.shape}")
|
103 |
|
104 |
+
# Run prediction
|
105 |
print("\n4. Running prediction with DAWO model...")
|
106 |
results = model.predict(gene_expression, drug_features, cell_features_tensor)
|
107 |
|
requirements.txt
CHANGED
@@ -2,4 +2,6 @@ torch>=1.10.0
|
|
2 |
numpy>=1.20.0
|
3 |
pandas>=1.3.0
|
4 |
scipy>=1.7.0
|
5 |
-
pyarrow>=7.0.0
|
|
|
|
|
|
2 |
numpy>=1.20.0
|
3 |
pandas>=1.3.0
|
4 |
scipy>=1.7.0
|
5 |
+
pyarrow>=7.0.0
|
6 |
+
scanpy>=1.9.0
|
7 |
+
anndata>=0.8.0
|