Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Finetuning AIDO.Tissue for spatial single cell downstream tasks
|
2 |
+
In this file, we introduce how to finetune and evaluate our pre-trained AIDO.Tissue foundation models for downstream tasks. These tasks can be classified into the following categories:
|
3 |
+
|
4 |
+
* **Sequence-level classification tasks**: niche label type prediction
|
5 |
+
* **Sequence-level regression tasks**: cell density prediction
|
6 |
+
|
7 |
+
Note: All the following scripts should be run under `ModelGenerator/`.
|
8 |
+
|
9 |
+
## Download data
|
10 |
+
The related data is deposited at https://huggingface.co/datasets/genbio-ai/tissue-downstream-tasks. Please download the data and put under `ModelGenerator/downloads` as `cell_density` or `niche_type_classification`. Under each sub-directory, there are three files denote different split (xx.train.h5ad, xx.val.h5ad, xx.test.h5ad).
|
11 |
+
|
12 |
+
For each `.h5ad`, several obs attributes should be included to reprezent the spatial (coordinate) information (like `x`, `y`), the label information (like `niche_label`). All the column fields will be specified in the following `config.yaml` file.
|
13 |
+
|
14 |
+
Note: the file `scRNA_genename_and_index.tsv` includes all the corresponding gene name and index in h5ad file.
|
15 |
+
|
16 |
+
## Sequence-level classification tasks
|
17 |
+
### niche label type prediction
|
18 |
+
We fully finetune AIDO.Tissue for niche label type prediction.
|
19 |
+
|
20 |
+
|
21 |
+
#### Finetuning script
|
22 |
+
```shell
|
23 |
+
CUDA_VISIBLE_DEVICES=7 nohup mgen fit --config experiments/AIDO.Tissue/niche_type_classfification.yaml > logs/nohup/AIDO.Tissue.niche_type_classfification.yaml.log 2>&1 &
|
24 |
+
```
|
25 |
+
|
26 |
+
Note:
|
27 |
+
|
28 |
+
The `filter_columns` includes label column and spatial coordinate column. `rename_columns` keep unchanged and will be used for running.
|
29 |
+
|
30 |
+
|
31 |
+
#### Evaluation script
|
32 |
+
|
33 |
+
Once finished run, there will be several `ckpt` file under the specified output directory `default_root_dir`. Then we can use the `ckpt` to evaluate on test dataset.
|
34 |
+
|
35 |
+
```shell
|
36 |
+
CUDA_VISIBLE_DEVICES=6 nohup mgen test --config experiments/AIDO.Tissue/niche_type_classfification.yaml \
|
37 |
+
--ckpt_path ckpt_path \
|
38 |
+
> ckpt_path.pred.log 2>&1 &
|
39 |
+
```
|
40 |
+
|
41 |
+
Note: `ckpt_path` is the finetuned checkpoint path.
|
42 |
+
|
43 |
+
|
44 |
+
## Sequence-level regression tasks
|
45 |
+
|
46 |
+
### cell density prediction
|
47 |
+
|
48 |
+
The config file is like `experiments/AIDO.Tissue/cell_density_regression.yaml`, all the fintuning running and evaluation are similar as classification task.
|
49 |
+
|
50 |
+
## Dump embedding
|
51 |
+
|
52 |
+
We can dump embedding for a `.h5ad` file. The script is as:
|
53 |
+
|
54 |
+
```shell
|
55 |
+
CUDA_VISIBLE_DEVICES=3 nohup mgen predict --config experiments/AIDO.Tissue/emb.xenium.yaml > logs/nohup/AIDO.Tissue.emb.xenium.log 2>&1 &
|
56 |
+
```
|
57 |
+
|
58 |
+
The output file will be under specified `output_dir` like `./logs/emb.xenium/lightning_logs/pred_output`. Each batch will be saved and a merged one will also be generated as `predict_predictions.pt`. The `predict_predictions.pt` file satcks all batches:
|
59 |
+
|
60 |
+
```shell
|
61 |
+
>>> import torch
|
62 |
+
>>> file_all = 'predict_predictions.pt'
|
63 |
+
>>> d_all = torch.load(file_all, map_location='cpu')
|
64 |
+
>>> d_all.keys()
|
65 |
+
dict_keys(['predictions', 'ids'])
|
66 |
+
>>> len(d_all['predictions']) # this equal to #sample
|
67 |
+
586
|
68 |
+
>>> len(d_all['ids']) # ids are numeric index corresponding to .h5ad file
|
69 |
+
586
|
70 |
+
>>> d_all['predictions'].shape # (B, L, D), L is max sequence length of all samples
|
71 |
+
torch.Size([586, 90, 128])
|
72 |
+
```
|
73 |
+
|
74 |
+
We can retrieve all the gene embedding and aggregate into cell embedding (like max pooling):
|
75 |
+
|
76 |
+
```bash
|
77 |
+
>>> d_all_maxpooling = [d_all['predictions'][i,:,:] for i in range(d_all['predictions'].shape[0])]
|
78 |
+
>>> d_all_maxpooling = [i[~torch.any(i.isnan(), dim=1)] for i in d_all_maxpooling]
|
79 |
+
>>> d_all_maxpooling = torch.cat([i.max(dim=0)[0].view(1,-1) for i in d_all_maxpooling])
|
80 |
+
>>> d_all_maxpooling.shape
|
81 |
+
torch.Size([586, 128])
|
82 |
+
```
|