kfoughali commited on
Commit
850d736
·
verified ·
1 Parent(s): 35dde85

Update utils/metrics.py

Browse files
Files changed (1) hide show
  1. utils/metrics.py +90 -74
utils/metrics.py CHANGED
@@ -4,73 +4,84 @@ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
4
  import numpy as np
5
 
6
  class GraphMetrics:
7
- """Production-ready evaluation metrics"""
8
 
9
  @staticmethod
10
  def accuracy(pred, target):
11
  """Classification accuracy"""
12
- pred_labels = pred.argmax(dim=1)
 
 
 
13
  return (pred_labels == target).float().mean().item()
14
 
15
  @staticmethod
16
  def f1_score_macro(pred, target):
17
  """Macro F1 score"""
18
- pred_labels = pred.argmax(dim=1).cpu().numpy()
19
- target_labels = target.cpu().numpy()
20
- return f1_score(target_labels, pred_labels, average='macro')
 
 
 
 
 
 
 
21
 
22
  @staticmethod
23
  def f1_score_micro(pred, target):
24
  """Micro F1 score"""
25
- pred_labels = pred.argmax(dim=1).cpu().numpy()
26
- target_labels = target.cpu().numpy()
27
- return f1_score(target_labels, pred_labels, average='micro')
28
-
29
- @staticmethod
30
- def roc_auc(pred, target, num_classes):
31
- """ROC AUC for multi-class"""
32
- if num_classes == 2:
33
- # Binary classification
34
- pred_probs = F.softmax(pred, dim=1)[:, 1].cpu().numpy()
35
  target_labels = target.cpu().numpy()
36
- return roc_auc_score(target_labels, pred_probs)
37
- else:
38
- # Multi-class
39
- pred_probs = F.softmax(pred, dim=1).cpu().numpy()
40
- target_onehot = F.one_hot(target, num_classes).cpu().numpy()
41
- return roc_auc_score(target_onehot, pred_probs, multi_class='ovr', average='macro')
42
 
43
  @staticmethod
44
  def evaluate_node_classification(model, data, mask, device):
45
  """Comprehensive node classification evaluation"""
46
  model.eval()
47
 
48
- with torch.no_grad():
49
- data = data.to(device)
50
- h = model(data.x, data.edge_index)
51
-
52
- # Assuming a classification head exists
53
- if hasattr(model, 'classifier'):
54
- pred = model.classifier(h)
55
- else:
56
- # If no classifier, return embeddings
57
- return {'embeddings': h[mask].cpu()}
58
-
59
- pred_masked = pred[mask]
60
- target_masked = data.y[mask]
61
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  metrics = {
63
- 'accuracy': GraphMetrics.accuracy(pred_masked, target_masked),
64
- 'f1_macro': GraphMetrics.f1_score_macro(pred_masked, target_masked),
65
- 'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked),
 
66
  }
67
-
68
- # Add ROC AUC if binary/multi-class
69
- try:
70
- num_classes = pred.size(1)
71
- metrics['roc_auc'] = GraphMetrics.roc_auc(pred_masked, target_masked, num_classes)
72
- except:
73
- pass
74
 
75
  return metrics
76
 
@@ -82,35 +93,40 @@ class GraphMetrics:
82
  all_preds = []
83
  all_targets = []
84
 
85
- with torch.no_grad():
86
- for batch in dataloader:
87
- batch = batch.to(device)
88
- h = model(batch.x, batch.edge_index, batch.batch)
89
-
90
- # Graph-level prediction
91
- graph_h = model.get_graph_embedding(h, batch.batch)
92
-
93
- if hasattr(model, 'classifier'):
94
- pred = model.classifier(graph_h)
95
- all_preds.append(pred)
96
- all_targets.append(batch.y)
97
-
98
- if all_preds:
99
- all_preds = torch.cat(all_preds, dim=0)
100
- all_targets = torch.cat(all_targets, dim=0)
101
-
102
- metrics = {
103
- 'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
104
- 'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
105
- 'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
106
- }
107
 
108
- try:
109
- num_classes = all_preds.size(1)
110
- metrics['roc_auc'] = GraphMetrics.roc_auc(all_preds, all_targets, num_classes)
111
- except:
112
- pass
 
 
 
 
 
 
113
 
114
- return metrics
 
 
115
 
116
- return {'error': 'No predictions generated'}
 
4
  import numpy as np
5
 
6
  class GraphMetrics:
7
+ """Production-ready evaluation metrics - device safe"""
8
 
9
  @staticmethod
10
  def accuracy(pred, target):
11
  """Classification accuracy"""
12
+ if pred.dim() > 1:
13
+ pred_labels = pred.argmax(dim=1)
14
+ else:
15
+ pred_labels = pred
16
  return (pred_labels == target).float().mean().item()
17
 
18
  @staticmethod
19
  def f1_score_macro(pred, target):
20
  """Macro F1 score"""
21
+ try:
22
+ if pred.dim() > 1:
23
+ pred_labels = pred.argmax(dim=1)
24
+ else:
25
+ pred_labels = pred
26
+ pred_labels = pred_labels.cpu().numpy()
27
+ target_labels = target.cpu().numpy()
28
+ return f1_score(target_labels, pred_labels, average='macro', zero_division=0)
29
+ except:
30
+ return 0.0
31
 
32
  @staticmethod
33
  def f1_score_micro(pred, target):
34
  """Micro F1 score"""
35
+ try:
36
+ if pred.dim() > 1:
37
+ pred_labels = pred.argmax(dim=1)
38
+ else:
39
+ pred_labels = pred
40
+ pred_labels = pred_labels.cpu().numpy()
 
 
 
 
41
  target_labels = target.cpu().numpy()
42
+ return f1_score(target_labels, pred_labels, average='micro', zero_division=0)
43
+ except:
44
+ return 0.0
 
 
 
45
 
46
  @staticmethod
47
  def evaluate_node_classification(model, data, mask, device):
48
  """Comprehensive node classification evaluation"""
49
  model.eval()
50
 
51
+ try:
52
+ with torch.no_grad():
53
+ # Ensure data is on correct device
54
+ data = data.to(device)
55
+ model = model.to(device)
56
+
57
+ h = model(data.x, data.edge_index)
58
+
59
+ # Get predictions
60
+ if hasattr(model, 'classifier') and model.classifier is not None:
61
+ pred = model.classifier(h)
62
+ else:
63
+ # Initialize classifier if needed
64
+ num_classes = len(torch.unique(data.y))
65
+ model._init_classifier(num_classes, device)
66
+ pred = model.classifier(h)
67
+
68
+ pred_masked = pred[mask]
69
+ target_masked = data.y[mask]
70
+
71
+ metrics = {
72
+ 'accuracy': GraphMetrics.accuracy(pred_masked, target_masked),
73
+ 'f1_macro': GraphMetrics.f1_score_macro(pred_masked, target_masked),
74
+ 'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked),
75
+ }
76
+
77
+ except Exception as e:
78
+ print(f"Evaluation error: {e}")
79
  metrics = {
80
+ 'accuracy': 0.0,
81
+ 'f1_macro': 0.0,
82
+ 'f1_micro': 0.0,
83
+ 'error': str(e)
84
  }
 
 
 
 
 
 
 
85
 
86
  return metrics
87
 
 
93
  all_preds = []
94
  all_targets = []
95
 
96
+ try:
97
+ with torch.no_grad():
98
+ for batch in dataloader:
99
+ batch = batch.to(device)
100
+ h = model(batch.x, batch.edge_index, batch.batch)
101
+
102
+ # Graph-level prediction
103
+ graph_h = model.get_graph_embedding(h, batch.batch)
104
+
105
+ if hasattr(model, 'classifier') and model.classifier is not None:
106
+ pred = model.classifier(graph_h)
107
+ else:
108
+ # Initialize classifier
109
+ num_classes = len(torch.unique(batch.y))
110
+ model._init_classifier(num_classes, device)
111
+ pred = model.classifier(graph_h)
112
+
113
+ all_preds.append(pred.cpu())
114
+ all_targets.append(batch.y.cpu())
 
 
 
115
 
116
+ if all_preds:
117
+ all_preds = torch.cat(all_preds, dim=0)
118
+ all_targets = torch.cat(all_targets, dim=0)
119
+
120
+ metrics = {
121
+ 'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
122
+ 'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
123
+ 'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
124
+ }
125
+ else:
126
+ metrics = {'error': 'No predictions generated'}
127
 
128
+ except Exception as e:
129
+ print(f"Graph classification evaluation error: {e}")
130
+ metrics = {'error': str(e)}
131
 
132
+ return metrics