Commit
Β·
1e6fe0a
1
Parent(s):
75fecef
core-dino src added
Browse files- src/backbone.py +79 -0
- src/data.py +158 -0
- src/loss.py +144 -0
- src/train.py +112 -0
- src/utils.py +77 -0
src/backbone.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
𦴠core-dino | YOLO Backbone Wrapper for Feature Extraction π
|
3 |
+
|
4 |
+
Wraps a YOLO model to extract intermediate feature maps for DINO-style
|
5 |
+
self-supervised training. Optionally applies an MLP projection head.
|
6 |
+
|
7 |
+
Author: Gajesh Ladhar
|
8 |
+
π LinkedIn: https://www.linkedin.com/in/gajeshladhar/
|
9 |
+
π€ Hugging Face: https://huggingface.co/gajeshladhar
|
10 |
+
"""
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from ultralytics import YOLO
|
15 |
+
|
16 |
+
|
17 |
+
class YOLOBackBone(nn.Module):
|
18 |
+
"""
|
19 |
+
π§© Extracts multi-scale spatial features from YOLO backbone.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
model_path (str): Path to YOLO weights (.pt)
|
23 |
+
stop_at (int): Layer index to cut the model
|
24 |
+
use_mlp (bool): Whether to apply MLP projection head
|
25 |
+
mlp_dim (int): Output dim of MLP head (if enabled)
|
26 |
+
"""
|
27 |
+
def __init__(self, model_path='yolo11x.pt', stop_at=23, use_mlp=True, mlp_dim=512):
|
28 |
+
super().__init__()
|
29 |
+
raw_model = YOLO(model_path).model.train()
|
30 |
+
self.layers = nn.ModuleList(raw_model.model[:stop_at])
|
31 |
+
self.layer_defs = raw_model.yaml["backbone"] + raw_model.yaml["head"]
|
32 |
+
|
33 |
+
self.use_mlp = use_mlp
|
34 |
+
if use_mlp:
|
35 |
+
self.init_mlp(self._get_out_channels(self.layers[-1]), mlp_dim)
|
36 |
+
|
37 |
+
for p in self.parameters():
|
38 |
+
p.requires_grad = True
|
39 |
+
|
40 |
+
def _get_out_channels(self, layer):
|
41 |
+
return 768
|
42 |
+
|
43 |
+
def init_mlp(self, in_channels, out_channels):
|
44 |
+
self.mlp_head = nn.Identity()
|
45 |
+
# self.mlp_head = nn.Sequential(
|
46 |
+
# nn.Conv2d(in_channels, 2048, 1),
|
47 |
+
# nn.GELU(),
|
48 |
+
# nn.Conv2d(2048, out_channels, 1),
|
49 |
+
# nn.GELU(),
|
50 |
+
# nn.Conv2d(out_channels, in_channels, 1)
|
51 |
+
# )
|
52 |
+
|
53 |
+
def apply_mlp(self, x):
|
54 |
+
return self.mlp_head(x) if self.use_mlp else x
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
"""
|
58 |
+
π Forward pass through selected YOLO layers and optional MLP.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
x (Tensor): Input image tensor (B, C, H, W)
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Tensor: Final feature map
|
65 |
+
"""
|
66 |
+
outputs = []
|
67 |
+
for i, layer in enumerate(self.layers):
|
68 |
+
from_ids = self.layer_defs[i][0]
|
69 |
+
from_ids = [from_ids] if isinstance(from_ids, int) else from_ids
|
70 |
+
inputs = [x if j == -1 else outputs[j] for j in from_ids]
|
71 |
+
x = layer(inputs if len(inputs) > 1 else inputs[0])
|
72 |
+
outputs.append(x)
|
73 |
+
return self.apply_mlp(x)
|
74 |
+
|
75 |
+
def count_params(self):
|
76 |
+
total = sum(p.numel() for p in self.parameters())
|
77 |
+
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
78 |
+
return total, trainable
|
79 |
+
|
src/data.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""π¦ core-dino | Data Loader for Self-Supervised DINO Training on Core-Five π
|
2 |
+
|
3 |
+
This module defines the `DinoDataset` which streams multi-resolution
|
4 |
+
satellite patches from the Core-Five dataset, preparing teacher-student
|
5 |
+
views for resolution-agnostic self-supervised learning.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import io
|
10 |
+
import time
|
11 |
+
import torch
|
12 |
+
import random
|
13 |
+
import requests
|
14 |
+
import numpy as np
|
15 |
+
import geopandas as gpd
|
16 |
+
import h5py
|
17 |
+
import xarray as xr
|
18 |
+
from torch import nn
|
19 |
+
from torch.utils.data import Dataset
|
20 |
+
import albumentations as A
|
21 |
+
import fsspec
|
22 |
+
|
23 |
+
from utils import (
|
24 |
+
shared_store, process_pool, write_last_updated,
|
25 |
+
AddPoissonNoise, AddSaltPepperNoise
|
26 |
+
)
|
27 |
+
|
28 |
+
class DinoDataset(Dataset):
|
29 |
+
"""
|
30 |
+
π§ DinoDataset β resolution-agnostic loader for Core-Five π
|
31 |
+
|
32 |
+
Streams random crops of HR satellite images from Hugging Face,
|
33 |
+
creates clean (teacher) and augmented (student) views using
|
34 |
+
Albumentations & torch.
|
35 |
+
|
36 |
+
---
|
37 |
+
π€ Author: Gajesh Ladhar
|
38 |
+
π LinkedIn: π https://www.linkedin.com/in/gajeshladhar/
|
39 |
+
π€ Hugging Face: π€ https://huggingface.co/gajeshladhar
|
40 |
+
|
41 |
+
"""
|
42 |
+
def __init__(self, imgsz, batch_size=1, queue_size=50):
|
43 |
+
"""
|
44 |
+
π Init the dataset with remote Core-Five metadata and start
|
45 |
+
async patch fetching.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
imgsz (int): Patch size (min 320 recommended)
|
49 |
+
batch_size (int): Number of patches per batch
|
50 |
+
queue_size (int): Max queue length for shared store
|
51 |
+
"""
|
52 |
+
if imgsz < 320:
|
53 |
+
raise ValueError("βοΈimgsz must be β₯ 320 for stable patch extraction β got {}".format(imgsz))
|
54 |
+
self.imgsz = imgsz
|
55 |
+
metadata_url = "https://huggingface.co/datasets/gajeshladhar/core-five/resolve/main/metadata.parquet"
|
56 |
+
self.df_metadata = gpd.read_parquet(fsspec.open(metadata_url).open())
|
57 |
+
self.batch_size = batch_size
|
58 |
+
self.queue_size = queue_size
|
59 |
+
self.store = shared_store
|
60 |
+
|
61 |
+
for _ in range(6):
|
62 |
+
process_pool.submit(self.fetch_and_store)
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def transform(batch):
|
66 |
+
"""
|
67 |
+
ποΈ Apply augmentation pipeline to simulate degraded inputs
|
68 |
+
for student; teacher gets clean view. Maintains shape consistency.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
Dict with 'student' and 'teacher' uint8 tensors
|
72 |
+
"""
|
73 |
+
augment_satellite = A.Compose([
|
74 |
+
A.GaussNoise(std_range=(0.01, 0.1), p=0.3),
|
75 |
+
AddPoissonNoise(p=0.3),
|
76 |
+
AddSaltPepperNoise(amount=0.02, p=0.3),
|
77 |
+
A.MultiplicativeNoise(multiplier=(0.9, 1.1), elementwise=True, p=0.3),
|
78 |
+
A.MotionBlur(blur_limit=(3, 11), p=0.3),
|
79 |
+
A.GaussianBlur(blur_limit=(3, 11), p=0.3),
|
80 |
+
A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=0.1),
|
81 |
+
A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=0.3),
|
82 |
+
A.RGBShift(r_shift_limit=30, g_shift_limit=30, b_shift_limit=30, p=0.3),
|
83 |
+
A.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=30, val_shift_limit=30, p=0.3),
|
84 |
+
A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.2),
|
85 |
+
A.CoarseDropout(num_holes_range=(1, 4), hole_height_range=(0.05, 0.2),
|
86 |
+
hole_width_range=(0.05, 0.2), fill='random_uniform', p=0.1)
|
87 |
+
])
|
88 |
+
|
89 |
+
imgsz_half = batch[0].shape[-1]
|
90 |
+
size = np.random.choice(np.arange(32 * 10, imgsz_half, 32))
|
91 |
+
|
92 |
+
student, teacher = [], []
|
93 |
+
for img in batch:
|
94 |
+
student_data = nn.Upsample(size=size, mode='bilinear')(torch.tensor(img[np.newaxis, :]))[0].data.numpy().astype("uint8")
|
95 |
+
student_data = augment_satellite(image=student_data.transpose(1, 2, 0))['image'].transpose(2, 0, 1)
|
96 |
+
student.append(torch.tensor(student_data))
|
97 |
+
teacher.append(torch.tensor(img))
|
98 |
+
|
99 |
+
return {
|
100 |
+
"student": torch.stack(student).to(torch.uint8),
|
101 |
+
"teacher": torch.stack(teacher).to(torch.uint8)
|
102 |
+
}
|
103 |
+
|
104 |
+
def fetch_and_store(self):
|
105 |
+
"""
|
106 |
+
π Continuously samples random crops from Core-Five, augments
|
107 |
+
them via `transform`, and updates the shared queue for training.
|
108 |
+
"""
|
109 |
+
np.random.seed(int.from_bytes(os.urandom(4), 'little'))
|
110 |
+
while True:
|
111 |
+
try:
|
112 |
+
batch = []
|
113 |
+
for _ in range(self.batch_size):
|
114 |
+
path = os.path.join("https://huggingface.co/datasets/gajeshladhar/core-five/resolve/main/",
|
115 |
+
self.df_metadata.sample(n=1).path.iloc[0])
|
116 |
+
buffer = io.BytesIO(requests.get(path, headers={"User-Agent": "Mozilla/5.0"}).content)
|
117 |
+
with h5py.File(buffer, "r") as f:
|
118 |
+
x = f["hr/x"][:]
|
119 |
+
y = f["hr/y"][:]
|
120 |
+
data = f["/hr/data"][:]
|
121 |
+
bands = list(range(data.shape[0]))
|
122 |
+
|
123 |
+
ds = xr.DataArray(data, dims=['band', 'y', 'x'], coords=[bands, y, x]).astype("uint8")
|
124 |
+
|
125 |
+
imgsz_half = self.imgsz // 2
|
126 |
+
yid = np.random.randint(imgsz_half, len(ds.y) - imgsz_half)
|
127 |
+
xid = np.random.randint(imgsz_half, len(ds.x) - imgsz_half)
|
128 |
+
ds = ds.isel(y=range(yid - imgsz_half, yid + imgsz_half),
|
129 |
+
x=range(xid - imgsz_half, xid + imgsz_half)).compute()
|
130 |
+
ds['y'], ds['x'] = np.linspace(ds.y.values[0], ds.y.values[-1], ds.shape[1]), \
|
131 |
+
np.linspace(ds.x.values[0], ds.x.values[-1], ds.shape[2])
|
132 |
+
|
133 |
+
batch.append(ds.data)
|
134 |
+
|
135 |
+
result = DinoDataset.transform(batch)
|
136 |
+
if len(self.store) >= self.queue_size:
|
137 |
+
index = np.random.randint(0, self.queue_size - 1)
|
138 |
+
self.store[index] = result
|
139 |
+
else:
|
140 |
+
self.store.append(result)
|
141 |
+
|
142 |
+
# enable for getting recent updates
|
143 |
+
if np.random.random() < 0.20:
|
144 |
+
write_last_updated()
|
145 |
+
except KeyboardInterrupt:
|
146 |
+
break
|
147 |
+
except Exception as e:
|
148 |
+
print("ERROR:", e)
|
149 |
+
continue
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
if __name__=="__main__":
|
155 |
+
dataset = DinoDataset(imgsz=1696,batch_size=3,queue_size=1000)
|
156 |
+
while True :
|
157 |
+
print(len(dataset.store))
|
158 |
+
time.sleep(5)
|
src/loss.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
π― core-dino | DINO-style Loss Functions π₯
|
3 |
+
|
4 |
+
Defines the cross-view contrastive loss used in DINO setups,
|
5 |
+
including temperature scaling, centering, and teacher-student divergence.
|
6 |
+
|
7 |
+
Includes:
|
8 |
+
- DinoSpatialLoss: Temp-scaled CE loss with center momentum π
|
9 |
+
- DinoSinkhornSpatialLoss: Sinkhorn-based balanced assignment loss βοΈ
|
10 |
+
|
11 |
+
Author: Gajesh Ladhar
|
12 |
+
π LinkedIn: https://www.linkedin.com/in/gajeshladhar/
|
13 |
+
π€ Hugging Face: https://huggingface.co/gajeshladhar
|
14 |
+
"""
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
|
21 |
+
class DinoSpatialLoss(nn.Module):
|
22 |
+
"""
|
23 |
+
π DINO loss using temperature-scaled cross-entropy over spatial tokens.
|
24 |
+
|
25 |
+
- Aligns teacher & student spatial features (B, C, H, W)
|
26 |
+
- Applies center momentum for teacher stability
|
27 |
+
|
28 |
+
Args:
|
29 |
+
teacher_temp (float): Temperature for teacher softmax
|
30 |
+
student_temp (float): Temperature for student softmax
|
31 |
+
center_momentum (float): EMA factor for center update
|
32 |
+
"""
|
33 |
+
def __init__(self, teacher_temp=0.04, student_temp=0.1, center_momentum=0.9):
|
34 |
+
super().__init__()
|
35 |
+
self.teacher_temp = teacher_temp
|
36 |
+
self.student_temp = student_temp
|
37 |
+
self.center_momentum = center_momentum
|
38 |
+
self.register_buffer("center", torch.zeros(1, 1)) # lazy init
|
39 |
+
|
40 |
+
def forward(self, student_feat, teacher_feat):
|
41 |
+
"""
|
42 |
+
Compute loss over (B, C, H, W) features.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
student_feat (Tensor): Student output, shape (B, C, Hs, Ws)
|
46 |
+
teacher_feat (Tensor): Teacher output, shape (B, C, Ht, Wt)
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Tensor: Scalar DINO loss
|
50 |
+
"""
|
51 |
+
|
52 |
+
# Initialize center shape based on teacher feature dim
|
53 |
+
if self.center.shape[1] == 1:
|
54 |
+
self.center = self.center.new_zeros(1, teacher_feat.shape[1])
|
55 |
+
|
56 |
+
# Resize student to teacher resolution
|
57 |
+
student_resized = F.interpolate(student_feat, size=teacher_feat.shape[2:], mode='bilinear', align_corners=False)
|
58 |
+
|
59 |
+
# Flatten spatial dims: (B, C, H, W) β (B*H*W, C)
|
60 |
+
B, C, H, W = student_resized.shape
|
61 |
+
student_flat = student_resized.permute(0, 2, 3, 1).reshape(-1, C) # (BHW, C)
|
62 |
+
teacher_flat = teacher_feat.permute(0, 2, 3, 1).reshape(-1, C) # (BHW, C)
|
63 |
+
|
64 |
+
# Apply softmax (teacher uses center)
|
65 |
+
student_logits = student_flat / self.student_temp
|
66 |
+
teacher_logits = (teacher_flat - self.center) / self.teacher_temp
|
67 |
+
|
68 |
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
69 |
+
teacher_probs = F.softmax(teacher_logits, dim=-1).detach()
|
70 |
+
|
71 |
+
# Cross-entropy loss
|
72 |
+
loss = - (teacher_probs * student_log_probs).sum(dim=-1).mean()
|
73 |
+
|
74 |
+
# Update center
|
75 |
+
batch_center = teacher_probs.mean(dim=0, keepdim=True)
|
76 |
+
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
|
77 |
+
|
78 |
+
return loss
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
class SinkhornKnopp(nn.Module):
|
83 |
+
"""
|
84 |
+
βοΈ Sinkhorn-Knopp normalization for balanced assignments.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
num_iters (int): Number of normalization iterations
|
88 |
+
eps (float): Stabilizer to avoid div-by-zero
|
89 |
+
"""
|
90 |
+
def __init__(self, num_iters: int = 3, eps: float = 1e-6):
|
91 |
+
super().__init__()
|
92 |
+
self.num_iters = num_iters
|
93 |
+
self.eps = eps
|
94 |
+
|
95 |
+
def forward(self, logits: torch.Tensor) -> torch.Tensor:
|
96 |
+
logits = logits - logits.max(dim=1, keepdim=True)[0] # stabilize
|
97 |
+
Q = torch.exp(logits).clone()
|
98 |
+
Q /= Q.sum()
|
99 |
+
|
100 |
+
for _ in range(self.num_iters):
|
101 |
+
Q /= Q.sum(dim=1, keepdim=True) + self.eps # row normalization
|
102 |
+
Q /= Q.sum(dim=0, keepdim=True) + self.eps # column normalization
|
103 |
+
|
104 |
+
return Q
|
105 |
+
|
106 |
+
class DinoSinkhornSpatialLoss(nn.Module):
|
107 |
+
"""
|
108 |
+
π DINO loss with Sinkhorn assignment β no center, balanced targets.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
student_temp (float): Temperature for student softmax
|
112 |
+
sinkhorn_iters (int): Iterations for Sinkhorn normalization
|
113 |
+
"""
|
114 |
+
def __init__(self, student_temp=0.1, sinkhorn_iters=3):
|
115 |
+
super().__init__()
|
116 |
+
self.student_temp = student_temp
|
117 |
+
self.sinkhorn = SinkhornKnopp(sinkhorn_iters)
|
118 |
+
|
119 |
+
def forward(self, student_feat, teacher_feat):
|
120 |
+
"""
|
121 |
+
student_feat: (B, C, Hs, Ws)
|
122 |
+
teacher_feat: (B, C, Ht, Wt)
|
123 |
+
"""
|
124 |
+
|
125 |
+
# Resize student to teacher resolution
|
126 |
+
student_resized = F.interpolate(
|
127 |
+
student_feat, size=teacher_feat.shape[2:], mode='bilinear', align_corners=False
|
128 |
+
)
|
129 |
+
|
130 |
+
# Flatten spatial dims: (B, C, H, W) β (BHW, C)
|
131 |
+
B, C, H, W = student_resized.shape
|
132 |
+
student_flat = student_resized.permute(0, 2, 3, 1).reshape(-1, C)
|
133 |
+
teacher_flat = teacher_feat.permute(0, 2, 3, 1).reshape(-1, C)
|
134 |
+
|
135 |
+
# Teacher: apply Sinkhorn (no temp, no center)
|
136 |
+
teacher_probs = self.sinkhorn(teacher_flat).detach()
|
137 |
+
|
138 |
+
# Student: softmax with temp
|
139 |
+
student_log_probs = F.log_softmax(student_flat / self.student_temp, dim=-1)
|
140 |
+
|
141 |
+
# Cross-entropy loss
|
142 |
+
loss = -(teacher_probs * student_log_probs).sum(dim=-1).mean()
|
143 |
+
|
144 |
+
return loss
|
src/train.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
π°οΈ core-dino | Training Script for Resolution-Agnostic SSL on Satellite Imagery
|
3 |
+
|
4 |
+
Trains DINO with a YOLO backbone using multi-resolution Core-Five patches.
|
5 |
+
|
6 |
+
π¨βπ» Author: Gajesh Ladhar
|
7 |
+
π LinkedIn: https://www.linkedin.com/in/gajeshladhar/
|
8 |
+
π€ Hugging Face: https://huggingface.co/gajeshladhar
|
9 |
+
"""
|
10 |
+
|
11 |
+
# π¦ Imports
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from loss import DinoSpatialLoss
|
15 |
+
from backbone import YOLOBackBone
|
16 |
+
from data import DinoDataset
|
17 |
+
from utils import *
|
18 |
+
|
19 |
+
# βοΈ Config
|
20 |
+
CFG = {
|
21 |
+
"imgsz": 1696,
|
22 |
+
"batch_size": 4,
|
23 |
+
"epochs": 100,
|
24 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
25 |
+
"lr": 1e-4,
|
26 |
+
"queue_size": 1000,
|
27 |
+
"ckpt_path": "yolo11x.pt",
|
28 |
+
"save_path" : "dino-yolo.pt",
|
29 |
+
|
30 |
+
## core-DINO logic parameters...
|
31 |
+
"teacher_temperature":0.04,
|
32 |
+
"student_temperature":0.1,
|
33 |
+
"teacher_ema" : 0.998,
|
34 |
+
}
|
35 |
+
|
36 |
+
# π Sync Student β Teacher Weights
|
37 |
+
@torch.no_grad()
|
38 |
+
def initialize_teacher(student, teacher):
|
39 |
+
for ps, pt in zip(student.parameters(), teacher.parameters()):
|
40 |
+
pt.data.copy_(ps.data)
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def update_teacher(student, teacher, m=0.996):
|
44 |
+
for ps, pt in zip(student.parameters(), teacher.parameters()):
|
45 |
+
pt.data.mul_(m).add_(ps.data, alpha=1 - m)
|
46 |
+
|
47 |
+
|
48 |
+
# π§ Model + Loss + Optimizer
|
49 |
+
def setup_model_and_loss():
|
50 |
+
student = YOLOBackBone(model_path=CFG["ckpt_path"]).to(CFG["device"])
|
51 |
+
teacher = YOLOBackBone(model_path=CFG["ckpt_path"]).to(CFG["device"])
|
52 |
+
for p in teacher.parameters():
|
53 |
+
p.requires_grad = False
|
54 |
+
loss_fn = DinoSpatialLoss(teacher_temp=CFG["teacher_temperature"],student_temp=CFG["student_temperature"]).to(CFG["device"])
|
55 |
+
optimizer = torch.optim.AdamW(student.parameters(), lr=CFG["lr"], weight_decay=0.05)
|
56 |
+
return student, teacher, loss_fn, optimizer
|
57 |
+
|
58 |
+
|
59 |
+
# π Training Loop
|
60 |
+
def train():
|
61 |
+
student, teacher, criterion, optimizer = setup_model_and_loss()
|
62 |
+
dataset = DinoDataset(imgsz=CFG["imgsz"], batch_size=CFG["batch_size"], queue_size=CFG["queue_size"])
|
63 |
+
|
64 |
+
num_epochs = CFG["epochs"]
|
65 |
+
device = CFG["device"]
|
66 |
+
for epoch in range(num_epochs):
|
67 |
+
running_loss = 0.0
|
68 |
+
running_entropy = 0.0
|
69 |
+
total_count = 0
|
70 |
+
loop = tqdm(dataset.store, desc=f"π
Epoch {epoch+1}/{num_epochs}")
|
71 |
+
|
72 |
+
for batch in loop:
|
73 |
+
images_s = torch.nan_to_num(batch['student'].float() / 255.0, nan=0.0).to(device)
|
74 |
+
images_t = torch.nan_to_num(batch['teacher'].float() / 255.0, nan=0.0).to(device)
|
75 |
+
|
76 |
+
with torch.no_grad():
|
77 |
+
teacher_out = teacher(images_t).detach()
|
78 |
+
|
79 |
+
with autocast(device_type='cuda', enabled=False):
|
80 |
+
student_out = student(images_s)
|
81 |
+
loss = criterion(student_out, teacher_out)
|
82 |
+
|
83 |
+
optimizer.zero_grad()
|
84 |
+
loss.backward()
|
85 |
+
optimizer.step()
|
86 |
+
update_teacher(student, teacher, m=CFG["teacher_ema"])
|
87 |
+
|
88 |
+
running_loss += loss.item()
|
89 |
+
total_count += 1
|
90 |
+
|
91 |
+
# π Entropy Calc
|
92 |
+
probs = F.softmax(teacher_out / CFG["teacher_temperature"], dim=1)
|
93 |
+
eps = 1e-6
|
94 |
+
entropy = -(probs * (probs + eps).log()).sum(dim=1).mean()
|
95 |
+
running_entropy += entropy.item()
|
96 |
+
|
97 |
+
# π Live Bar Update
|
98 |
+
loop.set_postfix({
|
99 |
+
"π₯ Loss": f"{loss.item():.4f}",
|
100 |
+
"π Entropy": f"{entropy.item():.4f}"
|
101 |
+
})
|
102 |
+
|
103 |
+
avg_loss = running_loss / total_count
|
104 |
+
avg_entropy = running_entropy / total_count
|
105 |
+
|
106 |
+
print(f"β
Epoch {epoch+1:03} | π§ Avg Loss: {avg_loss:.4f} | π Teacher Entropy: {avg_entropy:.4f} | πΎ Saved β {CFG['save_path']}")
|
107 |
+
torch.save(student.state_dict(), CFG["save_path"])
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
if __name__=="__main__":
|
112 |
+
train()
|
src/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.amp import autocast
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
|
9 |
+
import copy
|
10 |
+
import queue
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
import geopandas as gpd
|
14 |
+
|
15 |
+
import fsspec
|
16 |
+
import xarray as xr
|
17 |
+
from tqdm.notebook import tqdm
|
18 |
+
|
19 |
+
from ultralytics import YOLO
|
20 |
+
from IPython.display import clear_output
|
21 |
+
|
22 |
+
from multiprocessing import Manager
|
23 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
24 |
+
|
25 |
+
import huggingface_hub as hf
|
26 |
+
import albumentations as A
|
27 |
+
|
28 |
+
import h5py
|
29 |
+
import requests
|
30 |
+
from io import BytesIO
|
31 |
+
|
32 |
+
import datetime
|
33 |
+
from pathlib import Path
|
34 |
+
import tempfile
|
35 |
+
import shutil
|
36 |
+
|
37 |
+
# parallel processing of static datasets
|
38 |
+
manager = Manager()
|
39 |
+
shared_store = manager.list()
|
40 |
+
process_pool = ProcessPoolExecutor(max_workers=6)
|
41 |
+
|
42 |
+
def write_last_updated(path="store_last_updated.txt"):
|
43 |
+
with tempfile.NamedTemporaryFile("w", delete=False, dir=".") as tmp:
|
44 |
+
tmp.write(f"{datetime.datetime.now().isoformat()}")
|
45 |
+
tmp_path = tmp.name
|
46 |
+
shutil.move(tmp_path, path)
|
47 |
+
|
48 |
+
|
49 |
+
class AddPoissonNoise(A.ImageOnlyTransform):
|
50 |
+
def __init__(self, p=0.5):
|
51 |
+
super().__init__(p)
|
52 |
+
|
53 |
+
def apply(self, image, **params):
|
54 |
+
image = image.astype(np.float32) / 255.0 if image.dtype == np.uint8 else image.copy()
|
55 |
+
noisy = np.random.poisson(image * 255.0)
|
56 |
+
return np.clip(noisy, 0, 255).astype('uint8')
|
57 |
+
|
58 |
+
class AddSaltPepperNoise(A.ImageOnlyTransform):
|
59 |
+
def __init__(self, amount=0.02, salt_vs_pepper=0.5, p=0.5):
|
60 |
+
super(AddSaltPepperNoise, self).__init__(p)
|
61 |
+
self.amount = amount
|
62 |
+
self.salt_vs_pepper = salt_vs_pepper
|
63 |
+
|
64 |
+
def apply(self, image, **params):
|
65 |
+
noisy = image.copy()
|
66 |
+
num_salt = np.ceil(self.amount * image.size * self.salt_vs_pepper)
|
67 |
+
num_pepper = np.ceil(self.amount * image.size * (1.0 - self.salt_vs_pepper))
|
68 |
+
|
69 |
+
# Salt noise
|
70 |
+
coords = [np.random.randint(0, i - 1, int(num_salt)) for i in image.shape]
|
71 |
+
noisy[tuple(coords)] = 1
|
72 |
+
|
73 |
+
# Pepper noise
|
74 |
+
coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in image.shape]
|
75 |
+
noisy[tuple(coords)] = 0
|
76 |
+
|
77 |
+
return noisy
|