Spaces:
Build error
Build error
Commit
·
88bec8b
1
Parent(s):
7f88488
Update utils/shared_utils.py
Browse files- utils/shared_utils.py +4 -2
utils/shared_utils.py
CHANGED
@@ -21,9 +21,9 @@ from utils.photo_smooth import Propagator
|
|
21 |
root = Path.cwd()
|
22 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
# Load model
|
24 |
-
p_wct = PhotoWCT()
|
25 |
p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth"))
|
26 |
-
p_pro = Propagator()
|
27 |
stylization_module=p_wct
|
28 |
smoothing_module=p_pro
|
29 |
|
@@ -115,6 +115,8 @@ def style_transfer(cont_img,styl_img):
|
|
115 |
return stylized_img
|
116 |
|
117 |
def smoother(stylized_img, over_img):
|
|
|
|
|
118 |
final_img = smoothing_module.process(stylized_img, over_img)
|
119 |
#final_img = smooth_filter(stylized_img, over_img, f_radius=15, f_edge=1e-1)
|
120 |
return final_img
|
|
|
21 |
root = Path.cwd()
|
22 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
# Load model
|
24 |
+
p_wct = PhotoWCT().to(device)
|
25 |
p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth"))
|
26 |
+
p_pro = Propagator().to(device)
|
27 |
stylization_module=p_wct
|
28 |
smoothing_module=p_pro
|
29 |
|
|
|
115 |
return stylized_img
|
116 |
|
117 |
def smoother(stylized_img, over_img):
|
118 |
+
if device == 'cuda':
|
119 |
+
smoothing_module.to(device)
|
120 |
final_img = smoothing_module.process(stylized_img, over_img)
|
121 |
#final_img = smooth_filter(stylized_img, over_img, f_radius=15, f_edge=1e-1)
|
122 |
return final_img
|