blumenstiel commited on
Commit
82c6c69
·
1 Parent(s): 88c4554

Fix inference code

Browse files

Signed-off-by: Benedikt Blumenstiel <[email protected]>

Files changed (2) hide show
  1. config.yaml +1 -30
  2. inference.py +2 -2
config.yaml CHANGED
@@ -6,18 +6,7 @@ trainer:
6
  devices: auto
7
  num_nodes: 1
8
  precision: 16-mixed
9
- logger:
10
- class_path: lightning.pytorch.loggers.TensorBoardLogger
11
- init_args:
12
- save_dir: /dccstor/geofm-finetuning/benchmark-geo-bench-paolo/
13
- name: test2
14
- log_graph: false
15
- default_hp_metric: true
16
- prefix: ''
17
- comment: ''
18
- max_queue: 10
19
- flush_secs: 120
20
- filename_suffix: ''
21
  callbacks:
22
  - class_path: lightning.pytorch.callbacks.RichProgressBar
23
  init_args:
@@ -171,24 +160,6 @@ data:
171
  use_metadata: false
172
  out_dtype: int16
173
  deploy_config_file: true
174
- ModelCheckpoint:
175
- filename: '{epoch}'
176
- monitor: val/loss
177
- verbose: false
178
- save_top_k: 1
179
- mode: min
180
- save_weights_only: false
181
- auto_insert_metric_name: true
182
- enable_version_counter: true
183
- StateDictModelCheckpoint:
184
- filename: '{epoch}_state_dict'
185
- monitor: val/loss
186
- verbose: false
187
- save_top_k: 1
188
- mode: min
189
- save_weights_only: true
190
- auto_insert_metric_name: true
191
- enable_version_counter: true
192
  optimizer:
193
  class_path: torch.optim.AdamW
194
  init_args:
 
6
  devices: auto
7
  num_nodes: 1
8
  precision: 16-mixed
9
+ logger: true
 
 
 
 
 
 
 
 
 
 
 
10
  callbacks:
11
  - class_path: lightning.pytorch.callbacks.RichProgressBar
12
  init_args:
 
160
  use_metadata: false
161
  out_dtype: int16
162
  deploy_config_file: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  optimizer:
164
  class_path: torch.optim.AdamW
165
  init_args:
inference.py CHANGED
@@ -185,7 +185,7 @@ def run_model(input_data, temporal_coords, location_coords, model, datamodule, i
185
  for x in windows:
186
  # Apply standardization
187
  x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0))
188
- x = datamodule.aug(x)['image']
189
 
190
  with torch.no_grad():
191
  x = x.to(model.device)
@@ -317,7 +317,7 @@ if __name__ == "__main__":
317
  parser.add_argument(
318
  "--checkpoint",
319
  type=str,
320
- default="Prithvi-EO-V2-300M-TL-Sen1Floods11.ckpt",
321
  help="Path to a checkpoint file to load from.",
322
  )
323
  parser.add_argument(
 
185
  for x in windows:
186
  # Apply standardization
187
  x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0))
188
+ x = datamodule.aug(x['image'])
189
 
190
  with torch.no_grad():
191
  x = x.to(model.device)
 
317
  parser.add_argument(
318
  "--checkpoint",
319
  type=str,
320
+ default="Prithvi-EO-V2-300M-TL-Sen1Floods11.pt",
321
  help="Path to a checkpoint file to load from.",
322
  )
323
  parser.add_argument(