Upload DEM_SuperRes.ipynb with huggingface_hub
Browse files- DEM_SuperRes.ipynb +644 -0
DEM_SuperRes.ipynb
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "f541ffd4",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Synthetic High-Resolution DEM Generation for Marrakech, Morocco\n",
|
| 9 |
+
"# Using Only McKinley Dataset for Training\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"This notebook implements the full pipeline, training only on the McKinley dataset to generate a model for super-resolving 30m SRTM to 10m DEMs fused with Sentinel-2 imagery for Marrakech, Morocco.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"**Key Assumptions:**\n",
|
| 14 |
+
"- Training on McKinley Mine NM high-res LiDAR DEM.\n",
|
| 15 |
+
"- Inference on Marrakech mountain area.\n",
|
| 16 |
+
"- Adapted DeepDEM model with 7 input channels.\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"Run cells sequentially."
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": null,
|
| 24 |
+
"id": "b7aa9465",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"# Cell 1: Install Dependencies\n",
|
| 29 |
+
"!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
|
| 30 |
+
"!pip install pytorch-lightning torchgeo segmentation-models-pytorch rasterio geopandas albumentations scipy gdown earthengine-api\n",
|
| 31 |
+
"!apt-get install -y libspatialindex-dev libgdal-dev\n",
|
| 32 |
+
"!pip install gdal==$(gdal-config --version)"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "code",
|
| 37 |
+
"execution_count": null,
|
| 38 |
+
"id": "c4f399ac",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"outputs": [],
|
| 41 |
+
"source": [
|
| 42 |
+
"# Cell 2: Mount Google Drive and Set Up Directories\n",
|
| 43 |
+
"from google.colab import drive\n",
|
| 44 |
+
"drive.mount('/content/drive')\n",
|
| 45 |
+
"%cd /content/drive/MyDrive/DEM_Project\n",
|
| 46 |
+
"!mkdir -p Training_Data/McKinley Inference_Data/Marrackech Models"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": null,
|
| 52 |
+
"id": "ed508a36",
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"outputs": [],
|
| 55 |
+
"source": [
|
| 56 |
+
"# Cell 3: Clone DeepDEM Repo and Adapt for Our Use Case\n",
|
| 57 |
+
"!git clone https://github.com/uw-cryo/DeepDEM.git\n",
|
| 58 |
+
"%cd DeepDEM\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"# Adapt model for our inputs: Modify task_module.py to accept ['dsm', 'ortho_r', 'ortho_g', 'ortho_b', 'ortho_nir', 'ndvi', 'nodata_mask'] (7 channels)\n",
|
| 61 |
+
"# Set model in_channels=7, out_channels=1 (residuals)\n",
|
| 62 |
+
"# For simplicity, assume manual edit or duplicate code here.\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"import os\n",
|
| 65 |
+
"os.environ['PYTHONPATH'] += ':/content/drive/MyDrive/DEM_Project/DeepDEM'"
|
| 66 |
+
]
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"cell_type": "code",
|
| 70 |
+
"execution_count": null,
|
| 71 |
+
"id": "d7bb1f40",
|
| 72 |
+
"metadata": {},
|
| 73 |
+
"outputs": [],
|
| 74 |
+
"source": [
|
| 75 |
+
"# Cell 4: Authenticate and Initialize Earth Engine\n",
|
| 76 |
+
"from google.colab import auth\n",
|
| 77 |
+
"import ee\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"# 1. Authenticate your Google user\n",
|
| 80 |
+
"auth.authenticate_user()\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"# 2. Initialize Earth Engine with your Google Cloud Project ID\n",
|
| 83 |
+
"# REPLACE 'your-gcp-project-id' with the actual ID of your project\n",
|
| 84 |
+
"try:\n",
|
| 85 |
+
" ee.Initialize(project='dem-collab')\n",
|
| 86 |
+
" print(\"Earth Engine initialized successfully!\")\n",
|
| 87 |
+
"except ee.EEException as e:\n",
|
| 88 |
+
" print(f\"Error during initialization: {e}\")"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": null,
|
| 94 |
+
"id": "091b3f03",
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [],
|
| 97 |
+
"source": [
|
| 98 |
+
"# Cell 5: Data Acquisition Function (SRTM + Sentinel-2 for McKinley and Marrakech)\n",
|
| 99 |
+
"import os\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"def fetch_gee_data(bbox, output_dir, dataset_name):\n",
|
| 102 |
+
" os.makedirs(output_dir, exist_ok=True)\n",
|
| 103 |
+
"\n",
|
| 104 |
+
" geom = ee.Geometry.BBox(*bbox) # Define geometry for region\n",
|
| 105 |
+
"\n",
|
| 106 |
+
" # SRTM (30m)\n",
|
| 107 |
+
" srtm = ee.Image('CGIAR/SRTM90_V4').clip(geom).rename('dsm')\n",
|
| 108 |
+
" task_srtm = ee.batch.Export.image.toDrive(\n",
|
| 109 |
+
" image=srtm,\n",
|
| 110 |
+
" description=f'{dataset_name}_srtm',\n",
|
| 111 |
+
" folder=output_dir.split('/')[-1],\n",
|
| 112 |
+
" scale=30,\n",
|
| 113 |
+
" fileFormat='GeoTIFF',\n",
|
| 114 |
+
" region=geom\n",
|
| 115 |
+
" )\n",
|
| 116 |
+
" task_srtm.start()\n",
|
| 117 |
+
"\n",
|
| 118 |
+
" # Sentinel-2 (10m, cloud-free median, RGB + NIR)\n",
|
| 119 |
+
" # Sharper image: sort by cloud cover and take the best one from a good season\n",
|
| 120 |
+
" collection = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \\\n",
|
| 121 |
+
" .filterBounds(geom) \\\n",
|
| 122 |
+
" .filterDate('2023-06-01', '2023-10-31') \\\n",
|
| 123 |
+
" .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 10)) \\\n",
|
| 124 |
+
" .sort('CLOUDY_PIXEL_PERCENTAGE')\n",
|
| 125 |
+
"\n",
|
| 126 |
+
" num_images = collection.size().getInfo()\n",
|
| 127 |
+
" print(f'For {dataset_name} bbox: {num_images} Sentinel-2 images found.')\n",
|
| 128 |
+
"\n",
|
| 129 |
+
" # Get the best image\n",
|
| 130 |
+
" best_image = collection.first()\n",
|
| 131 |
+
"\n",
|
| 132 |
+
" # Export the raw 4-band image for the model\n",
|
| 133 |
+
" sentinel_raw = best_image.select(['B4','B3','B2','B8'])\n",
|
| 134 |
+
" task_s2_raw = ee.batch.Export.image.toDrive(\n",
|
| 135 |
+
" image=sentinel_raw,\n",
|
| 136 |
+
" description=f'{dataset_name}_sentinel', # Keep original name for downstream tasks\n",
|
| 137 |
+
" folder=output_dir.split('/')[-1],\n",
|
| 138 |
+
" scale=10,\n",
|
| 139 |
+
" fileFormat='GeoTIFF',\n",
|
| 140 |
+
" region=geom\n",
|
| 141 |
+
" )\n",
|
| 142 |
+
" task_s2_raw.start()\n",
|
| 143 |
+
"\n",
|
| 144 |
+
" # Export a separate, color-corrected visual version for inspection\n",
|
| 145 |
+
" sentinel_viz = best_image.visualize(min=0, max=3000, bands=['B4', 'B3', 'B2'])\n",
|
| 146 |
+
" task_s2_viz = ee.batch.Export.image.toDrive(\n",
|
| 147 |
+
" image=sentinel_viz,\n",
|
| 148 |
+
" description=f'{dataset_name}_sentinel_viz',\n",
|
| 149 |
+
" folder=output_dir.split('/')[-1],\n",
|
| 150 |
+
" scale=10,\n",
|
| 151 |
+
" fileFormat='GeoTIFF',\n",
|
| 152 |
+
" region=geom\n",
|
| 153 |
+
" )\n",
|
| 154 |
+
" task_s2_viz.start()\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"# Bounding boxes\n",
|
| 158 |
+
"bbox_mckinley = [-109.03892074228675, 35.58282920746211, -108.87077846472735, 35.736434167381475]\n",
|
| 159 |
+
"fetch_gee_data(bbox_mckinley, '/content/drive/MyDrive/DEM_Project/Training_Data/McKinley', 'mckinley')\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"bbox_marrakech = [-8.1, 31.5, -7.9, 31.7]\n",
|
| 162 |
+
"fetch_gee_data(bbox_marrakech, '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech', 'marrakech')"
|
| 163 |
+
]
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"cell_type": "code",
|
| 167 |
+
"execution_count": null,
|
| 168 |
+
"id": "5f9b0934",
|
| 169 |
+
"metadata": {},
|
| 170 |
+
"outputs": [],
|
| 171 |
+
"source": [
|
| 172 |
+
"# Cell 6: Download High-Resolution DEM Tiles and Merge (Only for McKinley)\n",
|
| 173 |
+
"!pip install boto3 gdal retry\n",
|
| 174 |
+
"!apt install gdal-bin\n",
|
| 175 |
+
"import boto3\n",
|
| 176 |
+
"import os\n",
|
| 177 |
+
"import shutil\n",
|
| 178 |
+
"from botocore import UNSIGNED\n",
|
| 179 |
+
"from botocore.client import Config\n",
|
| 180 |
+
"from retry import retry\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"# S3 Configuration\n",
|
| 183 |
+
"endpoint_url = 'https://opentopography.s3.sdsc.edu'\n",
|
| 184 |
+
"client = boto3.client('s3', endpoint_url=endpoint_url, config=Config(signature_version=UNSIGNED))\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"# Temp and data dirs\n",
|
| 187 |
+
"temp_base = '/content/hr_temp'\n",
|
| 188 |
+
"data_base = '/content/drive/MyDrive/DEM_Project/Training_Data'\n",
|
| 189 |
+
"os.makedirs(temp_base, exist_ok=True)\n",
|
| 190 |
+
"os.makedirs(data_base, exist_ok=True)\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"# Only McKinley\n",
|
| 193 |
+
"datasets = {'mckinley': 'NM23_McKinley'}\n",
|
| 194 |
+
"folder_names = {'mckinley': 'McKinley'}\n",
|
| 195 |
+
"\n",
|
| 196 |
+
"@retry(tries=3, delay=2, backoff=2)\n",
|
| 197 |
+
"def download_file_with_retry(bucket, key, filename):\n",
|
| 198 |
+
" client.download_file(Bucket=bucket, Key=key, Filename=filename)\n",
|
| 199 |
+
"\n",
|
| 200 |
+
"!df -h /content\n",
|
| 201 |
+
"for local_name, s3_dir in datasets.items():\n",
|
| 202 |
+
" temp_dir = os.path.join(temp_base, local_name)\n",
|
| 203 |
+
" dataset_dir = os.path.join(data_base, folder_names[local_name])\n",
|
| 204 |
+
" os.makedirs(temp_dir, exist_ok=True)\n",
|
| 205 |
+
" os.makedirs(dataset_dir, exist_ok=True)\n",
|
| 206 |
+
"\n",
|
| 207 |
+
" output_tif = os.path.join(dataset_dir, f'{local_name}_hr_dem.tif')\n",
|
| 208 |
+
" if os.path.exists(output_tif):\n",
|
| 209 |
+
" print(f'Merged DEM already exists for {local_name}: {output_tif}, skipping download and merge.')\n",
|
| 210 |
+
" continue\n",
|
| 211 |
+
"\n",
|
| 212 |
+
" paginator = client.get_paginator('list_objects_v2')\n",
|
| 213 |
+
" prefix = f'{s3_dir}/{s3_dir}_be/'\n",
|
| 214 |
+
" downloaded_files = []\n",
|
| 215 |
+
" try:\n",
|
| 216 |
+
" for page in paginator.paginate(Bucket='raster', Prefix=prefix):\n",
|
| 217 |
+
" for obj in page.get('Contents', []):\n",
|
| 218 |
+
" key = obj['Key']\n",
|
| 219 |
+
" if key.endswith('.tif'):\n",
|
| 220 |
+
" file_path = os.path.join(temp_dir, os.path.basename(key))\n",
|
| 221 |
+
" # Check if the tile file already exists\n",
|
| 222 |
+
" if os.path.exists(file_path):\n",
|
| 223 |
+
" print(f'Tile {os.path.basename(key)} already exists for {local_name}, skipping download.')\n",
|
| 224 |
+
" downloaded_files.append(file_path)\n",
|
| 225 |
+
" continue\n",
|
| 226 |
+
" try:\n",
|
| 227 |
+
" download_file_with_retry('raster', key, file_path)\n",
|
| 228 |
+
" downloaded_files.append(file_path)\n",
|
| 229 |
+
" print(f'Downloaded {os.path.basename(key)} for {local_name}')\n",
|
| 230 |
+
" except Exception as e:\n",
|
| 231 |
+
" print(f'Error downloading {key} for {local_name}: {e}')\n",
|
| 232 |
+
" except Exception as e:\n",
|
| 233 |
+
" print(f'Error listing tiles for {local_name}: {e}')\n",
|
| 234 |
+
"\n",
|
| 235 |
+
" if not downloaded_files:\n",
|
| 236 |
+
" print(f'No tiles downloaded for {local_name}; skipping merge.')\n",
|
| 237 |
+
" continue\n",
|
| 238 |
+
"\n",
|
| 239 |
+
" try:\n",
|
| 240 |
+
" # Using gdalbuildvrt and gdal_translate for better performance\n",
|
| 241 |
+
" !gdalbuildvrt merged.vrt {\" \".join(downloaded_files)}\n",
|
| 242 |
+
" !gdal_translate -of GTiff merged.vrt \"{output_tif}\" -co TILED=YES -co COMPRESS=DEFLATE -co NUM_THREADS=ALL_CPUS\n",
|
| 243 |
+
" print(f'Merged tiles to TIFF for {local_name}: {output_tif}')\n",
|
| 244 |
+
" except Exception as e:\n",
|
| 245 |
+
" print(f'Error merging tiles for {local_name}: {e}')\n",
|
| 246 |
+
" continue\n",
|
| 247 |
+
" # Removed the shutil.rmtree(temp_dir) line as requested\n",
|
| 248 |
+
"\n",
|
| 249 |
+
"!df -h /content\n",
|
| 250 |
+
"print('Download complete for McKinley!')"
|
| 251 |
+
]
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"cell_type": "code",
|
| 255 |
+
"execution_count": null,
|
| 256 |
+
"id": "7f82f16c",
|
| 257 |
+
"metadata": {},
|
| 258 |
+
"outputs": [],
|
| 259 |
+
"source": [
|
| 260 |
+
"# Cell 7: Data Preprocessing (Only for McKinley and Marrakech)\n",
|
| 261 |
+
"import rasterio\n",
|
| 262 |
+
"from rasterio.enums import Resampling\n",
|
| 263 |
+
"import numpy as np\n",
|
| 264 |
+
"from scipy.ndimage import gaussian_filter\n",
|
| 265 |
+
"import os\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"folder_names = {'mckinley': 'McKinley', 'marrakech': 'Marrakech'}\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"def preprocess_dataset(dataset_name, is_training=True, custom_base=None):\n",
|
| 270 |
+
" if custom_base:\n",
|
| 271 |
+
" base_dir = custom_base\n",
|
| 272 |
+
" else:\n",
|
| 273 |
+
" base_dir = f'/content/drive/MyDrive/DEM_Project/Training_Data/{folder_names[dataset_name]}'\n",
|
| 274 |
+
" \n",
|
| 275 |
+
" srtm_path = os.path.join(base_dir, f'{dataset_name}_srtm.tif')\n",
|
| 276 |
+
" s2_path = os.path.join(base_dir, f'{dataset_name}_sentinel.tif')\n",
|
| 277 |
+
" hr_path = os.path.join(base_dir, f'{dataset_name}_hr_dem.tif') if is_training else None\n",
|
| 278 |
+
" output_dir = base_dir\n",
|
| 279 |
+
"\n",
|
| 280 |
+
" with rasterio.open(srtm_path) as srtm_src, rasterio.open(s2_path) as s2_src:\n",
|
| 281 |
+
" target_shape = (s2_src.height, s2_src.width)\n",
|
| 282 |
+
" srtm = srtm_src.read(1, out_shape=target_shape, resampling=Resampling.cubic)\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" s2 = s2_src.read()\n",
|
| 285 |
+
" r, g, b, nir = s2\n",
|
| 286 |
+
"\n",
|
| 287 |
+
" ndvi = (nir - r) / (nir + r + 1e-10)\n",
|
| 288 |
+
"\n",
|
| 289 |
+
" mask = np.where(srtm == srtm_src.nodata, 1, 0).astype(np.float32)\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" if is_training:\n",
|
| 292 |
+
" with rasterio.open(hr_path) as hr_src:\n",
|
| 293 |
+
" hr = hr_src.read(1, out_shape=target_shape, resampling=Resampling.cubic) if hr_src.shape != target_shape else hr_src.read(1)\n",
|
| 294 |
+
"\n",
|
| 295 |
+
" trend = gaussian_filter(hr, sigma=5)\n",
|
| 296 |
+
" residual = hr - trend\n",
|
| 297 |
+
"\n",
|
| 298 |
+
" target_profile = s2_src.profile\n",
|
| 299 |
+
" target_profile['count'] = 1\n",
|
| 300 |
+
" with rasterio.open(os.path.join(output_dir, 'target.tif'), 'w', **target_profile) as dst:\n",
|
| 301 |
+
" dst.write(residual, 1)\n",
|
| 302 |
+
"\n",
|
| 303 |
+
" input_profile = s2_src.profile\n",
|
| 304 |
+
" input_profile['count'] = 7\n",
|
| 305 |
+
" with rasterio.open(os.path.join(output_dir, 'input.tif'), 'w', **input_profile) as dst:\n",
|
| 306 |
+
" dst.write(srtm, 1)\n",
|
| 307 |
+
" dst.write(r, 2)\n",
|
| 308 |
+
" dst.write(g, 3)\n",
|
| 309 |
+
" dst.write(b, 4)\n",
|
| 310 |
+
" dst.write(nir, 5)\n",
|
| 311 |
+
" dst.write(ndvi, 6)\n",
|
| 312 |
+
" dst.write(mask, 7)\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"# Preprocess McKinley\n",
|
| 315 |
+
"preprocess_dataset('mckinley', is_training=True)\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"# Preprocess Marrakech (no HR)\n",
|
| 318 |
+
"preprocess_dataset('marrakech', is_training=False, custom_base='/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech')\n",
|
| 319 |
+
"\n",
|
| 320 |
+
"!df -h /content/drive/MyDrive"
|
| 321 |
+
]
|
| 322 |
+
},
|
| 323 |
+
{
|
| 324 |
+
"cell_type": "code",
|
| 325 |
+
"execution_count": null,
|
| 326 |
+
"id": "7933d058",
|
| 327 |
+
"metadata": {},
|
| 328 |
+
"outputs": [],
|
| 329 |
+
"source": [
|
| 330 |
+
"# Cell 8: Custom Dataset Class\n",
|
| 331 |
+
"import albumentations as A\n",
|
| 332 |
+
"from albumentations.pytorch import ToTensorV2\n",
|
| 333 |
+
"from torch.utils.data import Dataset\n",
|
| 334 |
+
"import rasterio.windows\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"class CustomDEMDataset(Dataset):\n",
|
| 337 |
+
" def __init__(self, data_dirs, tile_size=256, transform=None):\n",
|
| 338 |
+
" self.pairs = []\n",
|
| 339 |
+
" for d_dir in data_dirs:\n",
|
| 340 |
+
" input_path = os.path.join(d_dir, 'input.tif')\n",
|
| 341 |
+
" target_path = os.path.join(d_dir, 'target.tif')\n",
|
| 342 |
+
" if os.path.exists(input_path) and os.path.exists(target_path):\n",
|
| 343 |
+
" self.pairs.append((input_path, target_path))\n",
|
| 344 |
+
" self.tile_size = tile_size\n",
|
| 345 |
+
" self.transform = transform or A.Compose([\n",
|
| 346 |
+
" A.RandomCrop(height=tile_size, width=tile_size),\n",
|
| 347 |
+
" A.RandomRotate90(),\n",
|
| 348 |
+
" A.HorizontalFlip(),\n",
|
| 349 |
+
" A.VerticalFlip(),\n",
|
| 350 |
+
" A.GaussNoise(var_limit=(0.01, 0.01)),\n",
|
| 351 |
+
" ToTensorV2()\n",
|
| 352 |
+
" ])\n",
|
| 353 |
+
"\n",
|
| 354 |
+
" def __len__(self):\n",
|
| 355 |
+
" return len(self.pairs) * 50\n",
|
| 356 |
+
"\n",
|
| 357 |
+
" def __getitem__(self, idx):\n",
|
| 358 |
+
" input_path, target_path = self.pairs[idx % len(self.pairs)]\n",
|
| 359 |
+
" with rasterio.open(input_path) as inp, rasterio.open(target_path) as tgt:\n",
|
| 360 |
+
" max_col = inp.width - self.tile_size\n",
|
| 361 |
+
" max_row = inp.height - self.tile_size\n",
|
| 362 |
+
" col_off = np.random.randint(0, max_col + 1)\n",
|
| 363 |
+
" row_off = np.random.randint(0, max_row + 1)\n",
|
| 364 |
+
" window = rasterio.windows.Window(col_off, row_off, self.tile_size, self.tile_size)\n",
|
| 365 |
+
" input_data = inp.read(window=window)\n",
|
| 366 |
+
" target_data = tgt.read(1, window=window)\n",
|
| 367 |
+
"\n",
|
| 368 |
+
" data = {'image': input_data.transpose(1,2,0).astype(np.float32), 'target': target_data.astype(np.float32)}\n",
|
| 369 |
+
" augmented = self.transform(image=data['image'], mask=data['target'])\n",
|
| 370 |
+
" return augmented['image'], augmented['mask'].unsqueeze(0)\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"train_dirs = ['/content/drive/MyDrive/DEM_Project/Training_Data/McKinley']\n",
|
| 373 |
+
"dataset = CustomDEMDataset(train_dirs)"
|
| 374 |
+
]
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"cell_type": "code",
|
| 378 |
+
"execution_count": null,
|
| 379 |
+
"id": "46cf2c6d",
|
| 380 |
+
"metadata": {},
|
| 381 |
+
"outputs": [],
|
| 382 |
+
"source": [
|
| 383 |
+
"# Cell 9: Model and Training (Using Only McKinley)\n",
|
| 384 |
+
"import pytorch_lightning as pl\n",
|
| 385 |
+
"import segmentation_models_pytorch as smp\n",
|
| 386 |
+
"from torch.utils.data import DataLoader\n",
|
| 387 |
+
"import torch\n",
|
| 388 |
+
"import torch.nn as nn\n",
|
| 389 |
+
"\n",
|
| 390 |
+
"class DeepDEMRefinement(pl.LightningModule):\n",
|
| 391 |
+
" def __init__(self, lr=1e-4):\n",
|
| 392 |
+
" super().__init__()\n",
|
| 393 |
+
" self.model = smp.Unet(encoder_name='resnet34', in_channels=7, classes=1, activation=None)\n",
|
| 394 |
+
" self.loss_fn = nn.L1Loss()\n",
|
| 395 |
+
" self.lr = lr\n",
|
| 396 |
+
"\n",
|
| 397 |
+
" def forward(self, x):\n",
|
| 398 |
+
" return self.model(x)\n",
|
| 399 |
+
"\n",
|
| 400 |
+
" def training_step(self, batch, batch_idx):\n",
|
| 401 |
+
" inputs, targets = batch\n",
|
| 402 |
+
" preds = self(inputs)\n",
|
| 403 |
+
" loss = self.loss_fn(preds, targets)\n",
|
| 404 |
+
" self.log('train_loss', loss)\n",
|
| 405 |
+
" return loss\n",
|
| 406 |
+
"\n",
|
| 407 |
+
" def configure_optimizers(self):\n",
|
| 408 |
+
" return torch.optim.Adam(self.parameters(), lr=self.lr)\n",
|
| 409 |
+
"\n",
|
| 410 |
+
"class DEMDataModule(pl.LightningDataModule):\n",
|
| 411 |
+
" def __init__(self, train_dirs, batch_size=4):\n",
|
| 412 |
+
" super().__init__()\n",
|
| 413 |
+
" self.train_dataset = CustomDEMDataset(train_dirs)\n",
|
| 414 |
+
" self.batch_size = batch_size\n",
|
| 415 |
+
"\n",
|
| 416 |
+
" def train_dataloader(self):\n",
|
| 417 |
+
" return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"model = DeepDEMRefinement()\n",
|
| 420 |
+
"datamodule = DEMDataModule(train_dirs)\n",
|
| 421 |
+
"trainer = pl.Trainer(max_epochs=5, accelerator='gpu', devices=1) # Training for 5 epochs\n",
|
| 422 |
+
"trainer.fit(model, datamodule)\n",
|
| 423 |
+
"\n",
|
| 424 |
+
"trainer.save_checkpoint('/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt')"
|
| 425 |
+
]
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
"cell_type": "code",
|
| 429 |
+
"execution_count": null,
|
| 430 |
+
"id": "30872060",
|
| 431 |
+
"metadata": {},
|
| 432 |
+
"outputs": [],
|
| 433 |
+
"source": [
|
| 434 |
+
"# Cell 10: Inference for Marrakech\n",
|
| 435 |
+
"model = DeepDEMRefinement.load_from_checkpoint('/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt')\n",
|
| 436 |
+
"model.eval()\n",
|
| 437 |
+
"model.to('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 438 |
+
"\n",
|
| 439 |
+
"input_path = '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech/input.tif'\n",
|
| 440 |
+
"with rasterio.open(input_path) as src:\n",
|
| 441 |
+
" input_data = src.read().astype(np.float32)\n",
|
| 442 |
+
" trend = gaussian_filter(input_data[0], sigma=5)\n",
|
| 443 |
+
"\n",
|
| 444 |
+
" input_tensor = torch.from_numpy(input_data).unsqueeze(0).to(model.device)\n",
|
| 445 |
+
"\n",
|
| 446 |
+
" with torch.no_grad():\n",
|
| 447 |
+
" residual_pred = model(input_tensor)\n",
|
| 448 |
+
"\n",
|
| 449 |
+
" synth_dem = residual_pred.squeeze().cpu().numpy() + trend\n",
|
| 450 |
+
"\n",
|
| 451 |
+
" profile = src.profile\n",
|
| 452 |
+
" profile['count'] = 1\n",
|
| 453 |
+
" with rasterio.open('/content/drive/MyDrive/DEM_Project/synth_dem_marrakech.tif', 'w', **profile) as dst:\n",
|
| 454 |
+
" dst.write(synth_dem, 1)\n",
|
| 455 |
+
"\n",
|
| 456 |
+
"print('Synthetic DEM generated for Marrakech!')"
|
| 457 |
+
]
|
| 458 |
+
},
|
| 459 |
+
{
|
| 460 |
+
"cell_type": "markdown",
|
| 461 |
+
"id": "8943e470",
|
| 462 |
+
"metadata": {},
|
| 463 |
+
"source": [
|
| 464 |
+
"# Quick correctness checks\n",
|
| 465 |
+
"\n",
|
| 466 |
+
"This section runs a few sanity checks on the trained model and data:\n",
|
| 467 |
+
"\n",
|
| 468 |
+
"- Validate shapes, CRS, and basic channel statistics of `input.tif` and `target.tif`\n",
|
| 469 |
+
"- Compute masked MAE/RMSE on random training crops (McKinley) to gauge training fit\n",
|
| 470 |
+
"- Flag obvious issues (e.g., all-zeros bands, nodata dominance)\n",
|
| 471 |
+
"\n",
|
| 472 |
+
"Run cells in order after training has completed."
|
| 473 |
+
]
|
| 474 |
+
},
|
| 475 |
+
{
|
| 476 |
+
"cell_type": "code",
|
| 477 |
+
"execution_count": null,
|
| 478 |
+
"id": "3c35167d",
|
| 479 |
+
"metadata": {},
|
| 480 |
+
"outputs": [],
|
| 481 |
+
"source": [
|
| 482 |
+
"# Check 1: Inspect input/target rasters (McKinley)\n",
|
| 483 |
+
"import rasterio\n",
|
| 484 |
+
"import numpy as np\n",
|
| 485 |
+
"from pathlib import Path\n",
|
| 486 |
+
"\n",
|
| 487 |
+
"train_dir = Path('/content/drive/MyDrive/DEM_Project/Training_Data/McKinley')\n",
|
| 488 |
+
"input_path = train_dir / 'input.tif'\n",
|
| 489 |
+
"target_path = train_dir / 'target.tif'\n",
|
| 490 |
+
"\n",
|
| 491 |
+
"issues = []\n",
|
| 492 |
+
"\n",
|
| 493 |
+
"with rasterio.open(input_path) as src:\n",
|
| 494 |
+
" print('INPUT:')\n",
|
| 495 |
+
" print({'count': src.count, 'width': src.width, 'height': src.height, 'crs': str(src.crs), 'dtype': src.dtypes})\n",
|
| 496 |
+
" data = src.read(out_dtype='float32')\n",
|
| 497 |
+
" nodata = src.nodata\n",
|
| 498 |
+
" band_stats = []\n",
|
| 499 |
+
" for i in range(src.count):\n",
|
| 500 |
+
" b = data[i]\n",
|
| 501 |
+
" if nodata is not None:\n",
|
| 502 |
+
" mask = b == nodata\n",
|
| 503 |
+
" valid = np.where(mask, np.nan, b)\n",
|
| 504 |
+
" else:\n",
|
| 505 |
+
" valid = b\n",
|
| 506 |
+
" mask = np.zeros_like(b, dtype=bool)\n",
|
| 507 |
+
" s = {\n",
|
| 508 |
+
" 'band': i+1,\n",
|
| 509 |
+
" 'nan_frac': float(np.mean(np.isnan(valid))),\n",
|
| 510 |
+
" 'nodata_frac': float(np.mean(mask)),\n",
|
| 511 |
+
" 'min': float(np.nanmin(valid)),\n",
|
| 512 |
+
" 'max': float(np.nanmax(valid)),\n",
|
| 513 |
+
" 'mean': float(np.nanmean(valid)),\n",
|
| 514 |
+
" 'std': float(np.nanstd(valid)),\n",
|
| 515 |
+
" }\n",
|
| 516 |
+
" band_stats.append(s)\n",
|
| 517 |
+
" print('Input band stats (1:dsm, 2:R, 3:G, 4:B, 5:NIR, 6:NDVI, 7:mask):')\n",
|
| 518 |
+
" for s in band_stats:\n",
|
| 519 |
+
" print(s)\n",
|
| 520 |
+
" # Basic checks\n",
|
| 521 |
+
" if band_stats[5]['min'] < -1.01 or band_stats[5]['max'] > 1.01:\n",
|
| 522 |
+
" issues.append('NDVI out of expected [-1,1] range; check scaling and bands (R,NIR indices).')\n",
|
| 523 |
+
" if band_stats[6]['mean'] < 0.01 and band_stats[6]['max'] < 0.5:\n",
|
| 524 |
+
" issues.append('Mask band appears mostly zeros; ensure mask=1 at nodata pixels, 0 elsewhere.')\n",
|
| 525 |
+
"\n",
|
| 526 |
+
"with rasterio.open(target_path) as src:\n",
|
| 527 |
+
" print('\\nTARGET:')\n",
|
| 528 |
+
" print({'count': src.count, 'width': src.width, 'height': src.height, 'crs': str(src.crs), 'dtype': src.dtypes})\n",
|
| 529 |
+
" t = src.read(1, out_dtype='float32')\n",
|
| 530 |
+
" print({'min': float(np.nanmin(t)), 'max': float(np.nanmax(t)), 'mean': float(np.nanmean(t)), 'std': float(np.nanstd(t))})\n",
|
| 531 |
+
" if np.allclose(t, 0):\n",
|
| 532 |
+
" issues.append('Target residual is all zeros; check HR DEM loading and detrending step.')\n",
|
| 533 |
+
"\n",
|
| 534 |
+
"print('\\nPotential issues:')\n",
|
| 535 |
+
"print(issues if issues else 'None detected')"
|
| 536 |
+
]
|
| 537 |
+
},
|
| 538 |
+
{
|
| 539 |
+
"cell_type": "code",
|
| 540 |
+
"execution_count": null,
|
| 541 |
+
"id": "ad78e542",
|
| 542 |
+
"metadata": {},
|
| 543 |
+
"outputs": [],
|
| 544 |
+
"source": [
|
| 545 |
+
"# Check 2: Compute quick masked MAE/RMSE on random training crops\n",
|
| 546 |
+
"import torch\n",
|
| 547 |
+
"import torch.nn.functional as F\n",
|
| 548 |
+
"from torch.utils.data import DataLoader\n",
|
| 549 |
+
"import numpy as np\n",
|
| 550 |
+
"\n",
|
| 551 |
+
"# Reuse dataset and model classes already defined earlier\n",
|
| 552 |
+
"try:\n",
|
| 553 |
+
" _ = CustomDEMDataset\n",
|
| 554 |
+
"except NameError:\n",
|
| 555 |
+
" raise RuntimeError('CustomDEMDataset not defined; run earlier cells first.')\n",
|
| 556 |
+
"\n",
|
| 557 |
+
"try:\n",
|
| 558 |
+
" _ = DeepDEMRefinement\n",
|
| 559 |
+
"except NameError:\n",
|
| 560 |
+
" raise RuntimeError('DeepDEMRefinement not defined; run training cells first.')\n",
|
| 561 |
+
"\n",
|
| 562 |
+
"# Load model\n",
|
| 563 |
+
"ckpt = '/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt'\n",
|
| 564 |
+
"model = DeepDEMRefinement.load_from_checkpoint(ckpt)\n",
|
| 565 |
+
"model.eval()\n",
|
| 566 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 567 |
+
"model.to(device)\n",
|
| 568 |
+
"\n",
|
| 569 |
+
"# Small eval dataset with deterministic crops\n",
|
| 570 |
+
"np.random.seed(42)\n",
|
| 571 |
+
"transform = A.Compose([\n",
|
| 572 |
+
" A.RandomCrop(height=256, width=256),\n",
|
| 573 |
+
" ToTensorV2()\n",
|
| 574 |
+
"])\n",
|
| 575 |
+
"\n",
|
| 576 |
+
"val_ds = CustomDEMDataset([str(train_dir)], tile_size=256, transform=transform)\n",
|
| 577 |
+
"val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=2)\n",
|
| 578 |
+
"\n",
|
| 579 |
+
"maes, rmses = [], []\n",
|
| 580 |
+
"with torch.no_grad():\n",
|
| 581 |
+
" for i, (x, y) in enumerate(val_loader):\n",
|
| 582 |
+
" if i >= 10: # ~40 tiles\n",
|
| 583 |
+
" break\n",
|
| 584 |
+
" x = x.to(device)\n",
|
| 585 |
+
" y = y.to(device)\n",
|
| 586 |
+
" pred = model(x)\n",
|
| 587 |
+
" # If mask channel included, optionally down-weight masked pixels\n",
|
| 588 |
+
" mask = x[:, 6:7] # channel 7\n",
|
| 589 |
+
" valid = (mask < 0.5).float()\n",
|
| 590 |
+
" diff = (pred - y) * valid\n",
|
| 591 |
+
" denom = valid.sum().clamp_min(1.0)\n",
|
| 592 |
+
" mae = diff.abs().sum() / denom\n",
|
| 593 |
+
" rmse = torch.sqrt((diff.pow(2).sum() / denom))\n",
|
| 594 |
+
" maes.append(mae.item())\n",
|
| 595 |
+
" rmses.append(rmse.item())\n",
|
| 596 |
+
"\n",
|
| 597 |
+
"print({'MAE_mean': float(np.mean(maes)), 'MAE_std': float(np.std(maes)), 'RMSE_mean': float(np.mean(rmses)), 'RMSE_std': float(np.std(rmses)), 'tiles': len(maes)*val_loader.batch_size})\n",
|
| 598 |
+
"\n",
|
| 599 |
+
"if np.mean(rmses) > 8.0:\n",
|
| 600 |
+
" print('Warning: High RMSE for residuals. Training may be underfit or target scaling may be off.')\n",
|
| 601 |
+
"else:\n",
|
| 602 |
+
" print('Residual error looks reasonable for the training run.')"
|
| 603 |
+
]
|
| 604 |
+
},
|
| 605 |
+
{
|
| 606 |
+
"cell_type": "code",
|
| 607 |
+
"execution_count": null,
|
| 608 |
+
"id": "def70a98",
|
| 609 |
+
"metadata": {},
|
| 610 |
+
"outputs": [],
|
| 611 |
+
"source": [
|
| 612 |
+
"# Check 3: Sanity-check inference output alignment vs input for Marrakech\n",
|
| 613 |
+
"from scipy.ndimage import gaussian_filter\n",
|
| 614 |
+
"\n",
|
| 615 |
+
"marrakech_input = '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech/input.tif'\n",
|
| 616 |
+
"marrakech_out = '/content/drive/MyDrive/DEM_Project/synth_dem_marrakech.tif'\n",
|
| 617 |
+
"\n",
|
| 618 |
+
"with rasterio.open(marrakech_input) as src_in, rasterio.open(marrakech_out) as src_out:\n",
|
| 619 |
+
" print('INFERENCE INPUT:', {'shape': (src_in.count, src_in.height, src_in.width), 'crs': str(src_in.crs), 'transform': tuple(src_in.transform)})\n",
|
| 620 |
+
" print('SYNTH OUTPUT:', {'shape': (src_out.count, src_out.height, src_out.width), 'crs': str(src_out.crs), 'transform': tuple(src_out.transform)})\n",
|
| 621 |
+
" if src_in.crs != src_out.crs:\n",
|
| 622 |
+
" print('Warning: CRS mismatch between input and output!')\n",
|
| 623 |
+
" if (src_in.height != src_out.height) or (src_in.width != src_out.width):\n",
|
| 624 |
+
" print('Warning: Dimension mismatch between input and output!')\n",
|
| 625 |
+
"\n",
|
| 626 |
+
" out_dem = src_out.read(1).astype('float32')\n",
|
| 627 |
+
" # Simple terrain sanity: residual-added trend should correlate with SRTM trend\n",
|
| 628 |
+
" srtm = src_in.read(1).astype('float32')\n",
|
| 629 |
+
" trend = gaussian_filter(srtm, sigma=5)\n",
|
| 630 |
+
" corr = np.corrcoef(trend.flatten(), out_dem.flatten())[0,1]\n",
|
| 631 |
+
" print('Correlation between SRTM trend and synthetic DEM:', float(corr))\n",
|
| 632 |
+
" if corr < 0.5:\n",
|
| 633 |
+
" print('Low correlation; output may be noisy or misaligned.')"
|
| 634 |
+
]
|
| 635 |
+
}
|
| 636 |
+
],
|
| 637 |
+
"metadata": {
|
| 638 |
+
"language_info": {
|
| 639 |
+
"name": "python"
|
| 640 |
+
}
|
| 641 |
+
},
|
| 642 |
+
"nbformat": 4,
|
| 643 |
+
"nbformat_minor": 5
|
| 644 |
+
}
|