Christina Theodoris
commited on
Commit
·
bfcada4
1
Parent(s):
e2ee685
fix gene class dict labeling
Browse files
geneformer/classifier_utils.py
CHANGED
@@ -115,13 +115,20 @@ def label_classes(classifier, data, gene_class_dict, nproc):
|
|
115 |
|
116 |
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
|
117 |
id_class_dict = {v: k for k, v in class_id_dict.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
def classes_to_ids(example):
|
120 |
if classifier == "cell":
|
121 |
example["label"] = class_id_dict[example["label"]]
|
122 |
elif classifier == "gene":
|
123 |
example["labels"] = label_gene_classes(
|
124 |
-
example, class_id_dict,
|
125 |
)
|
126 |
return example
|
127 |
|
@@ -129,9 +136,9 @@ def label_classes(classifier, data, gene_class_dict, nproc):
|
|
129 |
return data, id_class_dict
|
130 |
|
131 |
|
132 |
-
def label_gene_classes(example, class_id_dict,
|
133 |
return [
|
134 |
-
class_id_dict.get(
|
135 |
for token_id in example["input_ids"]
|
136 |
]
|
137 |
|
|
|
115 |
|
116 |
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
|
117 |
id_class_dict = {v: k for k, v in class_id_dict.items()}
|
118 |
+
inverse_gene_class_dict = {}
|
119 |
+
# Iterate over each key and list of values in the original dictionary
|
120 |
+
for key, value_list in gene_class_dict.items():
|
121 |
+
# Iterate over each value in the list
|
122 |
+
for value in value_list:
|
123 |
+
# Assign the value as a key and the original key as its value in the new dictionary
|
124 |
+
inverse_gene_class_dict[value] = key
|
125 |
|
126 |
def classes_to_ids(example):
|
127 |
if classifier == "cell":
|
128 |
example["label"] = class_id_dict[example["label"]]
|
129 |
elif classifier == "gene":
|
130 |
example["labels"] = label_gene_classes(
|
131 |
+
example, class_id_dict, inverse_gene_class_dict
|
132 |
)
|
133 |
return example
|
134 |
|
|
|
136 |
return data, id_class_dict
|
137 |
|
138 |
|
139 |
+
def label_gene_classes(example, class_id_dict, inverse_gene_class_dict):
|
140 |
return [
|
141 |
+
class_id_dict.get(inverse_gene_class_dict.get(token_id, -100), -100)
|
142 |
for token_id in example["input_ids"]
|
143 |
]
|
144 |
|