Commit
·
82c6c69
1
Parent(s):
88c4554
Fix inference code
Browse filesSigned-off-by: Benedikt Blumenstiel <[email protected]>
- config.yaml +1 -30
- inference.py +2 -2
config.yaml
CHANGED
@@ -6,18 +6,7 @@ trainer:
|
|
6 |
devices: auto
|
7 |
num_nodes: 1
|
8 |
precision: 16-mixed
|
9 |
-
logger:
|
10 |
-
class_path: lightning.pytorch.loggers.TensorBoardLogger
|
11 |
-
init_args:
|
12 |
-
save_dir: /dccstor/geofm-finetuning/benchmark-geo-bench-paolo/
|
13 |
-
name: test2
|
14 |
-
log_graph: false
|
15 |
-
default_hp_metric: true
|
16 |
-
prefix: ''
|
17 |
-
comment: ''
|
18 |
-
max_queue: 10
|
19 |
-
flush_secs: 120
|
20 |
-
filename_suffix: ''
|
21 |
callbacks:
|
22 |
- class_path: lightning.pytorch.callbacks.RichProgressBar
|
23 |
init_args:
|
@@ -171,24 +160,6 @@ data:
|
|
171 |
use_metadata: false
|
172 |
out_dtype: int16
|
173 |
deploy_config_file: true
|
174 |
-
ModelCheckpoint:
|
175 |
-
filename: '{epoch}'
|
176 |
-
monitor: val/loss
|
177 |
-
verbose: false
|
178 |
-
save_top_k: 1
|
179 |
-
mode: min
|
180 |
-
save_weights_only: false
|
181 |
-
auto_insert_metric_name: true
|
182 |
-
enable_version_counter: true
|
183 |
-
StateDictModelCheckpoint:
|
184 |
-
filename: '{epoch}_state_dict'
|
185 |
-
monitor: val/loss
|
186 |
-
verbose: false
|
187 |
-
save_top_k: 1
|
188 |
-
mode: min
|
189 |
-
save_weights_only: true
|
190 |
-
auto_insert_metric_name: true
|
191 |
-
enable_version_counter: true
|
192 |
optimizer:
|
193 |
class_path: torch.optim.AdamW
|
194 |
init_args:
|
|
|
6 |
devices: auto
|
7 |
num_nodes: 1
|
8 |
precision: 16-mixed
|
9 |
+
logger: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
callbacks:
|
11 |
- class_path: lightning.pytorch.callbacks.RichProgressBar
|
12 |
init_args:
|
|
|
160 |
use_metadata: false
|
161 |
out_dtype: int16
|
162 |
deploy_config_file: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
optimizer:
|
164 |
class_path: torch.optim.AdamW
|
165 |
init_args:
|
inference.py
CHANGED
@@ -185,7 +185,7 @@ def run_model(input_data, temporal_coords, location_coords, model, datamodule, i
|
|
185 |
for x in windows:
|
186 |
# Apply standardization
|
187 |
x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0))
|
188 |
-
x = datamodule.aug(x
|
189 |
|
190 |
with torch.no_grad():
|
191 |
x = x.to(model.device)
|
@@ -317,7 +317,7 @@ if __name__ == "__main__":
|
|
317 |
parser.add_argument(
|
318 |
"--checkpoint",
|
319 |
type=str,
|
320 |
-
default="Prithvi-EO-V2-300M-TL-Sen1Floods11.
|
321 |
help="Path to a checkpoint file to load from.",
|
322 |
)
|
323 |
parser.add_argument(
|
|
|
185 |
for x in windows:
|
186 |
# Apply standardization
|
187 |
x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0))
|
188 |
+
x = datamodule.aug(x['image'])
|
189 |
|
190 |
with torch.no_grad():
|
191 |
x = x.to(model.device)
|
|
|
317 |
parser.add_argument(
|
318 |
"--checkpoint",
|
319 |
type=str,
|
320 |
+
default="Prithvi-EO-V2-300M-TL-Sen1Floods11.pt",
|
321 |
help="Path to a checkpoint file to load from.",
|
322 |
)
|
323 |
parser.add_argument(
|