gongjing commited on
Commit
fd7ce66
·
verified ·
1 Parent(s): 01a5204

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +82 -0
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetuning AIDO.Tissue for spatial single cell downstream tasks
2
+ In this file, we introduce how to finetune and evaluate our pre-trained AIDO.Tissue foundation models for downstream tasks. These tasks can be classified into the following categories:
3
+
4
+ * **Sequence-level classification tasks**: niche label type prediction
5
+ * **Sequence-level regression tasks**: cell density prediction
6
+
7
+ Note: All the following scripts should be run under `ModelGenerator/`.
8
+
9
+ ## Download data
10
+ The related data is deposited at https://huggingface.co/datasets/genbio-ai/tissue-downstream-tasks. Please download the data and put under `ModelGenerator/downloads` as `cell_density` or `niche_type_classification`. Under each sub-directory, there are three files denote different split (xx.train.h5ad, xx.val.h5ad, xx.test.h5ad).
11
+
12
+ For each `.h5ad`, several obs attributes should be included to reprezent the spatial (coordinate) information (like `x`, `y`), the label information (like `niche_label`). All the column fields will be specified in the following `config.yaml` file.
13
+
14
+ Note: the file `scRNA_genename_and_index.tsv` includes all the corresponding gene name and index in h5ad file.
15
+
16
+ ## Sequence-level classification tasks
17
+ ### niche label type prediction
18
+ We fully finetune AIDO.Tissue for niche label type prediction.
19
+
20
+
21
+ #### Finetuning script
22
+ ```shell
23
+ CUDA_VISIBLE_DEVICES=7 nohup mgen fit --config experiments/AIDO.Tissue/niche_type_classfification.yaml > logs/nohup/AIDO.Tissue.niche_type_classfification.yaml.log 2>&1 &
24
+ ```
25
+
26
+ Note:
27
+
28
+ The `filter_columns` includes label column and spatial coordinate column. `rename_columns` keep unchanged and will be used for running.
29
+
30
+
31
+ #### Evaluation script
32
+
33
+ Once finished run, there will be several `ckpt` file under the specified output directory `default_root_dir`. Then we can use the `ckpt` to evaluate on test dataset.
34
+
35
+ ```shell
36
+ CUDA_VISIBLE_DEVICES=6 nohup mgen test --config experiments/AIDO.Tissue/niche_type_classfification.yaml \
37
+ --ckpt_path ckpt_path \
38
+ > ckpt_path.pred.log 2>&1 &
39
+ ```
40
+
41
+ Note: `ckpt_path` is the finetuned checkpoint path.
42
+
43
+
44
+ ## Sequence-level regression tasks
45
+
46
+ ### cell density prediction
47
+
48
+ The config file is like `experiments/AIDO.Tissue/cell_density_regression.yaml`, all the fintuning running and evaluation are similar as classification task.
49
+
50
+ ## Dump embedding
51
+
52
+ We can dump embedding for a `.h5ad` file. The script is as:
53
+
54
+ ```shell
55
+ CUDA_VISIBLE_DEVICES=3 nohup mgen predict --config experiments/AIDO.Tissue/emb.xenium.yaml > logs/nohup/AIDO.Tissue.emb.xenium.log 2>&1 &
56
+ ```
57
+
58
+ The output file will be under specified `output_dir` like `./logs/emb.xenium/lightning_logs/pred_output`. Each batch will be saved and a merged one will also be generated as `predict_predictions.pt`. The `predict_predictions.pt` file satcks all batches:
59
+
60
+ ```shell
61
+ >>> import torch
62
+ >>> file_all = 'predict_predictions.pt'
63
+ >>> d_all = torch.load(file_all, map_location='cpu')
64
+ >>> d_all.keys()
65
+ dict_keys(['predictions', 'ids'])
66
+ >>> len(d_all['predictions']) # this equal to #sample
67
+ 586
68
+ >>> len(d_all['ids']) # ids are numeric index corresponding to .h5ad file
69
+ 586
70
+ >>> d_all['predictions'].shape # (B, L, D), L is max sequence length of all samples
71
+ torch.Size([586, 90, 128])
72
+ ```
73
+
74
+ We can retrieve all the gene embedding and aggregate into cell embedding (like max pooling):
75
+
76
+ ```bash
77
+ >>> d_all_maxpooling = [d_all['predictions'][i,:,:] for i in range(d_all['predictions'].shape[0])]
78
+ >>> d_all_maxpooling = [i[~torch.any(i.isnan(), dim=1)] for i in d_all_maxpooling]
79
+ >>> d_all_maxpooling = torch.cat([i.max(dim=0)[0].view(1,-1) for i in d_all_maxpooling])
80
+ >>> d_all_maxpooling.shape
81
+ torch.Size([586, 128])
82
+ ```