Upload DEM_SuperRes.ipynb with huggingface_hub
Browse files- DEM_SuperRes.ipynb +644 -0
DEM_SuperRes.ipynb
ADDED
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "f541ffd4",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Synthetic High-Resolution DEM Generation for Marrakech, Morocco\n",
|
9 |
+
"# Using Only McKinley Dataset for Training\n",
|
10 |
+
"\n",
|
11 |
+
"This notebook implements the full pipeline, training only on the McKinley dataset to generate a model for super-resolving 30m SRTM to 10m DEMs fused with Sentinel-2 imagery for Marrakech, Morocco.\n",
|
12 |
+
"\n",
|
13 |
+
"**Key Assumptions:**\n",
|
14 |
+
"- Training on McKinley Mine NM high-res LiDAR DEM.\n",
|
15 |
+
"- Inference on Marrakech mountain area.\n",
|
16 |
+
"- Adapted DeepDEM model with 7 input channels.\n",
|
17 |
+
"\n",
|
18 |
+
"Run cells sequentially."
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": null,
|
24 |
+
"id": "b7aa9465",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"# Cell 1: Install Dependencies\n",
|
29 |
+
"!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
|
30 |
+
"!pip install pytorch-lightning torchgeo segmentation-models-pytorch rasterio geopandas albumentations scipy gdown earthengine-api\n",
|
31 |
+
"!apt-get install -y libspatialindex-dev libgdal-dev\n",
|
32 |
+
"!pip install gdal==$(gdal-config --version)"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "code",
|
37 |
+
"execution_count": null,
|
38 |
+
"id": "c4f399ac",
|
39 |
+
"metadata": {},
|
40 |
+
"outputs": [],
|
41 |
+
"source": [
|
42 |
+
"# Cell 2: Mount Google Drive and Set Up Directories\n",
|
43 |
+
"from google.colab import drive\n",
|
44 |
+
"drive.mount('/content/drive')\n",
|
45 |
+
"%cd /content/drive/MyDrive/DEM_Project\n",
|
46 |
+
"!mkdir -p Training_Data/McKinley Inference_Data/Marrackech Models"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": null,
|
52 |
+
"id": "ed508a36",
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [],
|
55 |
+
"source": [
|
56 |
+
"# Cell 3: Clone DeepDEM Repo and Adapt for Our Use Case\n",
|
57 |
+
"!git clone https://github.com/uw-cryo/DeepDEM.git\n",
|
58 |
+
"%cd DeepDEM\n",
|
59 |
+
"\n",
|
60 |
+
"# Adapt model for our inputs: Modify task_module.py to accept ['dsm', 'ortho_r', 'ortho_g', 'ortho_b', 'ortho_nir', 'ndvi', 'nodata_mask'] (7 channels)\n",
|
61 |
+
"# Set model in_channels=7, out_channels=1 (residuals)\n",
|
62 |
+
"# For simplicity, assume manual edit or duplicate code here.\n",
|
63 |
+
"\n",
|
64 |
+
"import os\n",
|
65 |
+
"os.environ['PYTHONPATH'] += ':/content/drive/MyDrive/DEM_Project/DeepDEM'"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"execution_count": null,
|
71 |
+
"id": "d7bb1f40",
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"# Cell 4: Authenticate and Initialize Earth Engine\n",
|
76 |
+
"from google.colab import auth\n",
|
77 |
+
"import ee\n",
|
78 |
+
"\n",
|
79 |
+
"# 1. Authenticate your Google user\n",
|
80 |
+
"auth.authenticate_user()\n",
|
81 |
+
"\n",
|
82 |
+
"# 2. Initialize Earth Engine with your Google Cloud Project ID\n",
|
83 |
+
"# REPLACE 'your-gcp-project-id' with the actual ID of your project\n",
|
84 |
+
"try:\n",
|
85 |
+
" ee.Initialize(project='dem-collab')\n",
|
86 |
+
" print(\"Earth Engine initialized successfully!\")\n",
|
87 |
+
"except ee.EEException as e:\n",
|
88 |
+
" print(f\"Error during initialization: {e}\")"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": null,
|
94 |
+
"id": "091b3f03",
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"# Cell 5: Data Acquisition Function (SRTM + Sentinel-2 for McKinley and Marrakech)\n",
|
99 |
+
"import os\n",
|
100 |
+
"\n",
|
101 |
+
"def fetch_gee_data(bbox, output_dir, dataset_name):\n",
|
102 |
+
" os.makedirs(output_dir, exist_ok=True)\n",
|
103 |
+
"\n",
|
104 |
+
" geom = ee.Geometry.BBox(*bbox) # Define geometry for region\n",
|
105 |
+
"\n",
|
106 |
+
" # SRTM (30m)\n",
|
107 |
+
" srtm = ee.Image('CGIAR/SRTM90_V4').clip(geom).rename('dsm')\n",
|
108 |
+
" task_srtm = ee.batch.Export.image.toDrive(\n",
|
109 |
+
" image=srtm,\n",
|
110 |
+
" description=f'{dataset_name}_srtm',\n",
|
111 |
+
" folder=output_dir.split('/')[-1],\n",
|
112 |
+
" scale=30,\n",
|
113 |
+
" fileFormat='GeoTIFF',\n",
|
114 |
+
" region=geom\n",
|
115 |
+
" )\n",
|
116 |
+
" task_srtm.start()\n",
|
117 |
+
"\n",
|
118 |
+
" # Sentinel-2 (10m, cloud-free median, RGB + NIR)\n",
|
119 |
+
" # Sharper image: sort by cloud cover and take the best one from a good season\n",
|
120 |
+
" collection = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \\\n",
|
121 |
+
" .filterBounds(geom) \\\n",
|
122 |
+
" .filterDate('2023-06-01', '2023-10-31') \\\n",
|
123 |
+
" .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 10)) \\\n",
|
124 |
+
" .sort('CLOUDY_PIXEL_PERCENTAGE')\n",
|
125 |
+
"\n",
|
126 |
+
" num_images = collection.size().getInfo()\n",
|
127 |
+
" print(f'For {dataset_name} bbox: {num_images} Sentinel-2 images found.')\n",
|
128 |
+
"\n",
|
129 |
+
" # Get the best image\n",
|
130 |
+
" best_image = collection.first()\n",
|
131 |
+
"\n",
|
132 |
+
" # Export the raw 4-band image for the model\n",
|
133 |
+
" sentinel_raw = best_image.select(['B4','B3','B2','B8'])\n",
|
134 |
+
" task_s2_raw = ee.batch.Export.image.toDrive(\n",
|
135 |
+
" image=sentinel_raw,\n",
|
136 |
+
" description=f'{dataset_name}_sentinel', # Keep original name for downstream tasks\n",
|
137 |
+
" folder=output_dir.split('/')[-1],\n",
|
138 |
+
" scale=10,\n",
|
139 |
+
" fileFormat='GeoTIFF',\n",
|
140 |
+
" region=geom\n",
|
141 |
+
" )\n",
|
142 |
+
" task_s2_raw.start()\n",
|
143 |
+
"\n",
|
144 |
+
" # Export a separate, color-corrected visual version for inspection\n",
|
145 |
+
" sentinel_viz = best_image.visualize(min=0, max=3000, bands=['B4', 'B3', 'B2'])\n",
|
146 |
+
" task_s2_viz = ee.batch.Export.image.toDrive(\n",
|
147 |
+
" image=sentinel_viz,\n",
|
148 |
+
" description=f'{dataset_name}_sentinel_viz',\n",
|
149 |
+
" folder=output_dir.split('/')[-1],\n",
|
150 |
+
" scale=10,\n",
|
151 |
+
" fileFormat='GeoTIFF',\n",
|
152 |
+
" region=geom\n",
|
153 |
+
" )\n",
|
154 |
+
" task_s2_viz.start()\n",
|
155 |
+
"\n",
|
156 |
+
"\n",
|
157 |
+
"# Bounding boxes\n",
|
158 |
+
"bbox_mckinley = [-109.03892074228675, 35.58282920746211, -108.87077846472735, 35.736434167381475]\n",
|
159 |
+
"fetch_gee_data(bbox_mckinley, '/content/drive/MyDrive/DEM_Project/Training_Data/McKinley', 'mckinley')\n",
|
160 |
+
"\n",
|
161 |
+
"bbox_marrakech = [-8.1, 31.5, -7.9, 31.7]\n",
|
162 |
+
"fetch_gee_data(bbox_marrakech, '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech', 'marrakech')"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": null,
|
168 |
+
"id": "5f9b0934",
|
169 |
+
"metadata": {},
|
170 |
+
"outputs": [],
|
171 |
+
"source": [
|
172 |
+
"# Cell 6: Download High-Resolution DEM Tiles and Merge (Only for McKinley)\n",
|
173 |
+
"!pip install boto3 gdal retry\n",
|
174 |
+
"!apt install gdal-bin\n",
|
175 |
+
"import boto3\n",
|
176 |
+
"import os\n",
|
177 |
+
"import shutil\n",
|
178 |
+
"from botocore import UNSIGNED\n",
|
179 |
+
"from botocore.client import Config\n",
|
180 |
+
"from retry import retry\n",
|
181 |
+
"\n",
|
182 |
+
"# S3 Configuration\n",
|
183 |
+
"endpoint_url = 'https://opentopography.s3.sdsc.edu'\n",
|
184 |
+
"client = boto3.client('s3', endpoint_url=endpoint_url, config=Config(signature_version=UNSIGNED))\n",
|
185 |
+
"\n",
|
186 |
+
"# Temp and data dirs\n",
|
187 |
+
"temp_base = '/content/hr_temp'\n",
|
188 |
+
"data_base = '/content/drive/MyDrive/DEM_Project/Training_Data'\n",
|
189 |
+
"os.makedirs(temp_base, exist_ok=True)\n",
|
190 |
+
"os.makedirs(data_base, exist_ok=True)\n",
|
191 |
+
"\n",
|
192 |
+
"# Only McKinley\n",
|
193 |
+
"datasets = {'mckinley': 'NM23_McKinley'}\n",
|
194 |
+
"folder_names = {'mckinley': 'McKinley'}\n",
|
195 |
+
"\n",
|
196 |
+
"@retry(tries=3, delay=2, backoff=2)\n",
|
197 |
+
"def download_file_with_retry(bucket, key, filename):\n",
|
198 |
+
" client.download_file(Bucket=bucket, Key=key, Filename=filename)\n",
|
199 |
+
"\n",
|
200 |
+
"!df -h /content\n",
|
201 |
+
"for local_name, s3_dir in datasets.items():\n",
|
202 |
+
" temp_dir = os.path.join(temp_base, local_name)\n",
|
203 |
+
" dataset_dir = os.path.join(data_base, folder_names[local_name])\n",
|
204 |
+
" os.makedirs(temp_dir, exist_ok=True)\n",
|
205 |
+
" os.makedirs(dataset_dir, exist_ok=True)\n",
|
206 |
+
"\n",
|
207 |
+
" output_tif = os.path.join(dataset_dir, f'{local_name}_hr_dem.tif')\n",
|
208 |
+
" if os.path.exists(output_tif):\n",
|
209 |
+
" print(f'Merged DEM already exists for {local_name}: {output_tif}, skipping download and merge.')\n",
|
210 |
+
" continue\n",
|
211 |
+
"\n",
|
212 |
+
" paginator = client.get_paginator('list_objects_v2')\n",
|
213 |
+
" prefix = f'{s3_dir}/{s3_dir}_be/'\n",
|
214 |
+
" downloaded_files = []\n",
|
215 |
+
" try:\n",
|
216 |
+
" for page in paginator.paginate(Bucket='raster', Prefix=prefix):\n",
|
217 |
+
" for obj in page.get('Contents', []):\n",
|
218 |
+
" key = obj['Key']\n",
|
219 |
+
" if key.endswith('.tif'):\n",
|
220 |
+
" file_path = os.path.join(temp_dir, os.path.basename(key))\n",
|
221 |
+
" # Check if the tile file already exists\n",
|
222 |
+
" if os.path.exists(file_path):\n",
|
223 |
+
" print(f'Tile {os.path.basename(key)} already exists for {local_name}, skipping download.')\n",
|
224 |
+
" downloaded_files.append(file_path)\n",
|
225 |
+
" continue\n",
|
226 |
+
" try:\n",
|
227 |
+
" download_file_with_retry('raster', key, file_path)\n",
|
228 |
+
" downloaded_files.append(file_path)\n",
|
229 |
+
" print(f'Downloaded {os.path.basename(key)} for {local_name}')\n",
|
230 |
+
" except Exception as e:\n",
|
231 |
+
" print(f'Error downloading {key} for {local_name}: {e}')\n",
|
232 |
+
" except Exception as e:\n",
|
233 |
+
" print(f'Error listing tiles for {local_name}: {e}')\n",
|
234 |
+
"\n",
|
235 |
+
" if not downloaded_files:\n",
|
236 |
+
" print(f'No tiles downloaded for {local_name}; skipping merge.')\n",
|
237 |
+
" continue\n",
|
238 |
+
"\n",
|
239 |
+
" try:\n",
|
240 |
+
" # Using gdalbuildvrt and gdal_translate for better performance\n",
|
241 |
+
" !gdalbuildvrt merged.vrt {\" \".join(downloaded_files)}\n",
|
242 |
+
" !gdal_translate -of GTiff merged.vrt \"{output_tif}\" -co TILED=YES -co COMPRESS=DEFLATE -co NUM_THREADS=ALL_CPUS\n",
|
243 |
+
" print(f'Merged tiles to TIFF for {local_name}: {output_tif}')\n",
|
244 |
+
" except Exception as e:\n",
|
245 |
+
" print(f'Error merging tiles for {local_name}: {e}')\n",
|
246 |
+
" continue\n",
|
247 |
+
" # Removed the shutil.rmtree(temp_dir) line as requested\n",
|
248 |
+
"\n",
|
249 |
+
"!df -h /content\n",
|
250 |
+
"print('Download complete for McKinley!')"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": null,
|
256 |
+
"id": "7f82f16c",
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [],
|
259 |
+
"source": [
|
260 |
+
"# Cell 7: Data Preprocessing (Only for McKinley and Marrakech)\n",
|
261 |
+
"import rasterio\n",
|
262 |
+
"from rasterio.enums import Resampling\n",
|
263 |
+
"import numpy as np\n",
|
264 |
+
"from scipy.ndimage import gaussian_filter\n",
|
265 |
+
"import os\n",
|
266 |
+
"\n",
|
267 |
+
"folder_names = {'mckinley': 'McKinley', 'marrakech': 'Marrakech'}\n",
|
268 |
+
"\n",
|
269 |
+
"def preprocess_dataset(dataset_name, is_training=True, custom_base=None):\n",
|
270 |
+
" if custom_base:\n",
|
271 |
+
" base_dir = custom_base\n",
|
272 |
+
" else:\n",
|
273 |
+
" base_dir = f'/content/drive/MyDrive/DEM_Project/Training_Data/{folder_names[dataset_name]}'\n",
|
274 |
+
" \n",
|
275 |
+
" srtm_path = os.path.join(base_dir, f'{dataset_name}_srtm.tif')\n",
|
276 |
+
" s2_path = os.path.join(base_dir, f'{dataset_name}_sentinel.tif')\n",
|
277 |
+
" hr_path = os.path.join(base_dir, f'{dataset_name}_hr_dem.tif') if is_training else None\n",
|
278 |
+
" output_dir = base_dir\n",
|
279 |
+
"\n",
|
280 |
+
" with rasterio.open(srtm_path) as srtm_src, rasterio.open(s2_path) as s2_src:\n",
|
281 |
+
" target_shape = (s2_src.height, s2_src.width)\n",
|
282 |
+
" srtm = srtm_src.read(1, out_shape=target_shape, resampling=Resampling.cubic)\n",
|
283 |
+
"\n",
|
284 |
+
" s2 = s2_src.read()\n",
|
285 |
+
" r, g, b, nir = s2\n",
|
286 |
+
"\n",
|
287 |
+
" ndvi = (nir - r) / (nir + r + 1e-10)\n",
|
288 |
+
"\n",
|
289 |
+
" mask = np.where(srtm == srtm_src.nodata, 1, 0).astype(np.float32)\n",
|
290 |
+
"\n",
|
291 |
+
" if is_training:\n",
|
292 |
+
" with rasterio.open(hr_path) as hr_src:\n",
|
293 |
+
" hr = hr_src.read(1, out_shape=target_shape, resampling=Resampling.cubic) if hr_src.shape != target_shape else hr_src.read(1)\n",
|
294 |
+
"\n",
|
295 |
+
" trend = gaussian_filter(hr, sigma=5)\n",
|
296 |
+
" residual = hr - trend\n",
|
297 |
+
"\n",
|
298 |
+
" target_profile = s2_src.profile\n",
|
299 |
+
" target_profile['count'] = 1\n",
|
300 |
+
" with rasterio.open(os.path.join(output_dir, 'target.tif'), 'w', **target_profile) as dst:\n",
|
301 |
+
" dst.write(residual, 1)\n",
|
302 |
+
"\n",
|
303 |
+
" input_profile = s2_src.profile\n",
|
304 |
+
" input_profile['count'] = 7\n",
|
305 |
+
" with rasterio.open(os.path.join(output_dir, 'input.tif'), 'w', **input_profile) as dst:\n",
|
306 |
+
" dst.write(srtm, 1)\n",
|
307 |
+
" dst.write(r, 2)\n",
|
308 |
+
" dst.write(g, 3)\n",
|
309 |
+
" dst.write(b, 4)\n",
|
310 |
+
" dst.write(nir, 5)\n",
|
311 |
+
" dst.write(ndvi, 6)\n",
|
312 |
+
" dst.write(mask, 7)\n",
|
313 |
+
"\n",
|
314 |
+
"# Preprocess McKinley\n",
|
315 |
+
"preprocess_dataset('mckinley', is_training=True)\n",
|
316 |
+
"\n",
|
317 |
+
"# Preprocess Marrakech (no HR)\n",
|
318 |
+
"preprocess_dataset('marrakech', is_training=False, custom_base='/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech')\n",
|
319 |
+
"\n",
|
320 |
+
"!df -h /content/drive/MyDrive"
|
321 |
+
]
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"cell_type": "code",
|
325 |
+
"execution_count": null,
|
326 |
+
"id": "7933d058",
|
327 |
+
"metadata": {},
|
328 |
+
"outputs": [],
|
329 |
+
"source": [
|
330 |
+
"# Cell 8: Custom Dataset Class\n",
|
331 |
+
"import albumentations as A\n",
|
332 |
+
"from albumentations.pytorch import ToTensorV2\n",
|
333 |
+
"from torch.utils.data import Dataset\n",
|
334 |
+
"import rasterio.windows\n",
|
335 |
+
"\n",
|
336 |
+
"class CustomDEMDataset(Dataset):\n",
|
337 |
+
" def __init__(self, data_dirs, tile_size=256, transform=None):\n",
|
338 |
+
" self.pairs = []\n",
|
339 |
+
" for d_dir in data_dirs:\n",
|
340 |
+
" input_path = os.path.join(d_dir, 'input.tif')\n",
|
341 |
+
" target_path = os.path.join(d_dir, 'target.tif')\n",
|
342 |
+
" if os.path.exists(input_path) and os.path.exists(target_path):\n",
|
343 |
+
" self.pairs.append((input_path, target_path))\n",
|
344 |
+
" self.tile_size = tile_size\n",
|
345 |
+
" self.transform = transform or A.Compose([\n",
|
346 |
+
" A.RandomCrop(height=tile_size, width=tile_size),\n",
|
347 |
+
" A.RandomRotate90(),\n",
|
348 |
+
" A.HorizontalFlip(),\n",
|
349 |
+
" A.VerticalFlip(),\n",
|
350 |
+
" A.GaussNoise(var_limit=(0.01, 0.01)),\n",
|
351 |
+
" ToTensorV2()\n",
|
352 |
+
" ])\n",
|
353 |
+
"\n",
|
354 |
+
" def __len__(self):\n",
|
355 |
+
" return len(self.pairs) * 50\n",
|
356 |
+
"\n",
|
357 |
+
" def __getitem__(self, idx):\n",
|
358 |
+
" input_path, target_path = self.pairs[idx % len(self.pairs)]\n",
|
359 |
+
" with rasterio.open(input_path) as inp, rasterio.open(target_path) as tgt:\n",
|
360 |
+
" max_col = inp.width - self.tile_size\n",
|
361 |
+
" max_row = inp.height - self.tile_size\n",
|
362 |
+
" col_off = np.random.randint(0, max_col + 1)\n",
|
363 |
+
" row_off = np.random.randint(0, max_row + 1)\n",
|
364 |
+
" window = rasterio.windows.Window(col_off, row_off, self.tile_size, self.tile_size)\n",
|
365 |
+
" input_data = inp.read(window=window)\n",
|
366 |
+
" target_data = tgt.read(1, window=window)\n",
|
367 |
+
"\n",
|
368 |
+
" data = {'image': input_data.transpose(1,2,0).astype(np.float32), 'target': target_data.astype(np.float32)}\n",
|
369 |
+
" augmented = self.transform(image=data['image'], mask=data['target'])\n",
|
370 |
+
" return augmented['image'], augmented['mask'].unsqueeze(0)\n",
|
371 |
+
"\n",
|
372 |
+
"train_dirs = ['/content/drive/MyDrive/DEM_Project/Training_Data/McKinley']\n",
|
373 |
+
"dataset = CustomDEMDataset(train_dirs)"
|
374 |
+
]
|
375 |
+
},
|
376 |
+
{
|
377 |
+
"cell_type": "code",
|
378 |
+
"execution_count": null,
|
379 |
+
"id": "46cf2c6d",
|
380 |
+
"metadata": {},
|
381 |
+
"outputs": [],
|
382 |
+
"source": [
|
383 |
+
"# Cell 9: Model and Training (Using Only McKinley)\n",
|
384 |
+
"import pytorch_lightning as pl\n",
|
385 |
+
"import segmentation_models_pytorch as smp\n",
|
386 |
+
"from torch.utils.data import DataLoader\n",
|
387 |
+
"import torch\n",
|
388 |
+
"import torch.nn as nn\n",
|
389 |
+
"\n",
|
390 |
+
"class DeepDEMRefinement(pl.LightningModule):\n",
|
391 |
+
" def __init__(self, lr=1e-4):\n",
|
392 |
+
" super().__init__()\n",
|
393 |
+
" self.model = smp.Unet(encoder_name='resnet34', in_channels=7, classes=1, activation=None)\n",
|
394 |
+
" self.loss_fn = nn.L1Loss()\n",
|
395 |
+
" self.lr = lr\n",
|
396 |
+
"\n",
|
397 |
+
" def forward(self, x):\n",
|
398 |
+
" return self.model(x)\n",
|
399 |
+
"\n",
|
400 |
+
" def training_step(self, batch, batch_idx):\n",
|
401 |
+
" inputs, targets = batch\n",
|
402 |
+
" preds = self(inputs)\n",
|
403 |
+
" loss = self.loss_fn(preds, targets)\n",
|
404 |
+
" self.log('train_loss', loss)\n",
|
405 |
+
" return loss\n",
|
406 |
+
"\n",
|
407 |
+
" def configure_optimizers(self):\n",
|
408 |
+
" return torch.optim.Adam(self.parameters(), lr=self.lr)\n",
|
409 |
+
"\n",
|
410 |
+
"class DEMDataModule(pl.LightningDataModule):\n",
|
411 |
+
" def __init__(self, train_dirs, batch_size=4):\n",
|
412 |
+
" super().__init__()\n",
|
413 |
+
" self.train_dataset = CustomDEMDataset(train_dirs)\n",
|
414 |
+
" self.batch_size = batch_size\n",
|
415 |
+
"\n",
|
416 |
+
" def train_dataloader(self):\n",
|
417 |
+
" return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)\n",
|
418 |
+
"\n",
|
419 |
+
"model = DeepDEMRefinement()\n",
|
420 |
+
"datamodule = DEMDataModule(train_dirs)\n",
|
421 |
+
"trainer = pl.Trainer(max_epochs=5, accelerator='gpu', devices=1) # Training for 5 epochs\n",
|
422 |
+
"trainer.fit(model, datamodule)\n",
|
423 |
+
"\n",
|
424 |
+
"trainer.save_checkpoint('/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt')"
|
425 |
+
]
|
426 |
+
},
|
427 |
+
{
|
428 |
+
"cell_type": "code",
|
429 |
+
"execution_count": null,
|
430 |
+
"id": "30872060",
|
431 |
+
"metadata": {},
|
432 |
+
"outputs": [],
|
433 |
+
"source": [
|
434 |
+
"# Cell 10: Inference for Marrakech\n",
|
435 |
+
"model = DeepDEMRefinement.load_from_checkpoint('/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt')\n",
|
436 |
+
"model.eval()\n",
|
437 |
+
"model.to('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
438 |
+
"\n",
|
439 |
+
"input_path = '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech/input.tif'\n",
|
440 |
+
"with rasterio.open(input_path) as src:\n",
|
441 |
+
" input_data = src.read().astype(np.float32)\n",
|
442 |
+
" trend = gaussian_filter(input_data[0], sigma=5)\n",
|
443 |
+
"\n",
|
444 |
+
" input_tensor = torch.from_numpy(input_data).unsqueeze(0).to(model.device)\n",
|
445 |
+
"\n",
|
446 |
+
" with torch.no_grad():\n",
|
447 |
+
" residual_pred = model(input_tensor)\n",
|
448 |
+
"\n",
|
449 |
+
" synth_dem = residual_pred.squeeze().cpu().numpy() + trend\n",
|
450 |
+
"\n",
|
451 |
+
" profile = src.profile\n",
|
452 |
+
" profile['count'] = 1\n",
|
453 |
+
" with rasterio.open('/content/drive/MyDrive/DEM_Project/synth_dem_marrakech.tif', 'w', **profile) as dst:\n",
|
454 |
+
" dst.write(synth_dem, 1)\n",
|
455 |
+
"\n",
|
456 |
+
"print('Synthetic DEM generated for Marrakech!')"
|
457 |
+
]
|
458 |
+
},
|
459 |
+
{
|
460 |
+
"cell_type": "markdown",
|
461 |
+
"id": "8943e470",
|
462 |
+
"metadata": {},
|
463 |
+
"source": [
|
464 |
+
"# Quick correctness checks\n",
|
465 |
+
"\n",
|
466 |
+
"This section runs a few sanity checks on the trained model and data:\n",
|
467 |
+
"\n",
|
468 |
+
"- Validate shapes, CRS, and basic channel statistics of `input.tif` and `target.tif`\n",
|
469 |
+
"- Compute masked MAE/RMSE on random training crops (McKinley) to gauge training fit\n",
|
470 |
+
"- Flag obvious issues (e.g., all-zeros bands, nodata dominance)\n",
|
471 |
+
"\n",
|
472 |
+
"Run cells in order after training has completed."
|
473 |
+
]
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"cell_type": "code",
|
477 |
+
"execution_count": null,
|
478 |
+
"id": "3c35167d",
|
479 |
+
"metadata": {},
|
480 |
+
"outputs": [],
|
481 |
+
"source": [
|
482 |
+
"# Check 1: Inspect input/target rasters (McKinley)\n",
|
483 |
+
"import rasterio\n",
|
484 |
+
"import numpy as np\n",
|
485 |
+
"from pathlib import Path\n",
|
486 |
+
"\n",
|
487 |
+
"train_dir = Path('/content/drive/MyDrive/DEM_Project/Training_Data/McKinley')\n",
|
488 |
+
"input_path = train_dir / 'input.tif'\n",
|
489 |
+
"target_path = train_dir / 'target.tif'\n",
|
490 |
+
"\n",
|
491 |
+
"issues = []\n",
|
492 |
+
"\n",
|
493 |
+
"with rasterio.open(input_path) as src:\n",
|
494 |
+
" print('INPUT:')\n",
|
495 |
+
" print({'count': src.count, 'width': src.width, 'height': src.height, 'crs': str(src.crs), 'dtype': src.dtypes})\n",
|
496 |
+
" data = src.read(out_dtype='float32')\n",
|
497 |
+
" nodata = src.nodata\n",
|
498 |
+
" band_stats = []\n",
|
499 |
+
" for i in range(src.count):\n",
|
500 |
+
" b = data[i]\n",
|
501 |
+
" if nodata is not None:\n",
|
502 |
+
" mask = b == nodata\n",
|
503 |
+
" valid = np.where(mask, np.nan, b)\n",
|
504 |
+
" else:\n",
|
505 |
+
" valid = b\n",
|
506 |
+
" mask = np.zeros_like(b, dtype=bool)\n",
|
507 |
+
" s = {\n",
|
508 |
+
" 'band': i+1,\n",
|
509 |
+
" 'nan_frac': float(np.mean(np.isnan(valid))),\n",
|
510 |
+
" 'nodata_frac': float(np.mean(mask)),\n",
|
511 |
+
" 'min': float(np.nanmin(valid)),\n",
|
512 |
+
" 'max': float(np.nanmax(valid)),\n",
|
513 |
+
" 'mean': float(np.nanmean(valid)),\n",
|
514 |
+
" 'std': float(np.nanstd(valid)),\n",
|
515 |
+
" }\n",
|
516 |
+
" band_stats.append(s)\n",
|
517 |
+
" print('Input band stats (1:dsm, 2:R, 3:G, 4:B, 5:NIR, 6:NDVI, 7:mask):')\n",
|
518 |
+
" for s in band_stats:\n",
|
519 |
+
" print(s)\n",
|
520 |
+
" # Basic checks\n",
|
521 |
+
" if band_stats[5]['min'] < -1.01 or band_stats[5]['max'] > 1.01:\n",
|
522 |
+
" issues.append('NDVI out of expected [-1,1] range; check scaling and bands (R,NIR indices).')\n",
|
523 |
+
" if band_stats[6]['mean'] < 0.01 and band_stats[6]['max'] < 0.5:\n",
|
524 |
+
" issues.append('Mask band appears mostly zeros; ensure mask=1 at nodata pixels, 0 elsewhere.')\n",
|
525 |
+
"\n",
|
526 |
+
"with rasterio.open(target_path) as src:\n",
|
527 |
+
" print('\\nTARGET:')\n",
|
528 |
+
" print({'count': src.count, 'width': src.width, 'height': src.height, 'crs': str(src.crs), 'dtype': src.dtypes})\n",
|
529 |
+
" t = src.read(1, out_dtype='float32')\n",
|
530 |
+
" print({'min': float(np.nanmin(t)), 'max': float(np.nanmax(t)), 'mean': float(np.nanmean(t)), 'std': float(np.nanstd(t))})\n",
|
531 |
+
" if np.allclose(t, 0):\n",
|
532 |
+
" issues.append('Target residual is all zeros; check HR DEM loading and detrending step.')\n",
|
533 |
+
"\n",
|
534 |
+
"print('\\nPotential issues:')\n",
|
535 |
+
"print(issues if issues else 'None detected')"
|
536 |
+
]
|
537 |
+
},
|
538 |
+
{
|
539 |
+
"cell_type": "code",
|
540 |
+
"execution_count": null,
|
541 |
+
"id": "ad78e542",
|
542 |
+
"metadata": {},
|
543 |
+
"outputs": [],
|
544 |
+
"source": [
|
545 |
+
"# Check 2: Compute quick masked MAE/RMSE on random training crops\n",
|
546 |
+
"import torch\n",
|
547 |
+
"import torch.nn.functional as F\n",
|
548 |
+
"from torch.utils.data import DataLoader\n",
|
549 |
+
"import numpy as np\n",
|
550 |
+
"\n",
|
551 |
+
"# Reuse dataset and model classes already defined earlier\n",
|
552 |
+
"try:\n",
|
553 |
+
" _ = CustomDEMDataset\n",
|
554 |
+
"except NameError:\n",
|
555 |
+
" raise RuntimeError('CustomDEMDataset not defined; run earlier cells first.')\n",
|
556 |
+
"\n",
|
557 |
+
"try:\n",
|
558 |
+
" _ = DeepDEMRefinement\n",
|
559 |
+
"except NameError:\n",
|
560 |
+
" raise RuntimeError('DeepDEMRefinement not defined; run training cells first.')\n",
|
561 |
+
"\n",
|
562 |
+
"# Load model\n",
|
563 |
+
"ckpt = '/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt'\n",
|
564 |
+
"model = DeepDEMRefinement.load_from_checkpoint(ckpt)\n",
|
565 |
+
"model.eval()\n",
|
566 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
567 |
+
"model.to(device)\n",
|
568 |
+
"\n",
|
569 |
+
"# Small eval dataset with deterministic crops\n",
|
570 |
+
"np.random.seed(42)\n",
|
571 |
+
"transform = A.Compose([\n",
|
572 |
+
" A.RandomCrop(height=256, width=256),\n",
|
573 |
+
" ToTensorV2()\n",
|
574 |
+
"])\n",
|
575 |
+
"\n",
|
576 |
+
"val_ds = CustomDEMDataset([str(train_dir)], tile_size=256, transform=transform)\n",
|
577 |
+
"val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=2)\n",
|
578 |
+
"\n",
|
579 |
+
"maes, rmses = [], []\n",
|
580 |
+
"with torch.no_grad():\n",
|
581 |
+
" for i, (x, y) in enumerate(val_loader):\n",
|
582 |
+
" if i >= 10: # ~40 tiles\n",
|
583 |
+
" break\n",
|
584 |
+
" x = x.to(device)\n",
|
585 |
+
" y = y.to(device)\n",
|
586 |
+
" pred = model(x)\n",
|
587 |
+
" # If mask channel included, optionally down-weight masked pixels\n",
|
588 |
+
" mask = x[:, 6:7] # channel 7\n",
|
589 |
+
" valid = (mask < 0.5).float()\n",
|
590 |
+
" diff = (pred - y) * valid\n",
|
591 |
+
" denom = valid.sum().clamp_min(1.0)\n",
|
592 |
+
" mae = diff.abs().sum() / denom\n",
|
593 |
+
" rmse = torch.sqrt((diff.pow(2).sum() / denom))\n",
|
594 |
+
" maes.append(mae.item())\n",
|
595 |
+
" rmses.append(rmse.item())\n",
|
596 |
+
"\n",
|
597 |
+
"print({'MAE_mean': float(np.mean(maes)), 'MAE_std': float(np.std(maes)), 'RMSE_mean': float(np.mean(rmses)), 'RMSE_std': float(np.std(rmses)), 'tiles': len(maes)*val_loader.batch_size})\n",
|
598 |
+
"\n",
|
599 |
+
"if np.mean(rmses) > 8.0:\n",
|
600 |
+
" print('Warning: High RMSE for residuals. Training may be underfit or target scaling may be off.')\n",
|
601 |
+
"else:\n",
|
602 |
+
" print('Residual error looks reasonable for the training run.')"
|
603 |
+
]
|
604 |
+
},
|
605 |
+
{
|
606 |
+
"cell_type": "code",
|
607 |
+
"execution_count": null,
|
608 |
+
"id": "def70a98",
|
609 |
+
"metadata": {},
|
610 |
+
"outputs": [],
|
611 |
+
"source": [
|
612 |
+
"# Check 3: Sanity-check inference output alignment vs input for Marrakech\n",
|
613 |
+
"from scipy.ndimage import gaussian_filter\n",
|
614 |
+
"\n",
|
615 |
+
"marrakech_input = '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech/input.tif'\n",
|
616 |
+
"marrakech_out = '/content/drive/MyDrive/DEM_Project/synth_dem_marrakech.tif'\n",
|
617 |
+
"\n",
|
618 |
+
"with rasterio.open(marrakech_input) as src_in, rasterio.open(marrakech_out) as src_out:\n",
|
619 |
+
" print('INFERENCE INPUT:', {'shape': (src_in.count, src_in.height, src_in.width), 'crs': str(src_in.crs), 'transform': tuple(src_in.transform)})\n",
|
620 |
+
" print('SYNTH OUTPUT:', {'shape': (src_out.count, src_out.height, src_out.width), 'crs': str(src_out.crs), 'transform': tuple(src_out.transform)})\n",
|
621 |
+
" if src_in.crs != src_out.crs:\n",
|
622 |
+
" print('Warning: CRS mismatch between input and output!')\n",
|
623 |
+
" if (src_in.height != src_out.height) or (src_in.width != src_out.width):\n",
|
624 |
+
" print('Warning: Dimension mismatch between input and output!')\n",
|
625 |
+
"\n",
|
626 |
+
" out_dem = src_out.read(1).astype('float32')\n",
|
627 |
+
" # Simple terrain sanity: residual-added trend should correlate with SRTM trend\n",
|
628 |
+
" srtm = src_in.read(1).astype('float32')\n",
|
629 |
+
" trend = gaussian_filter(srtm, sigma=5)\n",
|
630 |
+
" corr = np.corrcoef(trend.flatten(), out_dem.flatten())[0,1]\n",
|
631 |
+
" print('Correlation between SRTM trend and synthetic DEM:', float(corr))\n",
|
632 |
+
" if corr < 0.5:\n",
|
633 |
+
" print('Low correlation; output may be noisy or misaligned.')"
|
634 |
+
]
|
635 |
+
}
|
636 |
+
],
|
637 |
+
"metadata": {
|
638 |
+
"language_info": {
|
639 |
+
"name": "python"
|
640 |
+
}
|
641 |
+
},
|
642 |
+
"nbformat": 4,
|
643 |
+
"nbformat_minor": 5
|
644 |
+
}
|