File size: 2,523 Bytes
67a3943 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import onnxruntime as ort
from typing import List, Tuple, Any, Dict
from pathlib import Path
import numpy as np
from croplands.io import read_zarr, read_zarr_profile
from croplands.utils import impute_nan, normalize_s2
from croplands.polygonize import polygonize_raster
import json
from skimage import measure
class CroplandHandler():
def __init__(self, input_dir: str, output_dir: str, device: str = "cpu") -> None:
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
assert self.input_dir.exists(), "Input directory doesn't exist"
assert self.output_dir.exists(), "Output directory doesn't exist"
assert device == "cpu" or device.startswith("cuda"), f"{device} is not a valid device."
mdoel_path = "model_repository/utae.onnx"
provider = "CUDAExecutionProvider" if device.startswith("cuda") else "CPUExecutionProvider"
self.session = ort.InferenceSession(str(mdoel_path), providers=[provider])
with open("months_per_patch.json") as dates:
self.dates = json.load(dates)
def preprocess(self, file: str) -> Tuple[np.array, Dict, np.array]:
assert file is not None, "Missing input file for inference"
file_path = self.input_dir / file
data = read_zarr(file_path)
data = impute_nan(data)
data = normalize_s2(data)
profile = read_zarr_profile(file_path)
dates = self.dates[file_path.stem]
batch = np.expand_dims(data,axis=0)
dates = np.expand_dims(np.array(dates),axis=0)
return batch, profile, dates
def postprocess(self, outputs: Any, file: str, profile: Dict, save_raster: bool = False) -> np.array:
outputs = np.array(outputs)
if save_raster:
out_class = np.argmax(outputs[0][0], axis=0)
out_bin = (out_class!=0).astype(np.uint8)
components = measure.label(out_bin, connectivity=1)
gdf = polygonize_raster(out_class, components, tolerance = 0.0001, transform= profile["transform"],
crs=profile["crs"])
data_path = self.input_dir / file
save_path = self.output_dir / (data_path.stem + ".parquet")
gdf.to_parquet(save_path)
return outputs
def predict(self, files: List[str], save_raster: bool = False) -> np.array:
# Preprocessing
batch, profiles, dates = self.preprocess(files)
# Inference
outputs = self.session.run(None, {"input": batch, "batch_positions": dates})
# Postprocessing
outputs = self.postprocess(outputs, files, profiles, save_raster)
return outputs
|