Christina Theodoris commited on
Commit
bfcada4
·
1 Parent(s): e2ee685

fix gene class dict labeling

Browse files
Files changed (1) hide show
  1. geneformer/classifier_utils.py +10 -3
geneformer/classifier_utils.py CHANGED
@@ -115,13 +115,20 @@ def label_classes(classifier, data, gene_class_dict, nproc):
115
 
116
  class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
117
  id_class_dict = {v: k for k, v in class_id_dict.items()}
 
 
 
 
 
 
 
118
 
119
  def classes_to_ids(example):
120
  if classifier == "cell":
121
  example["label"] = class_id_dict[example["label"]]
122
  elif classifier == "gene":
123
  example["labels"] = label_gene_classes(
124
- example, class_id_dict, gene_class_dict
125
  )
126
  return example
127
 
@@ -129,9 +136,9 @@ def label_classes(classifier, data, gene_class_dict, nproc):
129
  return data, id_class_dict
130
 
131
 
132
- def label_gene_classes(example, class_id_dict, gene_class_dict):
133
  return [
134
- class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
135
  for token_id in example["input_ids"]
136
  ]
137
 
 
115
 
116
  class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
117
  id_class_dict = {v: k for k, v in class_id_dict.items()}
118
+ inverse_gene_class_dict = {}
119
+ # Iterate over each key and list of values in the original dictionary
120
+ for key, value_list in gene_class_dict.items():
121
+ # Iterate over each value in the list
122
+ for value in value_list:
123
+ # Assign the value as a key and the original key as its value in the new dictionary
124
+ inverse_gene_class_dict[value] = key
125
 
126
  def classes_to_ids(example):
127
  if classifier == "cell":
128
  example["label"] = class_id_dict[example["label"]]
129
  elif classifier == "gene":
130
  example["labels"] = label_gene_classes(
131
+ example, class_id_dict, inverse_gene_class_dict
132
  )
133
  return example
134
 
 
136
  return data, id_class_dict
137
 
138
 
139
+ def label_gene_classes(example, class_id_dict, inverse_gene_class_dict):
140
  return [
141
+ class_id_dict.get(inverse_gene_class_dict.get(token_id, -100), -100)
142
  for token_id in example["input_ids"]
143
  ]
144