Earth Observation
Foundation Model
Remote Sensing
msohaildanish commited on
Commit
dc6ae70
·
verified ·
1 Parent(s): ea2e1c4

Upload model weights

Browse files
Files changed (7) hide show
  1. .gitattributes +3 -0
  2. README.md +179 -0
  3. TerraFM-B.pth +3 -0
  4. images/arch.jpg +3 -0
  5. images/ls4s_qual.jpg +3 -0
  6. images/spider_gb.jpg +3 -0
  7. terrafm.py +374 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/arch.jpg filter=lfs diff=lfs merge=lfs -text
37
+ images/ls4s_qual.jpg filter=lfs diff=lfs merge=lfs -text
38
+ images/spider_gb.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TerraFM: A Scalable Foundation Model for Unified Multisensor Earth Observation
2
+ <p align="center">
3
+ <img src="https://i.imgur.com/waxVImv.png" alt="Oryx TerraFM">
4
+ </p>
5
+
6
+ #### [Muhammad Sohail Danish](https://www.linkedin.com/in/muhammad-sohail-danish/), [Muhammad Akhtar Munir](https://akhtarvision.github.io/), [Syed Roshaan Ali Shah](https://www.linkedin.com/in/syed-roshaan-ali-shah-b797b44a/), [Muhammad Haris Khan](https://www.linkedin.com/in/kartik-kuckreja-930531221/), [Rao Muhammad Anwer](https://research.ibm.com/people/paolo-fraccaro) , [Jorma Laaksonen](https://www.servicenow.com/research/author/alexandre-lacoste.html), [Fahad Shahbaz Khan](https://sites.google.com/view/fahadkhans/home), and [Salman Khan](https://salman-h-khan.github.io/)
7
+
8
+
9
+ #### **Mohamed bin Zayed University of AI, University College London, Aalto University, Linköping University, Australian National University**
10
+
11
+ [![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)]()
12
+ [![Model Zoo](https://img.shields.io/badge/Model%20Zoo-HuggingFace-blue)](#🧠-model-zoo)
13
+
14
+ ---
15
+
16
+ ## 📢 Latest Updates
17
+ - **Jun-09-25**: 🚀 Initial release of **TerraFM codebase** and **pretrained models**
18
+ - **Jun-09-25**: 📄 Paper released on arXiv: [arxiv link](). 🔥🔥
19
+
20
+ ---
21
+
22
+ ## 🌍 Overview
23
+
24
+ **TerraFM** is a scalable foundation model designed for unified processing of multisensor Earth Observation (EO) data. Built on a ViT backbone and trained over **18.7M tiles (~23T pixels)** from Sentinel-1 SAR and Sentinel-2 optical imagery, TerraFM unifies modality-specific inputs using:
25
+
26
+ - 🧩 Modality-specific patch embeddings
27
+ - 🌀 Adaptive cross-attention fusion
28
+ - 🎯 Dual-centering regularization for long-tailed distributions
29
+
30
+ TerraFM sets a new benchmark on **GEO-Bench** and **Copernicus-Bench**, demonstrating strong generalization across geographies, modalities, and tasks — including classification, segmentation, and landslide detection.
31
+
32
+ ---
33
+
34
+
35
+ ## 🔬 Key Features
36
+
37
+ <p align="center">
38
+ <img src="images/spider_gb.jpg" alt="TerraFM Architecture" width="500"/>
39
+ </p>
40
+
41
+ - **Multimodal Pretraining**: Uses Sentinel-1 (SAR) and Sentinel-2 (L1C, L2A) as natural augmentations.
42
+ - **Large-Scale Dataset**: Trained on 18.7M global tiles from the [Major-TOM](https://huggingface.co/Major-TOM) dataset.
43
+ - **Cross-Attention Fusion**: Dynamically aggregates information across sensors at patch level.
44
+ - **Dual-Centering**: Mitigates long-tailed land cover bias using ESA WorldCover statistics.
45
+ - **Benchmark SOTA**: Outperforms prior FMs (Galileo, Prithvi, DOFA) across multiple EO tasks.
46
+
47
+ ---
48
+ ## 🧱 Architecture
49
+
50
+ <p align="center">
51
+ <img src="images/arch.jpg" alt="TerraFM Architecture" width="700"/>
52
+ </p>
53
+
54
+ Overall architecture of TerraFM. It unifies student-teacher contrastive framework with modality augmentation with cross-attention fusion, and a new dual centering regularization. TerraFM is founded on ViT backbone and is trained on 18.7M globally distributed samples for pre-training and utilizes large-tile inputs for encoding broader spatial context. For illustration, RGB channels from S2-L2A and S2-L1C are selected, and S1 is visualized using a false-color RGB composite.
55
+
56
+ ---
57
+ ## 🧠 Model Zoo
58
+
59
+ | Model | Modality | Input Size | Backbone | Link |
60
+ |-------|----------|------------|--------|------|
61
+ | TerraFM-B | Sentinel-1 RTC + Sentinel-2 Level 2A + Sentinel-2 Level 1C | 224×224 | ViT-Base | [Download](https://huggingface.co/MBZUAI/TerraFM) |
62
+ | TerraFM-L | Sentinel-1 RTC + Sentinel-2 Level 2A + Sentinel-2 Level 1C | 224×224 | ViT-Large | [Download](https://huggingface.co/MBZUAI/TerraFM) |
63
+
64
+ ---
65
+
66
+ ## 🛠 Usage
67
+
68
+ TerraFM can be used directly via the `terrafm.py` module, which provides standalone implementations of the TerraFM-Base and TerraFM-Large models for easy integration into any codebase.
69
+
70
+ ```python
71
+ from terrafm import terrafm_base, terrafm_large
72
+ import torch
73
+
74
+ # Simulated input: 1 sample, 12 channels, 224×224 resolution (e.g., Sentinel-2 L2A)
75
+ x = torch.randn(1, 12, 224, 224)
76
+
77
+ # Load TerraFM-Base model
78
+ model = terrafm_base()
79
+
80
+ # Load pretrained weights (e.g., TerraFM-B.pth)
81
+ state_dict = torch.load("TerraFM-B.pth", map_location="cpu")
82
+ msg = model.load_state_dict(state_dict, strict=False)
83
+
84
+ # Forward pass
85
+ y = model(x)
86
+ print(f"Output shape: {y.shape}")
87
+ ```
88
+ ---
89
+
90
+
91
+ ## 📊 Results
92
+
93
+ ### 🔍 k-NN Classification Results
94
+
95
+ We evaluate image classification using k-nearest neighbors (kNN) and report Top-1 accuracy for all single-label tasks. For the multilabel BigEarthNet benchmark, we report the F1 score.
96
+
97
+ | Model | Backbone | m-EuroSat (100%) | m-EuroSat (1%) | m-BigEarthNet (100%) | m-BigEarthNet (1%) | m-So2Sat (100%) | m-So2Sat (1%) | m-Brick-Kiln (100%) | m-Brick-Kiln (1%) |
98
+ |----------------|------------|------------------|----------------|------------------------|--------------------|------------------|----------------|----------------------|--------------------|
99
+ | SatMAE | ViT-Base | 84.1 | 34.8 | 50.6 | 29.0 | 36.0 | 23.1 | 86.1 | 73.5 |
100
+ | SatMAE++ | ViT-Large | 82.7 | 48.5 | 50.8 | 31.6 | 34.7 | 23.4 | 89.6 | 76.7 |
101
+ | CROMA | ViT-Base | 85.6 | 51.3 | 58.8 | 44.7 | 48.8 | 33.8 | 92.6 | 85.1 |
102
+ | SoftCon | ViT-Small | 89.8 | 27.2 | 64.7 | 43.3 | 51.1 | 31.4 | 89.2 | 77.8 |
103
+ | DOFA | ViT-Base | 82.8 | 49.6 | 49.4 | 29.9 | 41.4 | 29.4 | 88.3 | 78.3 |
104
+ | Satlas | Swin-Tiny | 81.7 | 35.8 | 51.9 | 29.6 | 36.6 | 27.1 | 88.2 | 73.0 |
105
+ | MMEarth | CNN-atto | 81.7 | 30.0 | 58.3 | 39.6 | 39.8 | 25.1 | 89.4 | 79.7 |
106
+ | DeCUR | ViT-Small | 89.0 | 46.6 | 63.8 | 49.6 | 45.8 | 30.9 | 83.7 | 74.2 |
107
+ | AnySat | ViT-Base | 82.2 | 47.1 | 54.9 | 33.7 | 39.8 | 29.0 | 85.3 | 72.0 |
108
+ | Galileo | ViT-Base | 93.0 | 56.6 | 59.0 | 36.5 | 54.8 | **43.2** | 90.7 | 78.0 |
109
+ | Prithvi-2.0 | ViT-Large | 80.2 | 48.0 | 49.4 | 28.8 | 29.5 | 26.1 | 87.9 | 80.6 |
110
+ | Copernicus-FM | ViT-Base | 76.0 | 47.4 | 53.8 | 33.3 | 38.4 | 23.3 | 93.0 | 83.2 |
111
+ | **TerraFM** | ViT-Base | _94.2_ | _59.3_ | _68.7_ | 49.4 | _55.1_ | _41.6_ | **94.5** | **85.6** |
112
+ |**TerraFM**| ViT-Large | **95.1** | **62.1** | **69.4** | **50.6** | **55.9** | 41.1 | _93.0_ | 82.2 |
113
+
114
+
115
+ ### 🛰 Copernicus-Bench
116
+
117
+ Comparison of TerraFM with existing supervised and self-supervised methods on **Copernicus-Bench**.
118
+ Metrics include **OA** (Overall Accuracy), **mAP** (mean Average Precision), and **mIoU** (mean Intersection over Union).
119
+
120
+ | Dataset | Metric | Supervised | Random | SoftCon | CROMA | DOFA | Copernicus-FM | **TerraFM** |
121
+ |----------------|--------|------------|--------|---------|--------|------|----------------|-------------|
122
+ | **Backbone** | -- | ViT-B/16 | ViT-B/16 | ViT-B/14 | ViT-B/8 | ViT-B/16 | ViT-B/16 | ViT-B/16 |
123
+ | **Cloud-S2** | mIoU | 59.4 | 60.4 | 66.9 | 65.0 | 65.0 | 66.7 | **67.9** |
124
+ | **EuroSAT-S1** | OA | 81.5 | 75.4 | 83.6 | 83.9 | 81.7 | 87.2 | **87.8** |
125
+ | **EuroSAT-S2** | OA | 97.6 | 92.5 | 96.7 | 97.0 | 97.2 | 97.9 | **99.1** |
126
+ | **BigEarthNet-S1** | mAP | 70.6 | 63.8 | **78.7**| 70.8 | 70.5 | 77.9 | 76.9 |
127
+ | **BigEarthNet-S2** | mAP | 80.1 | 71.6 | 83.6 | 76.4 | 75.5 | 79.0 | **84.4** |
128
+ | **DFC2020-S1** | mIoU | 50.8 | 45.4 | 52.8 | 52.7 | 49.7 | 52.4 | **55.4** |
129
+ | **DFC2020-S2** | mIoU | 66.2 | 62.3 | 64.1 | **66.5**| 61.8 | 64.5 | 63.8 |
130
+ | **LCZ-S2** | OA | 85.3 | 77.4 | 83.6 | 84.1 | 83.0 | 84.4 | **87.0** |
131
+
132
+ ### 🧪 GEO-Bench Performance
133
+
134
+ Performance comparison on GEO-Bench for both **classification** (Top-1 Accuracy), **segmentation** (mIoU), and **F1 score** (for m-BigEarthNet).
135
+ TerraFM achieves state-of-the-art results across multiple datasets, outperforming previous foundation models.
136
+
137
+ | Method | Backbone | m-EuroSat | m-BigEarthNet | m-So2Sat | m-Brick-Kiln | m-Cashew-Plant | m-SA-Crop-Type |
138
+ |--------------|------------|-----------|----------------|----------|----------------|------------------|------------------|
139
+ | SatMAE | ViT-Large | 96.6 | 68.3 | 57.2 | 98.4 | 30.8 | 24.8 |
140
+ | SatMAE++ | ViT-Large | 96.5 | 67.9 | 56.0 | 98.6 | 29.6 | 25.7 |
141
+ | CROMA | ViT-Large | 96.6 | 71.9 | 60.6 | 98.7 | 31.8 | 32.0 |
142
+ | SoftCon | ViT-Base | 97.5 | 70.3 | 61.7 | 98.7 | 29.6 | 30.8 |
143
+ | DOFA | ViT-Large | 96.9 | 68.0 | 58.7 | 98.6 | 27.7 | 25.4 |
144
+ | Satlas | Swin-Base | 97.5 | 72.8 | 61.9 | **98.9** | 25.1 | 23.4 |
145
+ | MMEarth | CNN-atto | 95.7 | 70.0 | 57.2 | 98.9 | 24.2 | 22.2 |
146
+ | DeCUR | ViT-Small | 97.9 | 70.9 | 61.7 | 98.7 | 26.2 | 21.5 |
147
+ | Prithvi 2.0 | ViT-Large | 96.5 | 69.0 | 54.6 | 98.6 | 26.7 | 22.9 |
148
+ | AnySat | ViT-Base | 95.9 | 70.3 | 51.8 | 98.6 | 26.1 | 27.1 |
149
+ | Galileo | ViT-Base | 97.7 | 70.7 | 63.3 | 98.7 | 33.0 | 30.1 |
150
+ | **TerraFM** | ViT-Base | *98.1* | 72.6 | *64.9* | 98.7 | *34.1* | *33.0* |
151
+ | **TerraFM** | ViT-Large | **98.6** | **73.1** | **66.6** | **99.0** | **37.2** | **34.5** |
152
+
153
+
154
+ ### 🌋 Landslide Detection (Landslide4Sense)
155
+
156
+ Landslide detection performance on the **Landslide4Sense** test set.
157
+ Despite having significantly fewer parameters (120M vs. 300M), **TerraFM** achieves higher overall segmentation performance, especially for landslide regions.
158
+ | Model | mIoU | IoU (Landslide) |
159
+ |------------------------|------|-----------------|
160
+ | Prithvi-EO-2.0 (300M) | 65.0 | 31.5 |
161
+ | **TerraFM (120M)** | **70.8** | **43.1** |
162
+
163
+ <p align="center">
164
+ <img src="images/ls4s_qual.jpg" alt="Landslide Detection" width="700"/>
165
+ </p>
166
+ ---
167
+
168
+ ## 📜 Citation
169
+ If you find our work and this repository useful, please consider giving our repo a star and citing our paper as follows:
170
+ ```bibtex
171
+ @article{
172
+ }
173
+ ```
174
+
175
+
176
+
177
+
178
+ ## 📨 Contact
179
+ If you have any questions, please create an issue on this repository or contact at [email protected].
TerraFM-B.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bdb6a4cfd707a09d79f4790119daa9d76bb316b6ea756c6eb305f99d1797a06
3
+ size 451862010
images/arch.jpg ADDED

Git LFS Details

  • SHA256: 66bafc4de02a78413411df3dcb9aeb706342ef3939503836f61405fdbf7c62ac
  • Pointer size: 131 Bytes
  • Size of remote file: 606 kB
images/ls4s_qual.jpg ADDED

Git LFS Details

  • SHA256: b88a97c4e04f2a5c50f57c30fa539aca4e506f48514cfa523a0244f3b055435b
  • Pointer size: 131 Bytes
  • Size of remote file: 535 kB
images/spider_gb.jpg ADDED

Git LFS Details

  • SHA256: 23f38821eea589a1e179c6bf36ed2936180c38271e5fe599207fe27bfeeb785d
  • Pointer size: 131 Bytes
  • Size of remote file: 501 kB
terrafm.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # This file includes code copied and adapted from DINO:
3
+ # - DINO (https://github.com/facebookresearch/dino)
4
+ #
5
+ # ------------------------------------------------------------------------------
6
+ import random
7
+ import math
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+ from functools import partial
12
+
13
+
14
+ def make_2tuple(x):
15
+ if isinstance(x, tuple):
16
+ assert len(x) == 2
17
+ return x
18
+
19
+ assert isinstance(x, int)
20
+ return (x, x)
21
+
22
+
23
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
24
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
25
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
26
+ def norm_cdf(x):
27
+ # Computes standard normal cumulative distribution function
28
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
29
+
30
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
31
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
32
+ "The distribution of values may be incorrect.",
33
+ stacklevel=2)
34
+
35
+ with torch.no_grad():
36
+ # Values are generated by using a truncated uniform distribution and
37
+ # then using the inverse CDF for the normal distribution.
38
+ # Get upper and lower cdf values
39
+ l = norm_cdf((a - mean) / std)
40
+ u = norm_cdf((b - mean) / std)
41
+
42
+ # Uniformly fill tensor with values from [l, u], then translate to
43
+ # [2l-1, 2u-1].
44
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
45
+
46
+ # Use inverse cdf transform for normal distribution to get truncated
47
+ # standard normal
48
+ tensor.erfinv_()
49
+
50
+ # Transform to proper mean, std
51
+ tensor.mul_(std * math.sqrt(2.))
52
+ tensor.add_(mean)
53
+
54
+ # Clamp to ensure it's in the proper range
55
+ tensor.clamp_(min=a, max=b)
56
+ return tensor
57
+
58
+
59
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
60
+ # type: (Tensor, float, float, float, float) -> Tensor
61
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
62
+
63
+
64
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
65
+ if drop_prob == 0. or not training:
66
+ return x
67
+ keep_prob = 1 - drop_prob
68
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
69
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
70
+ random_tensor.floor_() # binarize
71
+ output = x.div(keep_prob) * random_tensor
72
+ return output
73
+
74
+ class DropPath(nn.Module):
75
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
76
+ """
77
+ def __init__(self, drop_prob=None):
78
+ super(DropPath, self).__init__()
79
+ self.drop_prob = drop_prob
80
+
81
+ def forward(self, x):
82
+ return drop_path(x, self.drop_prob, self.training)
83
+
84
+
85
+ class Mlp(nn.Module):
86
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
87
+ super().__init__()
88
+ out_features = out_features or in_features
89
+ hidden_features = hidden_features or in_features
90
+ self.fc1 = nn.Linear(in_features, hidden_features)
91
+ self.act = act_layer()
92
+ self.fc2 = nn.Linear(hidden_features, out_features)
93
+ self.drop = nn.Dropout(drop)
94
+
95
+ def forward(self, x):
96
+ x = self.fc1(x)
97
+ x = self.act(x)
98
+ x = self.drop(x)
99
+ x = self.fc2(x)
100
+ x = self.drop(x)
101
+ return x
102
+
103
+
104
+ class Attention(nn.Module):
105
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
106
+ super().__init__()
107
+ self.num_heads = num_heads
108
+ head_dim = dim // num_heads
109
+ self.scale = qk_scale or head_dim ** -0.5
110
+
111
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
112
+ self.attn_drop = nn.Dropout(attn_drop)
113
+ self.proj = nn.Linear(dim, dim)
114
+ self.proj_drop = nn.Dropout(proj_drop)
115
+
116
+ def forward(self, x):
117
+ B, N, C = x.shape
118
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
119
+ q, k, v = qkv[0], qkv[1], qkv[2]
120
+
121
+ attn = (q @ k.transpose(-2, -1)) * self.scale
122
+ attn = attn.softmax(dim=-1)
123
+ attn = self.attn_drop(attn)
124
+
125
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
126
+ x = self.proj(x)
127
+ x = self.proj_drop(x)
128
+ return x, attn
129
+
130
+
131
+ class Block(nn.Module):
132
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
133
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
134
+ super().__init__()
135
+ self.norm1 = norm_layer(dim)
136
+ self.attn = Attention(
137
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
138
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
139
+ self.norm2 = norm_layer(dim)
140
+ mlp_hidden_dim = int(dim * mlp_ratio)
141
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
142
+
143
+ def forward(self, x, return_attention=False):
144
+ y, attn = self.attn(self.norm1(x))
145
+ if return_attention:
146
+ return attn
147
+ x = x + self.drop_path(y)
148
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
149
+ return x
150
+
151
+
152
+
153
+ class PatchEmbed(nn.Module):
154
+ def __init__(
155
+ self,
156
+ img_size: int,
157
+ embed_dim: int,
158
+ patch_size: int,
159
+ in_chans_s1: int,
160
+ in_chans_s2: int,
161
+ ):
162
+ super().__init__()
163
+ attn_dim = embed_dim*3 # from Panopticon design
164
+ self.img_size = img_size
165
+ self.patch_size = patch_size
166
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
167
+ self.num_patches = num_patches
168
+
169
+ self.conv2d_s2_l2a = nn.Conv2d(in_chans_s2, attn_dim, kernel_size=patch_size, stride=patch_size)
170
+ self.conv2d_s2_l1c = nn.Conv2d(in_chans_s2, attn_dim, kernel_size=patch_size, stride=patch_size)
171
+ self.conv2d_s1 = nn.Conv2d(in_chans_s1, attn_dim, kernel_size=patch_size, stride=patch_size)
172
+
173
+
174
+ self.projection = TokenProjection(embed_dim=embed_dim, attn_dim=attn_dim)
175
+ self.s2_l2a_embed = nn.Parameter(torch.zeros(1, attn_dim))
176
+ self.s2_l1c_embed = nn.Parameter(torch.zeros(1, attn_dim))
177
+ self.s1_embed = nn.Parameter(torch.zeros(1, attn_dim))
178
+ self.attn_dim = attn_dim
179
+
180
+ def forward(self, x12: Tensor, is_l2a: bool = False) -> Tensor:
181
+
182
+ B,C,W,H = x12.shape
183
+ device, dtype = x12.device, x12.dtype
184
+ B = len(x12)
185
+ if C == 2:
186
+ x = self.conv2d_s1(x12).flatten(2).transpose(1, 2)
187
+ x += self.s1_embed
188
+ elif is_l2a:
189
+ x = self.conv2d_s2_l2a(x12).flatten(2).transpose(1, 2)
190
+ x += self.s2_l2a_embed
191
+ else:
192
+ x = self.conv2d_s2_l1c(x12).flatten(2).transpose(1, 2)
193
+ x += self.s2_l1c_embed
194
+
195
+ x = self.projection(x)
196
+ return x
197
+
198
+
199
+ class TokenProjection(nn.Module):
200
+ def __init__(self, embed_dim: int, attn_dim: int):
201
+ super().__init__()
202
+ self.proj1 = nn.Linear(attn_dim, attn_dim, bias=False)
203
+ self.norm_input = nn.LayerNorm(attn_dim)
204
+ self.proj2 = nn.Linear(attn_dim, attn_dim)
205
+ self.proj3 = nn.Linear(attn_dim, embed_dim)
206
+
207
+ def forward(self, x: Tensor) -> Tensor:
208
+ """
209
+ Applies a sequence of linear projections used for Case 1 & N in modality augmentation.
210
+
211
+ Steps:
212
+ 1. proj1 is shared between Case 1 and Case N (acts like value projection in attention).
213
+ 2. Applies LayerNorm to stabilize training and normalize features.
214
+ 3. In Case N, proj2 is applied after the weighted mean operation.
215
+ 4. proj3 projects to the final embedding dimension.
216
+ Args:
217
+ tokens (Tensor): Input tensor of shape [B, N, input_dim], where
218
+ B = batch size, N = number of tokens.
219
+
220
+ Returns:
221
+ Tensor: Projected output of shape [B, N, final_dim].
222
+ """
223
+ x = self.proj1(x) #V in corss attn
224
+ x = self.norm_input(x)
225
+ x = self.proj2(x)
226
+ x = self.proj3(x) #final projection
227
+ return x
228
+
229
+ class TerraFM(nn.Module):
230
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
231
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
232
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
233
+ super().__init__()
234
+ self.num_features = self.embed_dim = embed_dim
235
+
236
+ self.patch_embed = PatchEmbed(
237
+ img_size=img_size[0], patch_size=patch_size, in_chans_s1=2, in_chans_s2=12, embed_dim=embed_dim)
238
+ num_patches = self.patch_embed.num_patches
239
+
240
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
241
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
242
+ self.pos_drop = nn.Dropout(p=drop_rate)
243
+
244
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
245
+ self.blocks = nn.ModuleList([
246
+ Block(
247
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
248
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
249
+ for i in range(depth)])
250
+ self.norm = norm_layer(embed_dim)
251
+
252
+ # Classifier head
253
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
254
+
255
+ trunc_normal_(self.pos_embed, std=.02)
256
+ trunc_normal_(self.cls_token, std=.02)
257
+ self.apply(self._init_weights)
258
+
259
+ def _init_weights(self, m):
260
+ if isinstance(m, nn.Linear):
261
+ trunc_normal_(m.weight, std=.02)
262
+ if isinstance(m, nn.Linear) and m.bias is not None:
263
+ nn.init.constant_(m.bias, 0)
264
+ elif isinstance(m, nn.LayerNorm):
265
+ nn.init.constant_(m.bias, 0)
266
+ nn.init.constant_(m.weight, 1.0)
267
+
268
+ def interpolate_pos_encoding(self, x, w, h):
269
+ npatch = x.shape[1] - 1
270
+ N = self.pos_embed.shape[1] - 1
271
+ if npatch == N and w == h:
272
+ return self.pos_embed
273
+ class_pos_embed = self.pos_embed[:, 0]
274
+ patch_pos_embed = self.pos_embed[:, 1:]
275
+ dim = x.shape[-1]
276
+ w0 = w // self.patch_embed.patch_size
277
+ h0 = h // self.patch_embed.patch_size
278
+ # we add a small number to avoid floating point error in the interpolation
279
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
280
+ w0, h0 = w0 + 0.1, h0 + 0.1
281
+ patch_pos_embed = nn.functional.interpolate(
282
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
283
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
284
+ mode='bicubic',
285
+ )
286
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
287
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
288
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
289
+
290
+ def prepare_tokens(self, x):
291
+ B, nc, w, h = x.shape
292
+ x = self.patch_embed(x) # patch linear embedding
293
+
294
+ # add the [CLS] token to the embed patch tokens
295
+ cls_tokens = self.cls_token.expand(B, -1, -1)
296
+ x = torch.cat((cls_tokens, x), dim=1)
297
+
298
+ # add positional encoding to each token
299
+ x = x + self.interpolate_pos_encoding(x, w, h)
300
+
301
+ return self.pos_drop(x)
302
+
303
+ def forward_features(self, x):
304
+ return self.forward(x)
305
+
306
+ def forward(self, x):
307
+ x = self.prepare_tokens(x)
308
+ for blk in self.blocks:
309
+ x = blk(x)
310
+ x = self.norm(x)
311
+ return x[:, 0]
312
+
313
+ def get_last_selfattention(self, x):
314
+ x = self.prepare_tokens(x)
315
+ for i, blk in enumerate(self.blocks):
316
+ if i < len(self.blocks) - 1:
317
+ x = blk(x)
318
+ else:
319
+ # return attention of the last block
320
+ return blk(x, return_attention=True)
321
+
322
+ def get_intermediate_layers(self, x, n=1,
323
+ return_class_token = False,
324
+ norm=False,
325
+ ):
326
+ x = self.prepare_tokens(x)
327
+ # we return the output tokens from the `n` last blocks
328
+ output = []
329
+ for i, blk in enumerate(self.blocks):
330
+ x = blk(x)
331
+ if len(self.blocks) - i <= n:
332
+ output.append(x)
333
+ # output.append(self.norm(x))
334
+ if norm:
335
+ output = [self.norm(out) for out in output]
336
+ class_tokens = [out[:, 0] for out in output]
337
+ output = [out[:, 1:] for out in output]
338
+ if return_class_token:
339
+ return tuple(zip(output, class_tokens))
340
+ return output
341
+
342
+ def extract_feature(self, images, return_h_w=True, out_indices=[3, 5, 7, 11]):
343
+ x = self.prepare_tokens(images)
344
+ output = []
345
+ h, w = int(images.shape[2] / self.patch_embed.patch_size), int(images.shape[3] / self.patch_embed.patch_size)
346
+ for i, blk in enumerate(self.blocks):
347
+ x = blk(x)
348
+ if i in out_indices:
349
+ out = x[:, 1:]
350
+ out = self.norm(out)
351
+ B, _, C = out.shape
352
+ out = (
353
+ out.reshape(B, h, w, C)
354
+ .permute(0, 3, 1, 2)
355
+ .contiguous()
356
+ )
357
+ output.append(out)
358
+
359
+ return output
360
+
361
+
362
+
363
+
364
+ def terrafm_base(patch_size=16, **kwargs):
365
+ model = TerraFM(
366
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
367
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
368
+ return model
369
+
370
+ def terrafm_large(patch_size=16, **kwargs):
371
+ model = TerraFM(
372
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
373
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
374
+ return model