Spaces:
Running
Running
Commit
·
d65e913
1
Parent(s):
7cdd792
Fix mapping preview
Browse files- text_classification.py +11 -6
text_classification.py
CHANGED
|
@@ -91,22 +91,27 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
|
|
| 91 |
return column_mapping, prediction_result, None
|
| 92 |
|
| 93 |
if isinstance(column_mapping["label"], dict):
|
| 94 |
-
|
| 95 |
-
|
|
|
|
| 96 |
elif None in id2label_mapping.values():
|
| 97 |
column_mapping["label"] = {
|
| 98 |
i: None for i in id2label.keys()
|
| 99 |
}
|
| 100 |
return column_mapping, prediction_result, None
|
| 101 |
|
|
|
|
|
|
|
|
|
|
| 102 |
id2label_df = pd.DataFrame({
|
| 103 |
-
"ID":
|
| 104 |
-
"
|
| 105 |
-
"
|
| 106 |
})
|
| 107 |
if "label" not in column_mapping.keys():
|
|
|
|
| 108 |
column_mapping["label"] = {
|
| 109 |
-
i: id2label_mapping[
|
| 110 |
}
|
| 111 |
|
| 112 |
return column_mapping, prediction_result, id2label_df
|
|
|
|
| 91 |
return column_mapping, prediction_result, None
|
| 92 |
|
| 93 |
if isinstance(column_mapping["label"], dict):
|
| 94 |
+
# Use the column mapping passed by user
|
| 95 |
+
for i, model_label in column_mapping["label"].items():
|
| 96 |
+
id2label_mapping[model_label] = dataset_labels[int(i)]
|
| 97 |
elif None in id2label_mapping.values():
|
| 98 |
column_mapping["label"] = {
|
| 99 |
i: None for i in id2label.keys()
|
| 100 |
}
|
| 101 |
return column_mapping, prediction_result, None
|
| 102 |
|
| 103 |
+
id2label_mapping = {
|
| 104 |
+
v: k for k, v in id2label_mapping.items()
|
| 105 |
+
}
|
| 106 |
id2label_df = pd.DataFrame({
|
| 107 |
+
"ID": list(range(len(dataset_labels))),
|
| 108 |
+
"Labels": dataset_labels,
|
| 109 |
+
"Labels in original model": [f"{id2label_mapping[label]}({label2id[id2label_mapping[label]]})" for label in dataset_labels],
|
| 110 |
})
|
| 111 |
if "label" not in column_mapping.keys():
|
| 112 |
+
# Column mapping should contain original model labels
|
| 113 |
column_mapping["label"] = {
|
| 114 |
+
str(i): id2label_mapping[label] for i, label in zip(id2label.keys(), dataset_labels)
|
| 115 |
}
|
| 116 |
|
| 117 |
return column_mapping, prediction_result, id2label_df
|