Sheng-Yong Niu commited on
Commit
70a5d5c
·
verified ·
1 Parent(s): b33b5c3

Upload 7 files

Browse files
Files changed (2) hide show
  1. example.py +30 -9
  2. requirements.txt +3 -1
example.py CHANGED
@@ -2,8 +2,10 @@ import torch
2
  import numpy as np
3
  import pandas as pd
4
  import json
 
5
 
6
  from dawo_wrapper import DAWOWrapper
 
7
 
8
  print("DAWO Model Example: Drug Response Prediction")
9
  print("============================================")
@@ -15,7 +17,7 @@ model = DAWOWrapper(repo_path="./")
15
  # Load data files from the data folder
16
  print("\n2. Loading drug and cell line features...")
17
 
18
- # Set data directory (use local data directory)
19
  data_dir = "./data"
20
 
21
  # Drug feature components
@@ -43,8 +45,8 @@ print(f" Shape: {cell_features.shape}")
43
  # Select sample drug and cell line
44
  print("\n3. Preparing inputs for prediction:")
45
 
46
- # Select a drug for demonstration
47
- sample_drug = list(drug_name.keys())[0]
48
  print(f" - Selected drug: {sample_drug}")
49
 
50
  # Create complete drug feature vector by concatenating the three embedding types
@@ -58,8 +60,8 @@ drug_feature = np.concatenate((
58
  drug_features = torch.tensor(drug_feature, dtype=torch.float32).unsqueeze(0) # Add batch dimension
59
  print(f" Combined drug feature shape: {drug_features.shape}")
60
 
61
- # Select a cell line for demonstration
62
- sample_cell = cell_features.index[0]
63
  print(f" - Selected cell line: {sample_cell}")
64
 
65
  # Create cell line feature vector
@@ -75,12 +77,31 @@ padded_features[0, :cell_features_tensor.shape[1]] = cell_features_tensor
75
  cell_features_tensor = padded_features
76
  print(f" Padded cell feature shape: {cell_features_tensor.shape}")
77
 
78
- # Create simulated gene expression data (normally this would be real data)
79
- print(" - Creating sample gene expression data...")
80
- gene_expression = torch.randn(1, 5000) # 1 sample, 5000 genes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  print(f" Gene expression shape: {gene_expression.shape}")
82
 
83
- # Run prediction with prepared data
84
  print("\n4. Running prediction with DAWO model...")
85
  results = model.predict(gene_expression, drug_features, cell_features_tensor)
86
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import json
5
+ import scanpy as sc
6
 
7
  from dawo_wrapper import DAWOWrapper
8
+ from dawo import Anndata_to_Tensor
9
 
10
  print("DAWO Model Example: Drug Response Prediction")
11
  print("============================================")
 
17
  # Load data files from the data folder
18
  print("\n2. Loading drug and cell line features...")
19
 
20
+ # Set data directory
21
  data_dir = "./data"
22
 
23
  # Drug feature components
 
45
  # Select sample drug and cell line
46
  print("\n3. Preparing inputs for prediction:")
47
 
48
+ # Select a drug for demonstration - use Dabrafenib
49
+ sample_drug = "Dabrafenib"
50
  print(f" - Selected drug: {sample_drug}")
51
 
52
  # Create complete drug feature vector by concatenating the three embedding types
 
60
  drug_features = torch.tensor(drug_feature, dtype=torch.float32).unsqueeze(0) # Add batch dimension
61
  print(f" Combined drug feature shape: {drug_features.shape}")
62
 
63
+ # Select the MIA PaCa-2 cell line
64
+ sample_cell = "MIA PaCa-2"
65
  print(f" - Selected cell line: {sample_cell}")
66
 
67
  # Create cell line feature vector
 
77
  cell_features_tensor = padded_features
78
  print(f" Padded cell feature shape: {cell_features_tensor.shape}")
79
 
80
+ # Load gene expression data
81
+ print("\n - Loading real gene expression data for Dabrafenib on MIA_PaCa-2 cell line...")
82
+ adata = sc.read_h5ad(f'{data_dir}/Dabrafenib.MIA_PaCa-2.h5ad')
83
+ print(f" AnnData object shape: {adata.shape}")
84
+
85
+ # Preprocess gene expression
86
+ print(" - Preprocessing gene expression data...")
87
+ if 'highly_variable' not in adata.var:
88
+ print(" Selecting top 5000 highly variable genes...")
89
+ sc.pp.highly_variable_genes(adata, n_top_genes=5000)
90
+ adata = adata[:, adata.var.highly_variable]
91
+ else:
92
+ print(" Using pre-identified highly variable genes...")
93
+ if adata.shape[1] > 5000:
94
+ print(" Subsetting to 5000 genes...")
95
+ adata = adata[:, 0:5000]
96
+
97
+ # Convert to tensor
98
+ print(" - Converting gene expression to tensor...")
99
+ gene_expression = Anndata_to_Tensor(adata)
100
+ if len(gene_expression.shape) == 1:
101
+ gene_expression = gene_expression.unsqueeze(0) # Add batch dimension
102
  print(f" Gene expression shape: {gene_expression.shape}")
103
 
104
+ # Run prediction
105
  print("\n4. Running prediction with DAWO model...")
106
  results = model.predict(gene_expression, drug_features, cell_features_tensor)
107
 
requirements.txt CHANGED
@@ -2,4 +2,6 @@ torch>=1.10.0
2
  numpy>=1.20.0
3
  pandas>=1.3.0
4
  scipy>=1.7.0
5
- pyarrow>=7.0.0
 
 
 
2
  numpy>=1.20.0
3
  pandas>=1.3.0
4
  scipy>=1.7.0
5
+ pyarrow>=7.0.0
6
+ scanpy>=1.9.0
7
+ anndata>=0.8.0