Spaces:
Running
Running
Minh
commited on
Commit
·
85b3f9b
1
Parent(s):
e27d8ac
Add multiple strokes support
Browse files- app.py +68 -6
- ckpts/actor_notrans.pkl +3 -0
- ckpts/actor_round.pkl +3 -0
- ckpts/actor_triangle.pkl +3 -0
- ckpts/bezierwotrans.pkl +3 -0
- ckpts/round.pkl +3 -0
- ckpts/triangle.pkl +3 -0
app.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6 |
import argparse
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
9 |
-
|
10 |
from baseline.DRL.actor import *
|
11 |
from baseline.Renderer.stroke_gen import *
|
12 |
from baseline.Renderer.model import *
|
@@ -23,6 +23,11 @@ canvas_cnt = divide * divide
|
|
23 |
|
24 |
Decoder = FCN()
|
25 |
Decoder.load_state_dict(torch.load(renderer_path))
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def decode(x, canvas): # b * (10 + 3)
|
28 |
x = x.view(-1, 10 + 3)
|
@@ -37,6 +42,9 @@ def decode(x, canvas): # b * (10 + 3)
|
|
37 |
for i in range(5):
|
38 |
canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
|
39 |
res.append(canvas)
|
|
|
|
|
|
|
40 |
return canvas, res
|
41 |
|
42 |
def small2large(x):
|
@@ -86,10 +94,7 @@ def save_img(res, imgid, origin_shape, output_name, divide=False):
|
|
86 |
output = cv2.resize(output, origin_shape)
|
87 |
cv2.imwrite(output_name +"/" + str(imgid) + '.jpg', output)
|
88 |
|
89 |
-
|
90 |
-
actor.load_state_dict(torch.load(actor_path))
|
91 |
-
actor = actor.to(device).eval()
|
92 |
-
Decoder = Decoder.to(device).eval()
|
93 |
|
94 |
|
95 |
|
@@ -162,6 +167,41 @@ def paint_img(img):
|
|
162 |
|
163 |
yield output
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
examples = [
|
166 |
["image/chaoyue.png"],
|
167 |
["image/degang.png"],
|
@@ -170,6 +210,28 @@ examples = [
|
|
170 |
["image/mayun.png"],
|
171 |
]
|
172 |
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
demo.queue()
|
175 |
demo.launch(server_name="0.0.0.0")
|
|
|
6 |
import argparse
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
9 |
+
import gc
|
10 |
from baseline.DRL.actor import *
|
11 |
from baseline.Renderer.stroke_gen import *
|
12 |
from baseline.Renderer.model import *
|
|
|
23 |
|
24 |
Decoder = FCN()
|
25 |
Decoder.load_state_dict(torch.load(renderer_path))
|
26 |
+
actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13
|
27 |
+
actor.load_state_dict(torch.load(actor_path))
|
28 |
+
actor = actor.to(device).eval()
|
29 |
+
Decoder = Decoder.to(device).eval()
|
30 |
+
|
31 |
|
32 |
def decode(x, canvas): # b * (10 + 3)
|
33 |
x = x.view(-1, 10 + 3)
|
|
|
42 |
for i in range(5):
|
43 |
canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
|
44 |
res.append(canvas)
|
45 |
+
gc.collect()
|
46 |
+
if torch.cuda.is_available():
|
47 |
+
torch.cuda.empty_cache()
|
48 |
return canvas, res
|
49 |
|
50 |
def small2large(x):
|
|
|
94 |
output = cv2.resize(output, origin_shape)
|
95 |
cv2.imwrite(output_name +"/" + str(imgid) + '.jpg', output)
|
96 |
|
97 |
+
|
|
|
|
|
|
|
98 |
|
99 |
|
100 |
|
|
|
167 |
|
168 |
yield output
|
169 |
|
170 |
+
|
171 |
+
def change_model(choice: str):
|
172 |
+
global Decoder, actor
|
173 |
+
if choice == "Default":
|
174 |
+
actor_path = 'ckpts/actor.pkl'
|
175 |
+
renderer_path = 'ckpts/renderer.pkl'
|
176 |
+
elif choice == "Triangle":
|
177 |
+
actor_path = 'ckpts/actor_triangle.pkl'
|
178 |
+
renderer_path = 'ckpts/triangle.pkl'
|
179 |
+
elif choice == "Round":
|
180 |
+
actor_path = 'ckpts/actor_round.pkl'
|
181 |
+
renderer_path = 'ckpts/round.pkl'
|
182 |
+
else:
|
183 |
+
actor_path = 'ckpts/actor_notrans.pkl'
|
184 |
+
renderer_path = 'ckpts/bezierwotrans.pkl'
|
185 |
+
|
186 |
+
Decoder.load_state_dict(torch.load(renderer_path, map_location= "cpu"))
|
187 |
+
actor.load_state_dict(torch.load(actor_path, map_location= "cpu"))
|
188 |
+
actor = actor.to(device).eval()
|
189 |
+
Decoder = Decoder.to(device).eval()
|
190 |
+
|
191 |
+
|
192 |
+
def wrapper(func):
|
193 |
+
def inner(*args, **kwargs):
|
194 |
+
val = args[0]
|
195 |
+
args_ = tuple(x for i,x in enumerate(args) if i > 0)
|
196 |
+
event = func(*args_, **kwargs)
|
197 |
+
for i in event:
|
198 |
+
if val == "Cancel":
|
199 |
+
yield i
|
200 |
+
else:
|
201 |
+
event.close()
|
202 |
+
break
|
203 |
+
return inner
|
204 |
+
|
205 |
examples = [
|
206 |
["image/chaoyue.png"],
|
207 |
["image/degang.png"],
|
|
|
210 |
["image/mayun.png"],
|
211 |
]
|
212 |
|
213 |
+
with gr.Blocks() as demo:
|
214 |
+
with gr.Row():
|
215 |
+
with gr.Column():
|
216 |
+
input_image = gr.Image(label="Input image")
|
217 |
+
with gr.Row():
|
218 |
+
dropdown = gr.Dropdown(['Default', 'Round', 'Triangle', 'Bezier wo trans'], value= 'Default', label= 'Stroke choice')
|
219 |
+
with gr.Row():
|
220 |
+
with gr.Column():
|
221 |
+
clr_btn = gr.ClearButton([input_image], variant= "stop")
|
222 |
+
with gr.Column():
|
223 |
+
translate_btn = gr.Button(value="Paint", variant="primary")
|
224 |
+
|
225 |
+
with gr.Column():
|
226 |
+
output = gr.Image(label="Painting Result")
|
227 |
+
dropdown.select(change_model, dropdown)
|
228 |
+
translate_btn.click(lambda x: gr.Button(value="Cancel", variant="stop") if x == "Paint" else gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn).then(wrapper(paint_img), inputs=[translate_btn, input_image], outputs=output)\
|
229 |
+
.success(lambda x: gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)
|
230 |
+
examples = gr.Examples(examples=examples,
|
231 |
+
inputs=[input_image], cache_examples = False)
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
# demo = gr.Interface(fn=paint_img, inputs=gr.Image(), outputs="image", examples = examples)
|
236 |
demo.queue()
|
237 |
demo.launch(server_name="0.0.0.0")
|
ckpts/actor_notrans.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9178218af4256659e4589b378389aa115d2ed293d80d86ac9a650283824b0d6e
|
3 |
+
size 44898587
|
ckpts/actor_round.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29ac4cfddc16642443f3954326055b04a9589d739ae36058ac95c54552990ad7
|
3 |
+
size 44898567
|
ckpts/actor_triangle.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:68fefdc1bff96efe8220ed267442b32fd4cc1efdf02701fee0d74315df07997f
|
3 |
+
size 44898567
|
ckpts/bezierwotrans.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:82f9e5272d54bc17062f25392e8a0a6070e475852cbcf932ef32818db2cb2fea
|
3 |
+
size 44165801
|
ckpts/round.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a057a399a5f5f481e9606045f133ce79b62c5bc446dc151cb0620b2745f31913
|
3 |
+
size 44165813
|
ckpts/triangle.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59c8fbb8b2ebabf732ea03b9d6904647faa8c074c791df6730b9ddfd1f61cfeb
|
3 |
+
size 44165813
|