ZhengPeng7 commited on
Commit
ce634b2
·
1 Parent(s): e391bd4

Acceleration on the refine_foreground.

Browse files
Files changed (2) hide show
  1. app.py +90 -25
  2. app_local.py +90 -25
app.py CHANGED
@@ -9,6 +9,7 @@ from glob import glob
9
  from typing import Tuple
10
 
11
  from PIL import Image
 
12
  from torchvision import transforms
13
 
14
  import requests
@@ -27,39 +28,103 @@ torch.jit.script = lambda f: f
27
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
- ### image_proc.py
31
- def refine_foreground(image, mask, r=90):
32
- if mask.size != image.size:
33
- mask = mask.resize(image.size)
34
- image = np.array(image) / 255.0
35
- mask = np.array(mask) / 255.0
36
- estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
37
- image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
38
- return image_masked
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
 
42
  # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
43
  alpha = alpha[:, :, None]
44
- F, blur_B = FB_blur_fusion_foreground_estimator(
45
- image, image, image, alpha, r)
46
- return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
48
 
49
- def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
50
- if isinstance(image, Image.Image):
51
- image = np.array(image) / 255.0
52
- blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
53
 
54
- blurred_FA = cv2.blur(F * alpha, (r, r))
55
- blurred_F = blurred_FA / (blurred_alpha + 1e-5)
56
 
57
- blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
 
 
 
 
 
 
 
 
 
 
 
 
58
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
59
- F = blurred_F + alpha * \
60
- (image - alpha * blurred_F - (1 - alpha) * blurred_B)
61
- F = np.clip(F, 0, 1)
62
- return F, blurred_B
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  class ImagePreprocessor():
@@ -167,7 +232,7 @@ def predict(images, resolution, weights_file):
167
 
168
  # Show Results
169
  pred_pil = transforms.ToPILImage()(pred)
170
- image_masked = refine_foreground(image, pred_pil)
171
  image_masked.putalpha(pred_pil.resize(image.size))
172
 
173
  torch.cuda.empty_cache()
 
9
  from typing import Tuple
10
 
11
  from PIL import Image
12
+ import torch
13
  from torchvision import transforms
14
 
15
  import requests
 
28
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
 
 
 
 
 
 
 
 
 
31
 
32
+ ## CPU version refinement
33
+ def FB_blur_fusion_foreground_estimator_cpu(image, FG, B, alpha, r=90):
34
+ if isinstance(image, Image.Image):
35
+ image = np.array(image) / 255.0
36
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
37
+
38
+ blurred_FGA = cv2.blur(FG * alpha, (r, r))
39
+ blurred_FG = blurred_FGA / (blurred_alpha + 1e-5)
40
+
41
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
42
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
43
+ FG = blurred_FG + alpha * (image - alpha * blurred_FG - (1 - alpha) * blurred_B)
44
+ FG = np.clip(FG, 0, 1)
45
+ return FG, blurred_B
46
 
47
+
48
+ def FB_blur_fusion_foreground_estimator_cpu_2(image, alpha, r=90):
49
  # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
50
  alpha = alpha[:, :, None]
51
+ FG, blur_B = FB_blur_fusion_foreground_estimator_cpu(image, image, image, alpha, r)
52
+ return FB_blur_fusion_foreground_estimator_cpu(image, FG, blur_B, alpha, r=6)[0]
53
+
54
+
55
+ ## GPU version refinement
56
+ def mean_blur(x, kernel_size):
57
+ """
58
+ equivalent to cv.blur
59
+ x: [B, C, H, W]
60
+ """
61
+ if kernel_size % 2 == 0:
62
+ pad_l = kernel_size // 2 - 1
63
+ pad_r = kernel_size // 2
64
+ pad_t = kernel_size // 2 - 1
65
+ pad_b = kernel_size // 2
66
+ else:
67
+ pad_l = pad_r = pad_t = pad_b = kernel_size // 2
68
 
69
+ x_padded = torch.nn.functional.pad(x, (pad_l, pad_r, pad_t, pad_b), mode='replicate')
70
 
71
+ return torch.nn.functional.avg_pool2d(x_padded, kernel_size=(kernel_size, kernel_size), stride=1, count_include_pad=False)
 
 
 
72
 
73
+ def FB_blur_fusion_foreground_estimator_gpu(image, FG, B, alpha, r=90):
74
+ as_dtype = lambda x, dtype: x.to(dtype) if x.dtype != dtype else x
75
 
76
+ input_dtype = image.dtype
77
+ # convert image to float to avoid overflow
78
+ image = as_dtype(image, torch.float32)
79
+ FG = as_dtype(FG, torch.float32)
80
+ B = as_dtype(B, torch.float32)
81
+ alpha = as_dtype(alpha, torch.float32)
82
+
83
+ blurred_alpha = mean_blur(alpha, kernel_size=r)
84
+
85
+ blurred_FGA = mean_blur(FG * alpha, kernel_size=r)
86
+ blurred_FG = blurred_FGA / (blurred_alpha + 1e-5)
87
+
88
+ blurred_B1A = mean_blur(B * (1 - alpha), kernel_size=r)
89
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
90
+
91
+ FG_output = blurred_FG + alpha * (image - alpha * blurred_FG - (1 - alpha) * blurred_B)
92
+ FG_output = torch.clamp(FG_output, 0, 1)
93
+
94
+ return as_dtype(FG_output, input_dtype), as_dtype(blurred_B, input_dtype)
95
+
96
+
97
+ def FB_blur_fusion_foreground_estimator_gpu_2(image, alpha, r=90):
98
+ # Thanks to the source: https://github.com/ZhengPeng7/BiRefNet/issues/226#issuecomment-3016433728
99
+ FG, blur_B = FB_blur_fusion_foreground_estimator_gpu(image, image, image, alpha, r)
100
+ return FB_blur_fusion_foreground_estimator_gpu(image, FG, blur_B, alpha, r=6)[0]
101
+
102
+
103
+ def refine_foreground(image, mask, r=90, device='cuda'):
104
+ """both image and mask are in range of [0, 1]"""
105
+ if mask.size != image.size:
106
+ mask = mask.resize(image.size)
107
+
108
+ if device == 'cuda':
109
+ image = transforms.functional.to_tensor(image).float().cuda()
110
+ mask = transforms.functional.to_tensor(mask).float().cuda()
111
+ image = image.unsqueeze(0)
112
+ mask = mask.unsqueeze(0)
113
+
114
+ estimated_foreground = FB_blur_fusion_foreground_estimator_gpu_2(image, mask, r=r)
115
+
116
+ estimated_foreground = estimated_foreground.squeeze()
117
+ estimated_foreground = (estimated_foreground.mul(255.0)).to(torch.uint8)
118
+ estimated_foreground = estimated_foreground.permute(1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
119
+ else:
120
+ image = np.array(image, dtype=np.float32) / 255.0
121
+ mask = np.array(mask, dtype=np.float32) / 255.0
122
+ estimated_foreground = FB_blur_fusion_foreground_estimator_cpu_2(image, mask, r=r)
123
+ estimated_foreground = (estimated_foreground * 255.0).astype(np.uint8)
124
+
125
+ estimated_foreground = Image.fromarray(np.ascontiguousarray(estimated_foreground))
126
+
127
+ return estimated_foreground
128
 
129
 
130
  class ImagePreprocessor():
 
232
 
233
  # Show Results
234
  pred_pil = transforms.ToPILImage()(pred)
235
+ image_masked = refine_foreground(image, pred_pil, device=device)
236
  image_masked.putalpha(pred_pil.resize(image.size))
237
 
238
  torch.cuda.empty_cache()
app_local.py CHANGED
@@ -11,6 +11,7 @@ from typing import Tuple
11
  from PIL import Image
12
  # from gradio_imageslider import ImageSlider
13
  import transformers
 
14
  from torchvision import transforms
15
 
16
  import requests
@@ -23,39 +24,103 @@ torch.set_float32_matmul_precision('high')
23
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
- ### image_proc.py
27
- def refine_foreground(image, mask, r=90):
28
- if mask.size != image.size:
29
- mask = mask.resize(image.size)
30
- image = np.array(image) / 255.0
31
- mask = np.array(mask) / 255.0
32
- estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
33
- image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
34
- return image_masked
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
 
38
  # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
39
  alpha = alpha[:, :, None]
40
- F, blur_B = FB_blur_fusion_foreground_estimator(
41
- image, image, image, alpha, r)
42
- return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
44
 
45
- def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
46
- if isinstance(image, Image.Image):
47
- image = np.array(image) / 255.0
48
- blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
49
 
50
- blurred_FA = cv2.blur(F * alpha, (r, r))
51
- blurred_F = blurred_FA / (blurred_alpha + 1e-5)
52
 
53
- blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
 
 
 
 
 
 
 
 
 
 
 
 
54
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
55
- F = blurred_F + alpha * \
56
- (image - alpha * blurred_F - (1 - alpha) * blurred_B)
57
- F = np.clip(F, 0, 1)
58
- return F, blurred_B
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  class ImagePreprocessor():
@@ -163,7 +228,7 @@ def predict(images, resolution, weights_file):
163
 
164
  # Show Results
165
  pred_pil = transforms.ToPILImage()(pred)
166
- image_masked = refine_foreground(image, pred_pil)
167
  image_masked.putalpha(pred_pil.resize(image.size))
168
 
169
  torch.cuda.empty_cache()
 
11
  from PIL import Image
12
  # from gradio_imageslider import ImageSlider
13
  import transformers
14
+ import torch
15
  from torchvision import transforms
16
 
17
  import requests
 
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
 
 
 
 
 
 
 
 
 
 
27
 
28
+ ## CPU version refinement
29
+ def FB_blur_fusion_foreground_estimator_cpu(image, FG, B, alpha, r=90):
30
+ if isinstance(image, Image.Image):
31
+ image = np.array(image) / 255.0
32
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
33
+
34
+ blurred_FGA = cv2.blur(FG * alpha, (r, r))
35
+ blurred_FG = blurred_FGA / (blurred_alpha + 1e-5)
36
+
37
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
38
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
39
+ FG = blurred_FG + alpha * (image - alpha * blurred_FG - (1 - alpha) * blurred_B)
40
+ FG = np.clip(FG, 0, 1)
41
+ return FG, blurred_B
42
 
43
+
44
+ def FB_blur_fusion_foreground_estimator_cpu_2(image, alpha, r=90):
45
  # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
46
  alpha = alpha[:, :, None]
47
+ FG, blur_B = FB_blur_fusion_foreground_estimator_cpu(image, image, image, alpha, r)
48
+ return FB_blur_fusion_foreground_estimator_cpu(image, FG, blur_B, alpha, r=6)[0]
49
+
50
+
51
+ ## GPU version refinement
52
+ def mean_blur(x, kernel_size):
53
+ """
54
+ equivalent to cv.blur
55
+ x: [B, C, H, W]
56
+ """
57
+ if kernel_size % 2 == 0:
58
+ pad_l = kernel_size // 2 - 1
59
+ pad_r = kernel_size // 2
60
+ pad_t = kernel_size // 2 - 1
61
+ pad_b = kernel_size // 2
62
+ else:
63
+ pad_l = pad_r = pad_t = pad_b = kernel_size // 2
64
 
65
+ x_padded = torch.nn.functional.pad(x, (pad_l, pad_r, pad_t, pad_b), mode='replicate')
66
 
67
+ return torch.nn.functional.avg_pool2d(x_padded, kernel_size=(kernel_size, kernel_size), stride=1, count_include_pad=False)
 
 
 
68
 
69
+ def FB_blur_fusion_foreground_estimator_gpu(image, FG, B, alpha, r=90):
70
+ as_dtype = lambda x, dtype: x.to(dtype) if x.dtype != dtype else x
71
 
72
+ input_dtype = image.dtype
73
+ # convert image to float to avoid overflow
74
+ image = as_dtype(image, torch.float32)
75
+ FG = as_dtype(FG, torch.float32)
76
+ B = as_dtype(B, torch.float32)
77
+ alpha = as_dtype(alpha, torch.float32)
78
+
79
+ blurred_alpha = mean_blur(alpha, kernel_size=r)
80
+
81
+ blurred_FGA = mean_blur(FG * alpha, kernel_size=r)
82
+ blurred_FG = blurred_FGA / (blurred_alpha + 1e-5)
83
+
84
+ blurred_B1A = mean_blur(B * (1 - alpha), kernel_size=r)
85
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
86
+
87
+ FG_output = blurred_FG + alpha * (image - alpha * blurred_FG - (1 - alpha) * blurred_B)
88
+ FG_output = torch.clamp(FG_output, 0, 1)
89
+
90
+ return as_dtype(FG_output, input_dtype), as_dtype(blurred_B, input_dtype)
91
+
92
+
93
+ def FB_blur_fusion_foreground_estimator_gpu_2(image, alpha, r=90):
94
+ # Thanks to the source: https://github.com/ZhengPeng7/BiRefNet/issues/226#issuecomment-3016433728
95
+ FG, blur_B = FB_blur_fusion_foreground_estimator_gpu(image, image, image, alpha, r)
96
+ return FB_blur_fusion_foreground_estimator_gpu(image, FG, blur_B, alpha, r=6)[0]
97
+
98
+
99
+ def refine_foreground(image, mask, r=90, device='cuda'):
100
+ """both image and mask are in range of [0, 1]"""
101
+ if mask.size != image.size:
102
+ mask = mask.resize(image.size)
103
+
104
+ if device == 'cuda':
105
+ image = transforms.functional.to_tensor(image).float().cuda()
106
+ mask = transforms.functional.to_tensor(mask).float().cuda()
107
+ image = image.unsqueeze(0)
108
+ mask = mask.unsqueeze(0)
109
+
110
+ estimated_foreground = FB_blur_fusion_foreground_estimator_gpu_2(image, mask, r=r)
111
+
112
+ estimated_foreground = estimated_foreground.squeeze()
113
+ estimated_foreground = (estimated_foreground.mul(255.0)).to(torch.uint8)
114
+ estimated_foreground = estimated_foreground.permute(1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
115
+ else:
116
+ image = np.array(image, dtype=np.float32) / 255.0
117
+ mask = np.array(mask, dtype=np.float32) / 255.0
118
+ estimated_foreground = FB_blur_fusion_foreground_estimator_cpu_2(image, mask, r=r)
119
+ estimated_foreground = (estimated_foreground * 255.0).astype(np.uint8)
120
+
121
+ estimated_foreground = Image.fromarray(np.ascontiguousarray(estimated_foreground))
122
+
123
+ return estimated_foreground
124
 
125
 
126
  class ImagePreprocessor():
 
228
 
229
  # Show Results
230
  pred_pil = transforms.ToPILImage()(pred)
231
+ image_masked = refine_foreground(image, pred_pil, device=device)
232
  image_masked.putalpha(pred_pil.resize(image.size))
233
 
234
  torch.cuda.empty_cache()