gajeshladhar commited on
Commit
1e6fe0a
Β·
1 Parent(s): 75fecef

core-dino src added

Browse files
Files changed (5) hide show
  1. src/backbone.py +79 -0
  2. src/data.py +158 -0
  3. src/loss.py +144 -0
  4. src/train.py +112 -0
  5. src/utils.py +77 -0
src/backbone.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🦴 core-dino | YOLO Backbone Wrapper for Feature Extraction πŸ”
3
+
4
+ Wraps a YOLO model to extract intermediate feature maps for DINO-style
5
+ self-supervised training. Optionally applies an MLP projection head.
6
+
7
+ Author: Gajesh Ladhar
8
+ πŸ”— LinkedIn: https://www.linkedin.com/in/gajeshladhar/
9
+ πŸ€— Hugging Face: https://huggingface.co/gajeshladhar
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from ultralytics import YOLO
15
+
16
+
17
+ class YOLOBackBone(nn.Module):
18
+ """
19
+ 🧩 Extracts multi-scale spatial features from YOLO backbone.
20
+
21
+ Args:
22
+ model_path (str): Path to YOLO weights (.pt)
23
+ stop_at (int): Layer index to cut the model
24
+ use_mlp (bool): Whether to apply MLP projection head
25
+ mlp_dim (int): Output dim of MLP head (if enabled)
26
+ """
27
+ def __init__(self, model_path='yolo11x.pt', stop_at=23, use_mlp=True, mlp_dim=512):
28
+ super().__init__()
29
+ raw_model = YOLO(model_path).model.train()
30
+ self.layers = nn.ModuleList(raw_model.model[:stop_at])
31
+ self.layer_defs = raw_model.yaml["backbone"] + raw_model.yaml["head"]
32
+
33
+ self.use_mlp = use_mlp
34
+ if use_mlp:
35
+ self.init_mlp(self._get_out_channels(self.layers[-1]), mlp_dim)
36
+
37
+ for p in self.parameters():
38
+ p.requires_grad = True
39
+
40
+ def _get_out_channels(self, layer):
41
+ return 768
42
+
43
+ def init_mlp(self, in_channels, out_channels):
44
+ self.mlp_head = nn.Identity()
45
+ # self.mlp_head = nn.Sequential(
46
+ # nn.Conv2d(in_channels, 2048, 1),
47
+ # nn.GELU(),
48
+ # nn.Conv2d(2048, out_channels, 1),
49
+ # nn.GELU(),
50
+ # nn.Conv2d(out_channels, in_channels, 1)
51
+ # )
52
+
53
+ def apply_mlp(self, x):
54
+ return self.mlp_head(x) if self.use_mlp else x
55
+
56
+ def forward(self, x):
57
+ """
58
+ πŸš€ Forward pass through selected YOLO layers and optional MLP.
59
+
60
+ Args:
61
+ x (Tensor): Input image tensor (B, C, H, W)
62
+
63
+ Returns:
64
+ Tensor: Final feature map
65
+ """
66
+ outputs = []
67
+ for i, layer in enumerate(self.layers):
68
+ from_ids = self.layer_defs[i][0]
69
+ from_ids = [from_ids] if isinstance(from_ids, int) else from_ids
70
+ inputs = [x if j == -1 else outputs[j] for j in from_ids]
71
+ x = layer(inputs if len(inputs) > 1 else inputs[0])
72
+ outputs.append(x)
73
+ return self.apply_mlp(x)
74
+
75
+ def count_params(self):
76
+ total = sum(p.numel() for p in self.parameters())
77
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
78
+ return total, trainable
79
+
src/data.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """πŸ“¦ core-dino | Data Loader for Self-Supervised DINO Training on Core-Five πŸš€
2
+
3
+ This module defines the `DinoDataset` which streams multi-resolution
4
+ satellite patches from the Core-Five dataset, preparing teacher-student
5
+ views for resolution-agnostic self-supervised learning.
6
+ """
7
+
8
+ import os
9
+ import io
10
+ import time
11
+ import torch
12
+ import random
13
+ import requests
14
+ import numpy as np
15
+ import geopandas as gpd
16
+ import h5py
17
+ import xarray as xr
18
+ from torch import nn
19
+ from torch.utils.data import Dataset
20
+ import albumentations as A
21
+ import fsspec
22
+
23
+ from utils import (
24
+ shared_store, process_pool, write_last_updated,
25
+ AddPoissonNoise, AddSaltPepperNoise
26
+ )
27
+
28
+ class DinoDataset(Dataset):
29
+ """
30
+ 🧠 DinoDataset β€” resolution-agnostic loader for Core-Five 🌍
31
+
32
+ Streams random crops of HR satellite images from Hugging Face,
33
+ creates clean (teacher) and augmented (student) views using
34
+ Albumentations & torch.
35
+
36
+ ---
37
+ πŸ‘€ Author: Gajesh Ladhar
38
+ πŸ”— LinkedIn: πŸ”— https://www.linkedin.com/in/gajeshladhar/
39
+ πŸ€— Hugging Face: πŸ€— https://huggingface.co/gajeshladhar
40
+
41
+ """
42
+ def __init__(self, imgsz, batch_size=1, queue_size=50):
43
+ """
44
+ πŸ“ Init the dataset with remote Core-Five metadata and start
45
+ async patch fetching.
46
+
47
+ Args:
48
+ imgsz (int): Patch size (min 320 recommended)
49
+ batch_size (int): Number of patches per batch
50
+ queue_size (int): Max queue length for shared store
51
+ """
52
+ if imgsz < 320:
53
+ raise ValueError("❗️imgsz must be β‰₯ 320 for stable patch extraction β€” got {}".format(imgsz))
54
+ self.imgsz = imgsz
55
+ metadata_url = "https://huggingface.co/datasets/gajeshladhar/core-five/resolve/main/metadata.parquet"
56
+ self.df_metadata = gpd.read_parquet(fsspec.open(metadata_url).open())
57
+ self.batch_size = batch_size
58
+ self.queue_size = queue_size
59
+ self.store = shared_store
60
+
61
+ for _ in range(6):
62
+ process_pool.submit(self.fetch_and_store)
63
+
64
+ @staticmethod
65
+ def transform(batch):
66
+ """
67
+ πŸŽ›οΈ Apply augmentation pipeline to simulate degraded inputs
68
+ for student; teacher gets clean view. Maintains shape consistency.
69
+
70
+ Returns:
71
+ Dict with 'student' and 'teacher' uint8 tensors
72
+ """
73
+ augment_satellite = A.Compose([
74
+ A.GaussNoise(std_range=(0.01, 0.1), p=0.3),
75
+ AddPoissonNoise(p=0.3),
76
+ AddSaltPepperNoise(amount=0.02, p=0.3),
77
+ A.MultiplicativeNoise(multiplier=(0.9, 1.1), elementwise=True, p=0.3),
78
+ A.MotionBlur(blur_limit=(3, 11), p=0.3),
79
+ A.GaussianBlur(blur_limit=(3, 11), p=0.3),
80
+ A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=0.1),
81
+ A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=0.3),
82
+ A.RGBShift(r_shift_limit=30, g_shift_limit=30, b_shift_limit=30, p=0.3),
83
+ A.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=30, val_shift_limit=30, p=0.3),
84
+ A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.2),
85
+ A.CoarseDropout(num_holes_range=(1, 4), hole_height_range=(0.05, 0.2),
86
+ hole_width_range=(0.05, 0.2), fill='random_uniform', p=0.1)
87
+ ])
88
+
89
+ imgsz_half = batch[0].shape[-1]
90
+ size = np.random.choice(np.arange(32 * 10, imgsz_half, 32))
91
+
92
+ student, teacher = [], []
93
+ for img in batch:
94
+ student_data = nn.Upsample(size=size, mode='bilinear')(torch.tensor(img[np.newaxis, :]))[0].data.numpy().astype("uint8")
95
+ student_data = augment_satellite(image=student_data.transpose(1, 2, 0))['image'].transpose(2, 0, 1)
96
+ student.append(torch.tensor(student_data))
97
+ teacher.append(torch.tensor(img))
98
+
99
+ return {
100
+ "student": torch.stack(student).to(torch.uint8),
101
+ "teacher": torch.stack(teacher).to(torch.uint8)
102
+ }
103
+
104
+ def fetch_and_store(self):
105
+ """
106
+ πŸ”„ Continuously samples random crops from Core-Five, augments
107
+ them via `transform`, and updates the shared queue for training.
108
+ """
109
+ np.random.seed(int.from_bytes(os.urandom(4), 'little'))
110
+ while True:
111
+ try:
112
+ batch = []
113
+ for _ in range(self.batch_size):
114
+ path = os.path.join("https://huggingface.co/datasets/gajeshladhar/core-five/resolve/main/",
115
+ self.df_metadata.sample(n=1).path.iloc[0])
116
+ buffer = io.BytesIO(requests.get(path, headers={"User-Agent": "Mozilla/5.0"}).content)
117
+ with h5py.File(buffer, "r") as f:
118
+ x = f["hr/x"][:]
119
+ y = f["hr/y"][:]
120
+ data = f["/hr/data"][:]
121
+ bands = list(range(data.shape[0]))
122
+
123
+ ds = xr.DataArray(data, dims=['band', 'y', 'x'], coords=[bands, y, x]).astype("uint8")
124
+
125
+ imgsz_half = self.imgsz // 2
126
+ yid = np.random.randint(imgsz_half, len(ds.y) - imgsz_half)
127
+ xid = np.random.randint(imgsz_half, len(ds.x) - imgsz_half)
128
+ ds = ds.isel(y=range(yid - imgsz_half, yid + imgsz_half),
129
+ x=range(xid - imgsz_half, xid + imgsz_half)).compute()
130
+ ds['y'], ds['x'] = np.linspace(ds.y.values[0], ds.y.values[-1], ds.shape[1]), \
131
+ np.linspace(ds.x.values[0], ds.x.values[-1], ds.shape[2])
132
+
133
+ batch.append(ds.data)
134
+
135
+ result = DinoDataset.transform(batch)
136
+ if len(self.store) >= self.queue_size:
137
+ index = np.random.randint(0, self.queue_size - 1)
138
+ self.store[index] = result
139
+ else:
140
+ self.store.append(result)
141
+
142
+ # enable for getting recent updates
143
+ if np.random.random() < 0.20:
144
+ write_last_updated()
145
+ except KeyboardInterrupt:
146
+ break
147
+ except Exception as e:
148
+ print("ERROR:", e)
149
+ continue
150
+
151
+
152
+
153
+
154
+ if __name__=="__main__":
155
+ dataset = DinoDataset(imgsz=1696,batch_size=3,queue_size=1000)
156
+ while True :
157
+ print(len(dataset.store))
158
+ time.sleep(5)
src/loss.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🎯 core-dino | DINO-style Loss Functions πŸ’₯
3
+
4
+ Defines the cross-view contrastive loss used in DINO setups,
5
+ including temperature scaling, centering, and teacher-student divergence.
6
+
7
+ Includes:
8
+ - DinoSpatialLoss: Temp-scaled CE loss with center momentum πŸŒ€
9
+ - DinoSinkhornSpatialLoss: Sinkhorn-based balanced assignment loss βš–οΈ
10
+
11
+ Author: Gajesh Ladhar
12
+ πŸ”— LinkedIn: https://www.linkedin.com/in/gajeshladhar/
13
+ πŸ€— Hugging Face: https://huggingface.co/gajeshladhar
14
+ """
15
+
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+
20
+
21
+ class DinoSpatialLoss(nn.Module):
22
+ """
23
+ πŸŒ€ DINO loss using temperature-scaled cross-entropy over spatial tokens.
24
+
25
+ - Aligns teacher & student spatial features (B, C, H, W)
26
+ - Applies center momentum for teacher stability
27
+
28
+ Args:
29
+ teacher_temp (float): Temperature for teacher softmax
30
+ student_temp (float): Temperature for student softmax
31
+ center_momentum (float): EMA factor for center update
32
+ """
33
+ def __init__(self, teacher_temp=0.04, student_temp=0.1, center_momentum=0.9):
34
+ super().__init__()
35
+ self.teacher_temp = teacher_temp
36
+ self.student_temp = student_temp
37
+ self.center_momentum = center_momentum
38
+ self.register_buffer("center", torch.zeros(1, 1)) # lazy init
39
+
40
+ def forward(self, student_feat, teacher_feat):
41
+ """
42
+ Compute loss over (B, C, H, W) features.
43
+
44
+ Args:
45
+ student_feat (Tensor): Student output, shape (B, C, Hs, Ws)
46
+ teacher_feat (Tensor): Teacher output, shape (B, C, Ht, Wt)
47
+
48
+ Returns:
49
+ Tensor: Scalar DINO loss
50
+ """
51
+
52
+ # Initialize center shape based on teacher feature dim
53
+ if self.center.shape[1] == 1:
54
+ self.center = self.center.new_zeros(1, teacher_feat.shape[1])
55
+
56
+ # Resize student to teacher resolution
57
+ student_resized = F.interpolate(student_feat, size=teacher_feat.shape[2:], mode='bilinear', align_corners=False)
58
+
59
+ # Flatten spatial dims: (B, C, H, W) β†’ (B*H*W, C)
60
+ B, C, H, W = student_resized.shape
61
+ student_flat = student_resized.permute(0, 2, 3, 1).reshape(-1, C) # (BHW, C)
62
+ teacher_flat = teacher_feat.permute(0, 2, 3, 1).reshape(-1, C) # (BHW, C)
63
+
64
+ # Apply softmax (teacher uses center)
65
+ student_logits = student_flat / self.student_temp
66
+ teacher_logits = (teacher_flat - self.center) / self.teacher_temp
67
+
68
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
69
+ teacher_probs = F.softmax(teacher_logits, dim=-1).detach()
70
+
71
+ # Cross-entropy loss
72
+ loss = - (teacher_probs * student_log_probs).sum(dim=-1).mean()
73
+
74
+ # Update center
75
+ batch_center = teacher_probs.mean(dim=0, keepdim=True)
76
+ self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
77
+
78
+ return loss
79
+
80
+
81
+
82
+ class SinkhornKnopp(nn.Module):
83
+ """
84
+ βš–οΈ Sinkhorn-Knopp normalization for balanced assignments.
85
+
86
+ Args:
87
+ num_iters (int): Number of normalization iterations
88
+ eps (float): Stabilizer to avoid div-by-zero
89
+ """
90
+ def __init__(self, num_iters: int = 3, eps: float = 1e-6):
91
+ super().__init__()
92
+ self.num_iters = num_iters
93
+ self.eps = eps
94
+
95
+ def forward(self, logits: torch.Tensor) -> torch.Tensor:
96
+ logits = logits - logits.max(dim=1, keepdim=True)[0] # stabilize
97
+ Q = torch.exp(logits).clone()
98
+ Q /= Q.sum()
99
+
100
+ for _ in range(self.num_iters):
101
+ Q /= Q.sum(dim=1, keepdim=True) + self.eps # row normalization
102
+ Q /= Q.sum(dim=0, keepdim=True) + self.eps # column normalization
103
+
104
+ return Q
105
+
106
+ class DinoSinkhornSpatialLoss(nn.Module):
107
+ """
108
+ πŸŒ€ DINO loss with Sinkhorn assignment β€” no center, balanced targets.
109
+
110
+ Args:
111
+ student_temp (float): Temperature for student softmax
112
+ sinkhorn_iters (int): Iterations for Sinkhorn normalization
113
+ """
114
+ def __init__(self, student_temp=0.1, sinkhorn_iters=3):
115
+ super().__init__()
116
+ self.student_temp = student_temp
117
+ self.sinkhorn = SinkhornKnopp(sinkhorn_iters)
118
+
119
+ def forward(self, student_feat, teacher_feat):
120
+ """
121
+ student_feat: (B, C, Hs, Ws)
122
+ teacher_feat: (B, C, Ht, Wt)
123
+ """
124
+
125
+ # Resize student to teacher resolution
126
+ student_resized = F.interpolate(
127
+ student_feat, size=teacher_feat.shape[2:], mode='bilinear', align_corners=False
128
+ )
129
+
130
+ # Flatten spatial dims: (B, C, H, W) β†’ (BHW, C)
131
+ B, C, H, W = student_resized.shape
132
+ student_flat = student_resized.permute(0, 2, 3, 1).reshape(-1, C)
133
+ teacher_flat = teacher_feat.permute(0, 2, 3, 1).reshape(-1, C)
134
+
135
+ # Teacher: apply Sinkhorn (no temp, no center)
136
+ teacher_probs = self.sinkhorn(teacher_flat).detach()
137
+
138
+ # Student: softmax with temp
139
+ student_log_probs = F.log_softmax(student_flat / self.student_temp, dim=-1)
140
+
141
+ # Cross-entropy loss
142
+ loss = -(teacher_probs * student_log_probs).sum(dim=-1).mean()
143
+
144
+ return loss
src/train.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ πŸ›°οΈ core-dino | Training Script for Resolution-Agnostic SSL on Satellite Imagery
3
+
4
+ Trains DINO with a YOLO backbone using multi-resolution Core-Five patches.
5
+
6
+ πŸ‘¨β€πŸ’» Author: Gajesh Ladhar
7
+ πŸ”— LinkedIn: https://www.linkedin.com/in/gajeshladhar/
8
+ πŸ€— Hugging Face: https://huggingface.co/gajeshladhar
9
+ """
10
+
11
+ # πŸ“¦ Imports
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from loss import DinoSpatialLoss
15
+ from backbone import YOLOBackBone
16
+ from data import DinoDataset
17
+ from utils import *
18
+
19
+ # βš™οΈ Config
20
+ CFG = {
21
+ "imgsz": 1696,
22
+ "batch_size": 4,
23
+ "epochs": 100,
24
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
25
+ "lr": 1e-4,
26
+ "queue_size": 1000,
27
+ "ckpt_path": "yolo11x.pt",
28
+ "save_path" : "dino-yolo.pt",
29
+
30
+ ## core-DINO logic parameters...
31
+ "teacher_temperature":0.04,
32
+ "student_temperature":0.1,
33
+ "teacher_ema" : 0.998,
34
+ }
35
+
36
+ # πŸ”„ Sync Student β†’ Teacher Weights
37
+ @torch.no_grad()
38
+ def initialize_teacher(student, teacher):
39
+ for ps, pt in zip(student.parameters(), teacher.parameters()):
40
+ pt.data.copy_(ps.data)
41
+
42
+ @torch.no_grad()
43
+ def update_teacher(student, teacher, m=0.996):
44
+ for ps, pt in zip(student.parameters(), teacher.parameters()):
45
+ pt.data.mul_(m).add_(ps.data, alpha=1 - m)
46
+
47
+
48
+ # 🧠 Model + Loss + Optimizer
49
+ def setup_model_and_loss():
50
+ student = YOLOBackBone(model_path=CFG["ckpt_path"]).to(CFG["device"])
51
+ teacher = YOLOBackBone(model_path=CFG["ckpt_path"]).to(CFG["device"])
52
+ for p in teacher.parameters():
53
+ p.requires_grad = False
54
+ loss_fn = DinoSpatialLoss(teacher_temp=CFG["teacher_temperature"],student_temp=CFG["student_temperature"]).to(CFG["device"])
55
+ optimizer = torch.optim.AdamW(student.parameters(), lr=CFG["lr"], weight_decay=0.05)
56
+ return student, teacher, loss_fn, optimizer
57
+
58
+
59
+ # πŸ” Training Loop
60
+ def train():
61
+ student, teacher, criterion, optimizer = setup_model_and_loss()
62
+ dataset = DinoDataset(imgsz=CFG["imgsz"], batch_size=CFG["batch_size"], queue_size=CFG["queue_size"])
63
+
64
+ num_epochs = CFG["epochs"]
65
+ device = CFG["device"]
66
+ for epoch in range(num_epochs):
67
+ running_loss = 0.0
68
+ running_entropy = 0.0
69
+ total_count = 0
70
+ loop = tqdm(dataset.store, desc=f"πŸ“… Epoch {epoch+1}/{num_epochs}")
71
+
72
+ for batch in loop:
73
+ images_s = torch.nan_to_num(batch['student'].float() / 255.0, nan=0.0).to(device)
74
+ images_t = torch.nan_to_num(batch['teacher'].float() / 255.0, nan=0.0).to(device)
75
+
76
+ with torch.no_grad():
77
+ teacher_out = teacher(images_t).detach()
78
+
79
+ with autocast(device_type='cuda', enabled=False):
80
+ student_out = student(images_s)
81
+ loss = criterion(student_out, teacher_out)
82
+
83
+ optimizer.zero_grad()
84
+ loss.backward()
85
+ optimizer.step()
86
+ update_teacher(student, teacher, m=CFG["teacher_ema"])
87
+
88
+ running_loss += loss.item()
89
+ total_count += 1
90
+
91
+ # πŸ“Š Entropy Calc
92
+ probs = F.softmax(teacher_out / CFG["teacher_temperature"], dim=1)
93
+ eps = 1e-6
94
+ entropy = -(probs * (probs + eps).log()).sum(dim=1).mean()
95
+ running_entropy += entropy.item()
96
+
97
+ # πŸ”„ Live Bar Update
98
+ loop.set_postfix({
99
+ "πŸ’₯ Loss": f"{loss.item():.4f}",
100
+ "πŸ“ˆ Entropy": f"{entropy.item():.4f}"
101
+ })
102
+
103
+ avg_loss = running_loss / total_count
104
+ avg_entropy = running_entropy / total_count
105
+
106
+ print(f"βœ… Epoch {epoch+1:03} | 🧠 Avg Loss: {avg_loss:.4f} | πŸ” Teacher Entropy: {avg_entropy:.4f} | πŸ’Ύ Saved β†’ {CFG['save_path']}")
107
+ torch.save(student.state_dict(), CFG["save_path"])
108
+
109
+
110
+
111
+ if __name__=="__main__":
112
+ train()
src/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import torch
4
+ from torch import nn
5
+ from torch.amp import autocast
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader
8
+
9
+ import copy
10
+ import queue
11
+ import numpy as np
12
+ import pandas as pd
13
+ import geopandas as gpd
14
+
15
+ import fsspec
16
+ import xarray as xr
17
+ from tqdm.notebook import tqdm
18
+
19
+ from ultralytics import YOLO
20
+ from IPython.display import clear_output
21
+
22
+ from multiprocessing import Manager
23
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
24
+
25
+ import huggingface_hub as hf
26
+ import albumentations as A
27
+
28
+ import h5py
29
+ import requests
30
+ from io import BytesIO
31
+
32
+ import datetime
33
+ from pathlib import Path
34
+ import tempfile
35
+ import shutil
36
+
37
+ # parallel processing of static datasets
38
+ manager = Manager()
39
+ shared_store = manager.list()
40
+ process_pool = ProcessPoolExecutor(max_workers=6)
41
+
42
+ def write_last_updated(path="store_last_updated.txt"):
43
+ with tempfile.NamedTemporaryFile("w", delete=False, dir=".") as tmp:
44
+ tmp.write(f"{datetime.datetime.now().isoformat()}")
45
+ tmp_path = tmp.name
46
+ shutil.move(tmp_path, path)
47
+
48
+
49
+ class AddPoissonNoise(A.ImageOnlyTransform):
50
+ def __init__(self, p=0.5):
51
+ super().__init__(p)
52
+
53
+ def apply(self, image, **params):
54
+ image = image.astype(np.float32) / 255.0 if image.dtype == np.uint8 else image.copy()
55
+ noisy = np.random.poisson(image * 255.0)
56
+ return np.clip(noisy, 0, 255).astype('uint8')
57
+
58
+ class AddSaltPepperNoise(A.ImageOnlyTransform):
59
+ def __init__(self, amount=0.02, salt_vs_pepper=0.5, p=0.5):
60
+ super(AddSaltPepperNoise, self).__init__(p)
61
+ self.amount = amount
62
+ self.salt_vs_pepper = salt_vs_pepper
63
+
64
+ def apply(self, image, **params):
65
+ noisy = image.copy()
66
+ num_salt = np.ceil(self.amount * image.size * self.salt_vs_pepper)
67
+ num_pepper = np.ceil(self.amount * image.size * (1.0 - self.salt_vs_pepper))
68
+
69
+ # Salt noise
70
+ coords = [np.random.randint(0, i - 1, int(num_salt)) for i in image.shape]
71
+ noisy[tuple(coords)] = 1
72
+
73
+ # Pepper noise
74
+ coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in image.shape]
75
+ noisy[tuple(coords)] = 0
76
+
77
+ return noisy