Sadjad Alikhani commited on
Commit
d05c994
·
verified ·
1 Parent(s): fc197f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -1
app.py CHANGED
@@ -47,11 +47,41 @@ def beam_prediction_task(data_percentage, task_complexity):
47
 
48
  return raw_img, embeddings_img
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
 
 
 
 
 
 
 
51
  plt.figure(figsize=(8, 6))
52
  plt.imshow(cm, interpolation='nearest', cmap='coolwarm')
53
- plt.title(title)
54
  plt.colorbar()
 
55
  tick_marks = np.arange(len(classes))
56
  plt.xticks(tick_marks, classes, rotation=45)
57
  plt.yticks(tick_marks, classes)
@@ -61,6 +91,22 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
61
  plt.xlabel('Predicted label')
62
  plt.savefig(save_path)
63
  plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # Function to compute the average confusion matrix across CSV files in a folder
66
  #def compute_average_confusion_matrix(folder):
 
47
 
48
  return raw_img, embeddings_img
49
 
50
+ from sklearn.metrics import f1_score
51
+
52
+ # Function to compute the F1-score based on the confusion matrix
53
+ def compute_f1_score(cm):
54
+ # Compute precision and recall
55
+ TP = np.diag(cm)
56
+ FP = np.sum(cm, axis=0) - TP
57
+ FN = np.sum(cm, axis=1) - TP
58
+
59
+ precision = TP / (TP + FP)
60
+ recall = TP / (TP + FN)
61
+
62
+ # Handle division by zero in precision or recall
63
+ precision = np.nan_to_num(precision)
64
+ recall = np.nan_to_num(recall)
65
+
66
+ # Compute F1 score
67
+ f1 = 2 * (precision * recall) / (precision + recall)
68
+ f1 = np.nan_to_num(f1) # Replace NaN with 0
69
+ return np.mean(f1) # Return the mean F1-score across all classes
70
+
71
+ # Function to plot and save confusion matrix with F1-score in the title
72
  def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
73
+ # Compute the average F1-score
74
+ avg_f1 = compute_f1_score(cm)
75
+
76
+ # Update title to include average F1-score
77
+ full_title = f"{title} (Avg F1-Score: {avg_f1:.2f})"
78
+
79
+ # Plot the confusion matrix
80
  plt.figure(figsize=(8, 6))
81
  plt.imshow(cm, interpolation='nearest', cmap='coolwarm')
82
+ plt.title(full_title)
83
  plt.colorbar()
84
+
85
  tick_marks = np.arange(len(classes))
86
  plt.xticks(tick_marks, classes, rotation=45)
87
  plt.yticks(tick_marks, classes)
 
91
  plt.xlabel('Predicted label')
92
  plt.savefig(save_path)
93
  plt.close()
94
+
95
+
96
+ #def plot_confusion_matrix_beamPred(cm, classes, title, save_path):
97
+ # plt.figure(figsize=(8, 6))
98
+ # plt.imshow(cm, interpolation='nearest', cmap='coolwarm')
99
+ # plt.title(title)
100
+ # plt.colorbar()
101
+ # tick_marks = np.arange(len(classes))
102
+ # plt.xticks(tick_marks, classes, rotation=45)
103
+ # plt.yticks(tick_marks, classes)
104
+ #
105
+ # plt.tight_layout()
106
+ # plt.ylabel('True label')
107
+ # plt.xlabel('Predicted label')
108
+ # plt.savefig(save_path)
109
+ # plt.close()
110
 
111
  # Function to compute the average confusion matrix across CSV files in a folder
112
  #def compute_average_confusion_matrix(folder):