tommonopoli commited on
Commit
7829054
Β·
1 Parent(s): 46818a5

update readme

Browse files
Files changed (2) hide show
  1. README.md +3 -1
  2. lib/utils.py +35 -19
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: GRowSeg demo
3
  emoji: πŸ‡
4
  colorFrom: indigo
5
  colorTo: gray
@@ -14,6 +14,8 @@ models:
14
  - links-ads/gaia-growseg
15
  datasets:
16
  - links-ads/gaia-vineyard-uav-dataset
 
 
17
  tags:
18
  - agriculture
19
  - viticulture
 
1
  ---
2
+ title: Vineyard Row Segmentation
3
  emoji: πŸ‡
4
  colorFrom: indigo
5
  colorTo: gray
 
14
  - links-ads/gaia-growseg
15
  datasets:
16
  - links-ads/gaia-vineyard-uav-dataset
17
+ preload_from_hub:
18
+ - links-ads/gaia-growseg model.safetensors
19
  tags:
20
  - agriculture
21
  - viticulture
lib/utils.py CHANGED
@@ -80,36 +80,52 @@ def pad(img, pad, order='CHW'):
80
  return padded_img
81
 
82
 
83
- def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True):
84
  """Extract patches from an image, in the format (h_start, h_end, w_start, w_end)"""
85
  assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
86
-
 
 
 
 
87
  if order == 'HWC':
88
  H, W = img.shape[:2]
89
  else:
90
  H, W = img.shape[1:]
91
 
92
- # compute the number of patches
93
- n_patches = ((H - patch_size) // stride + 1) * ((W - patch_size) // stride + 1)
 
94
 
95
- # extract patches
96
- patches = []
97
  patches_idx = []
98
- for i in range(0, H-patch_size+1, stride):
99
- for j in range(0, W-patch_size+1, stride):
100
-
101
- patches_idx.append((i, i+patch_size, j, j+patch_size))
102
-
103
- if not only_return_idx:
104
- if order == 'HWC':
105
- patch = img[i:i+patch_size, j:j+patch_size, :]
106
- else:
107
- patch = img[:, i:i+patch_size, j:j+patch_size]
108
- patches.append(patch)
109
-
 
 
 
110
  if only_return_idx:
111
  return patches_idx
112
- return patches, patches_idx
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  def segment_batch(batch, model):
 
80
  return padded_img
81
 
82
 
83
+ def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True, include_last=True):
84
  """Extract patches from an image, in the format (h_start, h_end, w_start, w_end)"""
85
  assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
86
+ assert len(img.shape) == 3, f"Got image with {len(img.shape)} dimensions, expected 3 dimensions (C,H,W) or (H,W,C)"
87
+ assert img.shape[0] >= patch_size, f"Got image with height {img.shape[0]}, expected at least {patch_size}. Maybe apply padding first?"
88
+ assert img.shape[1] >= patch_size, f"Got image with width {img.shape[1]}, expected at least {patch_size}. Maybe apply padding first?"
89
+
90
+ # Get image height and width
91
  if order == 'HWC':
92
  H, W = img.shape[:2]
93
  else:
94
  H, W = img.shape[1:]
95
 
96
+ # Compute the number of "proper" patches in each dimension
97
+ n_patches_H = (H - patch_size) // stride + 1
98
+ n_patches_W = (W - patch_size) // stride + 1
99
 
100
+ # Extract patches indices
 
101
  patches_idx = []
102
+ for i in range(n_patches_H): # iterate over height
103
+ for j in range(n_patches_W): # iterate over width
104
+
105
+ # Get the current patch indices
106
+ patches_idx.append((i*stride, i*stride+patch_size, j*stride, j*stride+patch_size)) # (top, bottom, left, right)
107
+
108
+ # Include leftmost and lowermost patch if needed
109
+ if include_last:
110
+ if j == n_patches_W-1 and j*stride+patch_size < W:
111
+ patches_idx.append((i*stride, i*stride+patch_size, W-patch_size, W))
112
+ if i == n_patches_H-1 and i*stride+patch_size < H:
113
+ patches_idx.append((H-patch_size, H, j*stride, j*stride+patch_size))
114
+ if i == n_patches_H-1 and j == n_patches_W-1 and i*stride+patch_size < H and j*stride+patch_size < W:
115
+ patches_idx.append((H-patch_size, H, W-patch_size, W))
116
+
117
  if only_return_idx:
118
  return patches_idx
119
+ else:
120
+ # Extract patches
121
+ patches = []
122
+ for t,b,l,r in patches_idx:
123
+ if order == 'HWC':
124
+ patch = img[t:b, l:r, :]
125
+ else:
126
+ patch = img[:, t:b, l:r]
127
+ patches.append(patch)
128
+ return patches, patches_idx
129
 
130
 
131
  def segment_batch(batch, model):