zcash commited on
Commit
068792f
·
verified ·
1 Parent(s): d7e20bb

Upload DEM_SuperRes.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. DEM_SuperRes.ipynb +644 -0
DEM_SuperRes.ipynb ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f541ffd4",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Synthetic High-Resolution DEM Generation for Marrakech, Morocco\n",
9
+ "# Using Only McKinley Dataset for Training\n",
10
+ "\n",
11
+ "This notebook implements the full pipeline, training only on the McKinley dataset to generate a model for super-resolving 30m SRTM to 10m DEMs fused with Sentinel-2 imagery for Marrakech, Morocco.\n",
12
+ "\n",
13
+ "**Key Assumptions:**\n",
14
+ "- Training on McKinley Mine NM high-res LiDAR DEM.\n",
15
+ "- Inference on Marrakech mountain area.\n",
16
+ "- Adapted DeepDEM model with 7 input channels.\n",
17
+ "\n",
18
+ "Run cells sequentially."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "b7aa9465",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# Cell 1: Install Dependencies\n",
29
+ "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
30
+ "!pip install pytorch-lightning torchgeo segmentation-models-pytorch rasterio geopandas albumentations scipy gdown earthengine-api\n",
31
+ "!apt-get install -y libspatialindex-dev libgdal-dev\n",
32
+ "!pip install gdal==$(gdal-config --version)"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "c4f399ac",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "# Cell 2: Mount Google Drive and Set Up Directories\n",
43
+ "from google.colab import drive\n",
44
+ "drive.mount('/content/drive')\n",
45
+ "%cd /content/drive/MyDrive/DEM_Project\n",
46
+ "!mkdir -p Training_Data/McKinley Inference_Data/Marrackech Models"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "id": "ed508a36",
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "# Cell 3: Clone DeepDEM Repo and Adapt for Our Use Case\n",
57
+ "!git clone https://github.com/uw-cryo/DeepDEM.git\n",
58
+ "%cd DeepDEM\n",
59
+ "\n",
60
+ "# Adapt model for our inputs: Modify task_module.py to accept ['dsm', 'ortho_r', 'ortho_g', 'ortho_b', 'ortho_nir', 'ndvi', 'nodata_mask'] (7 channels)\n",
61
+ "# Set model in_channels=7, out_channels=1 (residuals)\n",
62
+ "# For simplicity, assume manual edit or duplicate code here.\n",
63
+ "\n",
64
+ "import os\n",
65
+ "os.environ['PYTHONPATH'] += ':/content/drive/MyDrive/DEM_Project/DeepDEM'"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "id": "d7bb1f40",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "# Cell 4: Authenticate and Initialize Earth Engine\n",
76
+ "from google.colab import auth\n",
77
+ "import ee\n",
78
+ "\n",
79
+ "# 1. Authenticate your Google user\n",
80
+ "auth.authenticate_user()\n",
81
+ "\n",
82
+ "# 2. Initialize Earth Engine with your Google Cloud Project ID\n",
83
+ "# REPLACE 'your-gcp-project-id' with the actual ID of your project\n",
84
+ "try:\n",
85
+ " ee.Initialize(project='dem-collab')\n",
86
+ " print(\"Earth Engine initialized successfully!\")\n",
87
+ "except ee.EEException as e:\n",
88
+ " print(f\"Error during initialization: {e}\")"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "id": "091b3f03",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "# Cell 5: Data Acquisition Function (SRTM + Sentinel-2 for McKinley and Marrakech)\n",
99
+ "import os\n",
100
+ "\n",
101
+ "def fetch_gee_data(bbox, output_dir, dataset_name):\n",
102
+ " os.makedirs(output_dir, exist_ok=True)\n",
103
+ "\n",
104
+ " geom = ee.Geometry.BBox(*bbox) # Define geometry for region\n",
105
+ "\n",
106
+ " # SRTM (30m)\n",
107
+ " srtm = ee.Image('CGIAR/SRTM90_V4').clip(geom).rename('dsm')\n",
108
+ " task_srtm = ee.batch.Export.image.toDrive(\n",
109
+ " image=srtm,\n",
110
+ " description=f'{dataset_name}_srtm',\n",
111
+ " folder=output_dir.split('/')[-1],\n",
112
+ " scale=30,\n",
113
+ " fileFormat='GeoTIFF',\n",
114
+ " region=geom\n",
115
+ " )\n",
116
+ " task_srtm.start()\n",
117
+ "\n",
118
+ " # Sentinel-2 (10m, cloud-free median, RGB + NIR)\n",
119
+ " # Sharper image: sort by cloud cover and take the best one from a good season\n",
120
+ " collection = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \\\n",
121
+ " .filterBounds(geom) \\\n",
122
+ " .filterDate('2023-06-01', '2023-10-31') \\\n",
123
+ " .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 10)) \\\n",
124
+ " .sort('CLOUDY_PIXEL_PERCENTAGE')\n",
125
+ "\n",
126
+ " num_images = collection.size().getInfo()\n",
127
+ " print(f'For {dataset_name} bbox: {num_images} Sentinel-2 images found.')\n",
128
+ "\n",
129
+ " # Get the best image\n",
130
+ " best_image = collection.first()\n",
131
+ "\n",
132
+ " # Export the raw 4-band image for the model\n",
133
+ " sentinel_raw = best_image.select(['B4','B3','B2','B8'])\n",
134
+ " task_s2_raw = ee.batch.Export.image.toDrive(\n",
135
+ " image=sentinel_raw,\n",
136
+ " description=f'{dataset_name}_sentinel', # Keep original name for downstream tasks\n",
137
+ " folder=output_dir.split('/')[-1],\n",
138
+ " scale=10,\n",
139
+ " fileFormat='GeoTIFF',\n",
140
+ " region=geom\n",
141
+ " )\n",
142
+ " task_s2_raw.start()\n",
143
+ "\n",
144
+ " # Export a separate, color-corrected visual version for inspection\n",
145
+ " sentinel_viz = best_image.visualize(min=0, max=3000, bands=['B4', 'B3', 'B2'])\n",
146
+ " task_s2_viz = ee.batch.Export.image.toDrive(\n",
147
+ " image=sentinel_viz,\n",
148
+ " description=f'{dataset_name}_sentinel_viz',\n",
149
+ " folder=output_dir.split('/')[-1],\n",
150
+ " scale=10,\n",
151
+ " fileFormat='GeoTIFF',\n",
152
+ " region=geom\n",
153
+ " )\n",
154
+ " task_s2_viz.start()\n",
155
+ "\n",
156
+ "\n",
157
+ "# Bounding boxes\n",
158
+ "bbox_mckinley = [-109.03892074228675, 35.58282920746211, -108.87077846472735, 35.736434167381475]\n",
159
+ "fetch_gee_data(bbox_mckinley, '/content/drive/MyDrive/DEM_Project/Training_Data/McKinley', 'mckinley')\n",
160
+ "\n",
161
+ "bbox_marrakech = [-8.1, 31.5, -7.9, 31.7]\n",
162
+ "fetch_gee_data(bbox_marrakech, '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech', 'marrakech')"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": null,
168
+ "id": "5f9b0934",
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "# Cell 6: Download High-Resolution DEM Tiles and Merge (Only for McKinley)\n",
173
+ "!pip install boto3 gdal retry\n",
174
+ "!apt install gdal-bin\n",
175
+ "import boto3\n",
176
+ "import os\n",
177
+ "import shutil\n",
178
+ "from botocore import UNSIGNED\n",
179
+ "from botocore.client import Config\n",
180
+ "from retry import retry\n",
181
+ "\n",
182
+ "# S3 Configuration\n",
183
+ "endpoint_url = 'https://opentopography.s3.sdsc.edu'\n",
184
+ "client = boto3.client('s3', endpoint_url=endpoint_url, config=Config(signature_version=UNSIGNED))\n",
185
+ "\n",
186
+ "# Temp and data dirs\n",
187
+ "temp_base = '/content/hr_temp'\n",
188
+ "data_base = '/content/drive/MyDrive/DEM_Project/Training_Data'\n",
189
+ "os.makedirs(temp_base, exist_ok=True)\n",
190
+ "os.makedirs(data_base, exist_ok=True)\n",
191
+ "\n",
192
+ "# Only McKinley\n",
193
+ "datasets = {'mckinley': 'NM23_McKinley'}\n",
194
+ "folder_names = {'mckinley': 'McKinley'}\n",
195
+ "\n",
196
+ "@retry(tries=3, delay=2, backoff=2)\n",
197
+ "def download_file_with_retry(bucket, key, filename):\n",
198
+ " client.download_file(Bucket=bucket, Key=key, Filename=filename)\n",
199
+ "\n",
200
+ "!df -h /content\n",
201
+ "for local_name, s3_dir in datasets.items():\n",
202
+ " temp_dir = os.path.join(temp_base, local_name)\n",
203
+ " dataset_dir = os.path.join(data_base, folder_names[local_name])\n",
204
+ " os.makedirs(temp_dir, exist_ok=True)\n",
205
+ " os.makedirs(dataset_dir, exist_ok=True)\n",
206
+ "\n",
207
+ " output_tif = os.path.join(dataset_dir, f'{local_name}_hr_dem.tif')\n",
208
+ " if os.path.exists(output_tif):\n",
209
+ " print(f'Merged DEM already exists for {local_name}: {output_tif}, skipping download and merge.')\n",
210
+ " continue\n",
211
+ "\n",
212
+ " paginator = client.get_paginator('list_objects_v2')\n",
213
+ " prefix = f'{s3_dir}/{s3_dir}_be/'\n",
214
+ " downloaded_files = []\n",
215
+ " try:\n",
216
+ " for page in paginator.paginate(Bucket='raster', Prefix=prefix):\n",
217
+ " for obj in page.get('Contents', []):\n",
218
+ " key = obj['Key']\n",
219
+ " if key.endswith('.tif'):\n",
220
+ " file_path = os.path.join(temp_dir, os.path.basename(key))\n",
221
+ " # Check if the tile file already exists\n",
222
+ " if os.path.exists(file_path):\n",
223
+ " print(f'Tile {os.path.basename(key)} already exists for {local_name}, skipping download.')\n",
224
+ " downloaded_files.append(file_path)\n",
225
+ " continue\n",
226
+ " try:\n",
227
+ " download_file_with_retry('raster', key, file_path)\n",
228
+ " downloaded_files.append(file_path)\n",
229
+ " print(f'Downloaded {os.path.basename(key)} for {local_name}')\n",
230
+ " except Exception as e:\n",
231
+ " print(f'Error downloading {key} for {local_name}: {e}')\n",
232
+ " except Exception as e:\n",
233
+ " print(f'Error listing tiles for {local_name}: {e}')\n",
234
+ "\n",
235
+ " if not downloaded_files:\n",
236
+ " print(f'No tiles downloaded for {local_name}; skipping merge.')\n",
237
+ " continue\n",
238
+ "\n",
239
+ " try:\n",
240
+ " # Using gdalbuildvrt and gdal_translate for better performance\n",
241
+ " !gdalbuildvrt merged.vrt {\" \".join(downloaded_files)}\n",
242
+ " !gdal_translate -of GTiff merged.vrt \"{output_tif}\" -co TILED=YES -co COMPRESS=DEFLATE -co NUM_THREADS=ALL_CPUS\n",
243
+ " print(f'Merged tiles to TIFF for {local_name}: {output_tif}')\n",
244
+ " except Exception as e:\n",
245
+ " print(f'Error merging tiles for {local_name}: {e}')\n",
246
+ " continue\n",
247
+ " # Removed the shutil.rmtree(temp_dir) line as requested\n",
248
+ "\n",
249
+ "!df -h /content\n",
250
+ "print('Download complete for McKinley!')"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "id": "7f82f16c",
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": [
260
+ "# Cell 7: Data Preprocessing (Only for McKinley and Marrakech)\n",
261
+ "import rasterio\n",
262
+ "from rasterio.enums import Resampling\n",
263
+ "import numpy as np\n",
264
+ "from scipy.ndimage import gaussian_filter\n",
265
+ "import os\n",
266
+ "\n",
267
+ "folder_names = {'mckinley': 'McKinley', 'marrakech': 'Marrakech'}\n",
268
+ "\n",
269
+ "def preprocess_dataset(dataset_name, is_training=True, custom_base=None):\n",
270
+ " if custom_base:\n",
271
+ " base_dir = custom_base\n",
272
+ " else:\n",
273
+ " base_dir = f'/content/drive/MyDrive/DEM_Project/Training_Data/{folder_names[dataset_name]}'\n",
274
+ " \n",
275
+ " srtm_path = os.path.join(base_dir, f'{dataset_name}_srtm.tif')\n",
276
+ " s2_path = os.path.join(base_dir, f'{dataset_name}_sentinel.tif')\n",
277
+ " hr_path = os.path.join(base_dir, f'{dataset_name}_hr_dem.tif') if is_training else None\n",
278
+ " output_dir = base_dir\n",
279
+ "\n",
280
+ " with rasterio.open(srtm_path) as srtm_src, rasterio.open(s2_path) as s2_src:\n",
281
+ " target_shape = (s2_src.height, s2_src.width)\n",
282
+ " srtm = srtm_src.read(1, out_shape=target_shape, resampling=Resampling.cubic)\n",
283
+ "\n",
284
+ " s2 = s2_src.read()\n",
285
+ " r, g, b, nir = s2\n",
286
+ "\n",
287
+ " ndvi = (nir - r) / (nir + r + 1e-10)\n",
288
+ "\n",
289
+ " mask = np.where(srtm == srtm_src.nodata, 1, 0).astype(np.float32)\n",
290
+ "\n",
291
+ " if is_training:\n",
292
+ " with rasterio.open(hr_path) as hr_src:\n",
293
+ " hr = hr_src.read(1, out_shape=target_shape, resampling=Resampling.cubic) if hr_src.shape != target_shape else hr_src.read(1)\n",
294
+ "\n",
295
+ " trend = gaussian_filter(hr, sigma=5)\n",
296
+ " residual = hr - trend\n",
297
+ "\n",
298
+ " target_profile = s2_src.profile\n",
299
+ " target_profile['count'] = 1\n",
300
+ " with rasterio.open(os.path.join(output_dir, 'target.tif'), 'w', **target_profile) as dst:\n",
301
+ " dst.write(residual, 1)\n",
302
+ "\n",
303
+ " input_profile = s2_src.profile\n",
304
+ " input_profile['count'] = 7\n",
305
+ " with rasterio.open(os.path.join(output_dir, 'input.tif'), 'w', **input_profile) as dst:\n",
306
+ " dst.write(srtm, 1)\n",
307
+ " dst.write(r, 2)\n",
308
+ " dst.write(g, 3)\n",
309
+ " dst.write(b, 4)\n",
310
+ " dst.write(nir, 5)\n",
311
+ " dst.write(ndvi, 6)\n",
312
+ " dst.write(mask, 7)\n",
313
+ "\n",
314
+ "# Preprocess McKinley\n",
315
+ "preprocess_dataset('mckinley', is_training=True)\n",
316
+ "\n",
317
+ "# Preprocess Marrakech (no HR)\n",
318
+ "preprocess_dataset('marrakech', is_training=False, custom_base='/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech')\n",
319
+ "\n",
320
+ "!df -h /content/drive/MyDrive"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "id": "7933d058",
327
+ "metadata": {},
328
+ "outputs": [],
329
+ "source": [
330
+ "# Cell 8: Custom Dataset Class\n",
331
+ "import albumentations as A\n",
332
+ "from albumentations.pytorch import ToTensorV2\n",
333
+ "from torch.utils.data import Dataset\n",
334
+ "import rasterio.windows\n",
335
+ "\n",
336
+ "class CustomDEMDataset(Dataset):\n",
337
+ " def __init__(self, data_dirs, tile_size=256, transform=None):\n",
338
+ " self.pairs = []\n",
339
+ " for d_dir in data_dirs:\n",
340
+ " input_path = os.path.join(d_dir, 'input.tif')\n",
341
+ " target_path = os.path.join(d_dir, 'target.tif')\n",
342
+ " if os.path.exists(input_path) and os.path.exists(target_path):\n",
343
+ " self.pairs.append((input_path, target_path))\n",
344
+ " self.tile_size = tile_size\n",
345
+ " self.transform = transform or A.Compose([\n",
346
+ " A.RandomCrop(height=tile_size, width=tile_size),\n",
347
+ " A.RandomRotate90(),\n",
348
+ " A.HorizontalFlip(),\n",
349
+ " A.VerticalFlip(),\n",
350
+ " A.GaussNoise(var_limit=(0.01, 0.01)),\n",
351
+ " ToTensorV2()\n",
352
+ " ])\n",
353
+ "\n",
354
+ " def __len__(self):\n",
355
+ " return len(self.pairs) * 50\n",
356
+ "\n",
357
+ " def __getitem__(self, idx):\n",
358
+ " input_path, target_path = self.pairs[idx % len(self.pairs)]\n",
359
+ " with rasterio.open(input_path) as inp, rasterio.open(target_path) as tgt:\n",
360
+ " max_col = inp.width - self.tile_size\n",
361
+ " max_row = inp.height - self.tile_size\n",
362
+ " col_off = np.random.randint(0, max_col + 1)\n",
363
+ " row_off = np.random.randint(0, max_row + 1)\n",
364
+ " window = rasterio.windows.Window(col_off, row_off, self.tile_size, self.tile_size)\n",
365
+ " input_data = inp.read(window=window)\n",
366
+ " target_data = tgt.read(1, window=window)\n",
367
+ "\n",
368
+ " data = {'image': input_data.transpose(1,2,0).astype(np.float32), 'target': target_data.astype(np.float32)}\n",
369
+ " augmented = self.transform(image=data['image'], mask=data['target'])\n",
370
+ " return augmented['image'], augmented['mask'].unsqueeze(0)\n",
371
+ "\n",
372
+ "train_dirs = ['/content/drive/MyDrive/DEM_Project/Training_Data/McKinley']\n",
373
+ "dataset = CustomDEMDataset(train_dirs)"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": null,
379
+ "id": "46cf2c6d",
380
+ "metadata": {},
381
+ "outputs": [],
382
+ "source": [
383
+ "# Cell 9: Model and Training (Using Only McKinley)\n",
384
+ "import pytorch_lightning as pl\n",
385
+ "import segmentation_models_pytorch as smp\n",
386
+ "from torch.utils.data import DataLoader\n",
387
+ "import torch\n",
388
+ "import torch.nn as nn\n",
389
+ "\n",
390
+ "class DeepDEMRefinement(pl.LightningModule):\n",
391
+ " def __init__(self, lr=1e-4):\n",
392
+ " super().__init__()\n",
393
+ " self.model = smp.Unet(encoder_name='resnet34', in_channels=7, classes=1, activation=None)\n",
394
+ " self.loss_fn = nn.L1Loss()\n",
395
+ " self.lr = lr\n",
396
+ "\n",
397
+ " def forward(self, x):\n",
398
+ " return self.model(x)\n",
399
+ "\n",
400
+ " def training_step(self, batch, batch_idx):\n",
401
+ " inputs, targets = batch\n",
402
+ " preds = self(inputs)\n",
403
+ " loss = self.loss_fn(preds, targets)\n",
404
+ " self.log('train_loss', loss)\n",
405
+ " return loss\n",
406
+ "\n",
407
+ " def configure_optimizers(self):\n",
408
+ " return torch.optim.Adam(self.parameters(), lr=self.lr)\n",
409
+ "\n",
410
+ "class DEMDataModule(pl.LightningDataModule):\n",
411
+ " def __init__(self, train_dirs, batch_size=4):\n",
412
+ " super().__init__()\n",
413
+ " self.train_dataset = CustomDEMDataset(train_dirs)\n",
414
+ " self.batch_size = batch_size\n",
415
+ "\n",
416
+ " def train_dataloader(self):\n",
417
+ " return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)\n",
418
+ "\n",
419
+ "model = DeepDEMRefinement()\n",
420
+ "datamodule = DEMDataModule(train_dirs)\n",
421
+ "trainer = pl.Trainer(max_epochs=5, accelerator='gpu', devices=1) # Training for 5 epochs\n",
422
+ "trainer.fit(model, datamodule)\n",
423
+ "\n",
424
+ "trainer.save_checkpoint('/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt')"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "id": "30872060",
431
+ "metadata": {},
432
+ "outputs": [],
433
+ "source": [
434
+ "# Cell 10: Inference for Marrakech\n",
435
+ "model = DeepDEMRefinement.load_from_checkpoint('/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt')\n",
436
+ "model.eval()\n",
437
+ "model.to('cuda' if torch.cuda.is_available() else 'cpu')\n",
438
+ "\n",
439
+ "input_path = '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech/input.tif'\n",
440
+ "with rasterio.open(input_path) as src:\n",
441
+ " input_data = src.read().astype(np.float32)\n",
442
+ " trend = gaussian_filter(input_data[0], sigma=5)\n",
443
+ "\n",
444
+ " input_tensor = torch.from_numpy(input_data).unsqueeze(0).to(model.device)\n",
445
+ "\n",
446
+ " with torch.no_grad():\n",
447
+ " residual_pred = model(input_tensor)\n",
448
+ "\n",
449
+ " synth_dem = residual_pred.squeeze().cpu().numpy() + trend\n",
450
+ "\n",
451
+ " profile = src.profile\n",
452
+ " profile['count'] = 1\n",
453
+ " with rasterio.open('/content/drive/MyDrive/DEM_Project/synth_dem_marrakech.tif', 'w', **profile) as dst:\n",
454
+ " dst.write(synth_dem, 1)\n",
455
+ "\n",
456
+ "print('Synthetic DEM generated for Marrakech!')"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "markdown",
461
+ "id": "8943e470",
462
+ "metadata": {},
463
+ "source": [
464
+ "# Quick correctness checks\n",
465
+ "\n",
466
+ "This section runs a few sanity checks on the trained model and data:\n",
467
+ "\n",
468
+ "- Validate shapes, CRS, and basic channel statistics of `input.tif` and `target.tif`\n",
469
+ "- Compute masked MAE/RMSE on random training crops (McKinley) to gauge training fit\n",
470
+ "- Flag obvious issues (e.g., all-zeros bands, nodata dominance)\n",
471
+ "\n",
472
+ "Run cells in order after training has completed."
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "code",
477
+ "execution_count": null,
478
+ "id": "3c35167d",
479
+ "metadata": {},
480
+ "outputs": [],
481
+ "source": [
482
+ "# Check 1: Inspect input/target rasters (McKinley)\n",
483
+ "import rasterio\n",
484
+ "import numpy as np\n",
485
+ "from pathlib import Path\n",
486
+ "\n",
487
+ "train_dir = Path('/content/drive/MyDrive/DEM_Project/Training_Data/McKinley')\n",
488
+ "input_path = train_dir / 'input.tif'\n",
489
+ "target_path = train_dir / 'target.tif'\n",
490
+ "\n",
491
+ "issues = []\n",
492
+ "\n",
493
+ "with rasterio.open(input_path) as src:\n",
494
+ " print('INPUT:')\n",
495
+ " print({'count': src.count, 'width': src.width, 'height': src.height, 'crs': str(src.crs), 'dtype': src.dtypes})\n",
496
+ " data = src.read(out_dtype='float32')\n",
497
+ " nodata = src.nodata\n",
498
+ " band_stats = []\n",
499
+ " for i in range(src.count):\n",
500
+ " b = data[i]\n",
501
+ " if nodata is not None:\n",
502
+ " mask = b == nodata\n",
503
+ " valid = np.where(mask, np.nan, b)\n",
504
+ " else:\n",
505
+ " valid = b\n",
506
+ " mask = np.zeros_like(b, dtype=bool)\n",
507
+ " s = {\n",
508
+ " 'band': i+1,\n",
509
+ " 'nan_frac': float(np.mean(np.isnan(valid))),\n",
510
+ " 'nodata_frac': float(np.mean(mask)),\n",
511
+ " 'min': float(np.nanmin(valid)),\n",
512
+ " 'max': float(np.nanmax(valid)),\n",
513
+ " 'mean': float(np.nanmean(valid)),\n",
514
+ " 'std': float(np.nanstd(valid)),\n",
515
+ " }\n",
516
+ " band_stats.append(s)\n",
517
+ " print('Input band stats (1:dsm, 2:R, 3:G, 4:B, 5:NIR, 6:NDVI, 7:mask):')\n",
518
+ " for s in band_stats:\n",
519
+ " print(s)\n",
520
+ " # Basic checks\n",
521
+ " if band_stats[5]['min'] < -1.01 or band_stats[5]['max'] > 1.01:\n",
522
+ " issues.append('NDVI out of expected [-1,1] range; check scaling and bands (R,NIR indices).')\n",
523
+ " if band_stats[6]['mean'] < 0.01 and band_stats[6]['max'] < 0.5:\n",
524
+ " issues.append('Mask band appears mostly zeros; ensure mask=1 at nodata pixels, 0 elsewhere.')\n",
525
+ "\n",
526
+ "with rasterio.open(target_path) as src:\n",
527
+ " print('\\nTARGET:')\n",
528
+ " print({'count': src.count, 'width': src.width, 'height': src.height, 'crs': str(src.crs), 'dtype': src.dtypes})\n",
529
+ " t = src.read(1, out_dtype='float32')\n",
530
+ " print({'min': float(np.nanmin(t)), 'max': float(np.nanmax(t)), 'mean': float(np.nanmean(t)), 'std': float(np.nanstd(t))})\n",
531
+ " if np.allclose(t, 0):\n",
532
+ " issues.append('Target residual is all zeros; check HR DEM loading and detrending step.')\n",
533
+ "\n",
534
+ "print('\\nPotential issues:')\n",
535
+ "print(issues if issues else 'None detected')"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": null,
541
+ "id": "ad78e542",
542
+ "metadata": {},
543
+ "outputs": [],
544
+ "source": [
545
+ "# Check 2: Compute quick masked MAE/RMSE on random training crops\n",
546
+ "import torch\n",
547
+ "import torch.nn.functional as F\n",
548
+ "from torch.utils.data import DataLoader\n",
549
+ "import numpy as np\n",
550
+ "\n",
551
+ "# Reuse dataset and model classes already defined earlier\n",
552
+ "try:\n",
553
+ " _ = CustomDEMDataset\n",
554
+ "except NameError:\n",
555
+ " raise RuntimeError('CustomDEMDataset not defined; run earlier cells first.')\n",
556
+ "\n",
557
+ "try:\n",
558
+ " _ = DeepDEMRefinement\n",
559
+ "except NameError:\n",
560
+ " raise RuntimeError('DeepDEMRefinement not defined; run training cells first.')\n",
561
+ "\n",
562
+ "# Load model\n",
563
+ "ckpt = '/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt'\n",
564
+ "model = DeepDEMRefinement.load_from_checkpoint(ckpt)\n",
565
+ "model.eval()\n",
566
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
567
+ "model.to(device)\n",
568
+ "\n",
569
+ "# Small eval dataset with deterministic crops\n",
570
+ "np.random.seed(42)\n",
571
+ "transform = A.Compose([\n",
572
+ " A.RandomCrop(height=256, width=256),\n",
573
+ " ToTensorV2()\n",
574
+ "])\n",
575
+ "\n",
576
+ "val_ds = CustomDEMDataset([str(train_dir)], tile_size=256, transform=transform)\n",
577
+ "val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=2)\n",
578
+ "\n",
579
+ "maes, rmses = [], []\n",
580
+ "with torch.no_grad():\n",
581
+ " for i, (x, y) in enumerate(val_loader):\n",
582
+ " if i >= 10: # ~40 tiles\n",
583
+ " break\n",
584
+ " x = x.to(device)\n",
585
+ " y = y.to(device)\n",
586
+ " pred = model(x)\n",
587
+ " # If mask channel included, optionally down-weight masked pixels\n",
588
+ " mask = x[:, 6:7] # channel 7\n",
589
+ " valid = (mask < 0.5).float()\n",
590
+ " diff = (pred - y) * valid\n",
591
+ " denom = valid.sum().clamp_min(1.0)\n",
592
+ " mae = diff.abs().sum() / denom\n",
593
+ " rmse = torch.sqrt((diff.pow(2).sum() / denom))\n",
594
+ " maes.append(mae.item())\n",
595
+ " rmses.append(rmse.item())\n",
596
+ "\n",
597
+ "print({'MAE_mean': float(np.mean(maes)), 'MAE_std': float(np.std(maes)), 'RMSE_mean': float(np.mean(rmses)), 'RMSE_std': float(np.std(rmses)), 'tiles': len(maes)*val_loader.batch_size})\n",
598
+ "\n",
599
+ "if np.mean(rmses) > 8.0:\n",
600
+ " print('Warning: High RMSE for residuals. Training may be underfit or target scaling may be off.')\n",
601
+ "else:\n",
602
+ " print('Residual error looks reasonable for the training run.')"
603
+ ]
604
+ },
605
+ {
606
+ "cell_type": "code",
607
+ "execution_count": null,
608
+ "id": "def70a98",
609
+ "metadata": {},
610
+ "outputs": [],
611
+ "source": [
612
+ "# Check 3: Sanity-check inference output alignment vs input for Marrakech\n",
613
+ "from scipy.ndimage import gaussian_filter\n",
614
+ "\n",
615
+ "marrakech_input = '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech/input.tif'\n",
616
+ "marrakech_out = '/content/drive/MyDrive/DEM_Project/synth_dem_marrakech.tif'\n",
617
+ "\n",
618
+ "with rasterio.open(marrakech_input) as src_in, rasterio.open(marrakech_out) as src_out:\n",
619
+ " print('INFERENCE INPUT:', {'shape': (src_in.count, src_in.height, src_in.width), 'crs': str(src_in.crs), 'transform': tuple(src_in.transform)})\n",
620
+ " print('SYNTH OUTPUT:', {'shape': (src_out.count, src_out.height, src_out.width), 'crs': str(src_out.crs), 'transform': tuple(src_out.transform)})\n",
621
+ " if src_in.crs != src_out.crs:\n",
622
+ " print('Warning: CRS mismatch between input and output!')\n",
623
+ " if (src_in.height != src_out.height) or (src_in.width != src_out.width):\n",
624
+ " print('Warning: Dimension mismatch between input and output!')\n",
625
+ "\n",
626
+ " out_dem = src_out.read(1).astype('float32')\n",
627
+ " # Simple terrain sanity: residual-added trend should correlate with SRTM trend\n",
628
+ " srtm = src_in.read(1).astype('float32')\n",
629
+ " trend = gaussian_filter(srtm, sigma=5)\n",
630
+ " corr = np.corrcoef(trend.flatten(), out_dem.flatten())[0,1]\n",
631
+ " print('Correlation between SRTM trend and synthetic DEM:', float(corr))\n",
632
+ " if corr < 0.5:\n",
633
+ " print('Low correlation; output may be noisy or misaligned.')"
634
+ ]
635
+ }
636
+ ],
637
+ "metadata": {
638
+ "language_info": {
639
+ "name": "python"
640
+ }
641
+ },
642
+ "nbformat": 4,
643
+ "nbformat_minor": 5
644
+ }