Spaces:
Sleeping
Sleeping
Update tasks/Model_Loader.py
Browse files- tasks/Model_Loader.py +13 -9
tasks/Model_Loader.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
2 |
|
3 |
class M5(torch.nn.Module):
|
4 |
-
def __init__(self, num_classes=
|
5 |
super(M5, self).__init__()
|
6 |
self.conv1 = torch.nn.Conv1d(in_channels=1, out_channels=32, kernel_size=80, stride=4)
|
7 |
self.bn1 = torch.nn.BatchNorm1d(32)
|
@@ -26,13 +26,17 @@ class M5(torch.nn.Module):
|
|
26 |
x = self.fc1(x)
|
27 |
return x
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
if __name__ == "__main__":
|
37 |
-
model, device = load_model("
|
38 |
-
print("✅ Model successfully loaded!")
|
|
|
|
1 |
import torch
|
2 |
|
3 |
class M5(torch.nn.Module):
|
4 |
+
def __init__(self, num_classes=2): # Ensure it matches dataset labels (chainsaw/environment)
|
5 |
super(M5, self).__init__()
|
6 |
self.conv1 = torch.nn.Conv1d(in_channels=1, out_channels=32, kernel_size=80, stride=4)
|
7 |
self.bn1 = torch.nn.BatchNorm1d(32)
|
|
|
26 |
x = self.fc1(x)
|
27 |
return x
|
28 |
|
29 |
+
def load_model(model_path, num_classes=2):
|
30 |
+
"""
|
31 |
+
Load trained M5 model.
|
32 |
+
"""
|
33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
model = M5(num_classes=num_classes).to(device)
|
35 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
36 |
+
model.eval() # Set model to evaluation mode
|
37 |
+
return model, device
|
38 |
|
39 |
if __name__ == "__main__":
|
40 |
+
model, device = load_model("quantized_teacher_m5_static.pth")
|
41 |
+
print("✅ Model successfully loaded!")
|
42 |
+
|