seronk commited on
Commit
f95c051
·
verified ·
1 Parent(s): e554c0a

Update tasks/Model_Loader.py

Browse files
Files changed (1) hide show
  1. tasks/Model_Loader.py +13 -9
tasks/Model_Loader.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
 
3
  class M5(torch.nn.Module):
4
- def __init__(self, num_classes=10):
5
  super(M5, self).__init__()
6
  self.conv1 = torch.nn.Conv1d(in_channels=1, out_channels=32, kernel_size=80, stride=4)
7
  self.bn1 = torch.nn.BatchNorm1d(32)
@@ -26,13 +26,17 @@ class M5(torch.nn.Module):
26
  x = self.fc1(x)
27
  return x
28
 
29
- def load_model(model_path, num_classes=2):
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- model = M5(num_classes=num_classes).to(device)
32
- model.load_state_dict(torch.load(model_path, map_location=device))
33
- model.eval() # Set model to evaluation mode
34
- return model, device
 
 
 
35
 
36
  if __name__ == "__main__":
37
- model, device = load_model("m5_audio_classification.pth")
38
- print("✅ Model successfully loaded!")
 
 
1
  import torch
2
 
3
  class M5(torch.nn.Module):
4
+ def __init__(self, num_classes=2): # Ensure it matches dataset labels (chainsaw/environment)
5
  super(M5, self).__init__()
6
  self.conv1 = torch.nn.Conv1d(in_channels=1, out_channels=32, kernel_size=80, stride=4)
7
  self.bn1 = torch.nn.BatchNorm1d(32)
 
26
  x = self.fc1(x)
27
  return x
28
 
29
+ def load_model(model_path, num_classes=2):
30
+ """
31
+ Load trained M5 model.
32
+ """
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ model = M5(num_classes=num_classes).to(device)
35
+ model.load_state_dict(torch.load(model_path, map_location=device))
36
+ model.eval() # Set model to evaluation mode
37
+ return model, device
38
 
39
  if __name__ == "__main__":
40
+ model, device = load_model("quantized_teacher_m5_static.pth")
41
+ print("✅ Model successfully loaded!")
42
+