Spaces:
Starting
Starting
Commit
Β·
7829054
1
Parent(s):
46818a5
update readme
Browse files- README.md +3 -1
- lib/utils.py +35 -19
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: π
|
4 |
colorFrom: indigo
|
5 |
colorTo: gray
|
@@ -14,6 +14,8 @@ models:
|
|
14 |
- links-ads/gaia-growseg
|
15 |
datasets:
|
16 |
- links-ads/gaia-vineyard-uav-dataset
|
|
|
|
|
17 |
tags:
|
18 |
- agriculture
|
19 |
- viticulture
|
|
|
1 |
---
|
2 |
+
title: Vineyard Row Segmentation
|
3 |
emoji: π
|
4 |
colorFrom: indigo
|
5 |
colorTo: gray
|
|
|
14 |
- links-ads/gaia-growseg
|
15 |
datasets:
|
16 |
- links-ads/gaia-vineyard-uav-dataset
|
17 |
+
preload_from_hub:
|
18 |
+
- links-ads/gaia-growseg model.safetensors
|
19 |
tags:
|
20 |
- agriculture
|
21 |
- viticulture
|
lib/utils.py
CHANGED
@@ -80,36 +80,52 @@ def pad(img, pad, order='CHW'):
|
|
80 |
return padded_img
|
81 |
|
82 |
|
83 |
-
def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True):
|
84 |
"""Extract patches from an image, in the format (h_start, h_end, w_start, w_end)"""
|
85 |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
if order == 'HWC':
|
88 |
H, W = img.shape[:2]
|
89 |
else:
|
90 |
H, W = img.shape[1:]
|
91 |
|
92 |
-
#
|
93 |
-
|
|
|
94 |
|
95 |
-
#
|
96 |
-
patches = []
|
97 |
patches_idx = []
|
98 |
-
for i in range(
|
99 |
-
for j in range(
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
110 |
if only_return_idx:
|
111 |
return patches_idx
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
|
115 |
def segment_batch(batch, model):
|
|
|
80 |
return padded_img
|
81 |
|
82 |
|
83 |
+
def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True, include_last=True):
|
84 |
"""Extract patches from an image, in the format (h_start, h_end, w_start, w_end)"""
|
85 |
assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
|
86 |
+
assert len(img.shape) == 3, f"Got image with {len(img.shape)} dimensions, expected 3 dimensions (C,H,W) or (H,W,C)"
|
87 |
+
assert img.shape[0] >= patch_size, f"Got image with height {img.shape[0]}, expected at least {patch_size}. Maybe apply padding first?"
|
88 |
+
assert img.shape[1] >= patch_size, f"Got image with width {img.shape[1]}, expected at least {patch_size}. Maybe apply padding first?"
|
89 |
+
|
90 |
+
# Get image height and width
|
91 |
if order == 'HWC':
|
92 |
H, W = img.shape[:2]
|
93 |
else:
|
94 |
H, W = img.shape[1:]
|
95 |
|
96 |
+
# Compute the number of "proper" patches in each dimension
|
97 |
+
n_patches_H = (H - patch_size) // stride + 1
|
98 |
+
n_patches_W = (W - patch_size) // stride + 1
|
99 |
|
100 |
+
# Extract patches indices
|
|
|
101 |
patches_idx = []
|
102 |
+
for i in range(n_patches_H): # iterate over height
|
103 |
+
for j in range(n_patches_W): # iterate over width
|
104 |
+
|
105 |
+
# Get the current patch indices
|
106 |
+
patches_idx.append((i*stride, i*stride+patch_size, j*stride, j*stride+patch_size)) # (top, bottom, left, right)
|
107 |
+
|
108 |
+
# Include leftmost and lowermost patch if needed
|
109 |
+
if include_last:
|
110 |
+
if j == n_patches_W-1 and j*stride+patch_size < W:
|
111 |
+
patches_idx.append((i*stride, i*stride+patch_size, W-patch_size, W))
|
112 |
+
if i == n_patches_H-1 and i*stride+patch_size < H:
|
113 |
+
patches_idx.append((H-patch_size, H, j*stride, j*stride+patch_size))
|
114 |
+
if i == n_patches_H-1 and j == n_patches_W-1 and i*stride+patch_size < H and j*stride+patch_size < W:
|
115 |
+
patches_idx.append((H-patch_size, H, W-patch_size, W))
|
116 |
+
|
117 |
if only_return_idx:
|
118 |
return patches_idx
|
119 |
+
else:
|
120 |
+
# Extract patches
|
121 |
+
patches = []
|
122 |
+
for t,b,l,r in patches_idx:
|
123 |
+
if order == 'HWC':
|
124 |
+
patch = img[t:b, l:r, :]
|
125 |
+
else:
|
126 |
+
patch = img[:, t:b, l:r]
|
127 |
+
patches.append(patch)
|
128 |
+
return patches, patches_idx
|
129 |
|
130 |
|
131 |
def segment_batch(batch, model):
|