Prepare data cell level pass cell state dict instead of genes

#483
by hchen725 - opened
Files changed (1) hide show
  1. geneformer/classifier.py +9 -3
geneformer/classifier.py CHANGED
@@ -437,14 +437,20 @@ class Classifier:
437
  )
438
  # rename cell state column to "label"
439
  data = cu.rename_cols(data, self.cell_state_dict["state_key"])
 
 
 
 
 
440
 
 
441
  # convert classes to numerical labels and save as id_class_dict
442
  # of note, will label all genes in gene_class_dict
443
  # if (cross-)validating, genes will be relabeled in column "labels" for each split
444
  # at the time of training with Classifier.validate
445
- data, id_class_dict = cu.label_classes(
446
- self.classifier, data, self.gene_class_dict, self.nproc
447
- )
448
 
449
  # save id_class_dict for future reference
450
  id_class_output_path = (
 
437
  )
438
  # rename cell state column to "label"
439
  data = cu.rename_cols(data, self.cell_state_dict["state_key"])
440
+
441
+ # convert classes to numerical labels and save as id_class_dict
442
+ data, id_class_dict = cu.label_classes(
443
+ self.classifier, data, self.cell_state_dict, self.nproc
444
+ )
445
 
446
+ elif self.classifier == "gene":
447
  # convert classes to numerical labels and save as id_class_dict
448
  # of note, will label all genes in gene_class_dict
449
  # if (cross-)validating, genes will be relabeled in column "labels" for each split
450
  # at the time of training with Classifier.validate
451
+ data, id_class_dict = cu.label_classes(
452
+ self.classifier, data, self.gene_class_dict, self.nproc
453
+ )
454
 
455
  # save id_class_dict for future reference
456
  id_class_output_path = (