Minh commited on
Commit
85b3f9b
·
1 Parent(s): e27d8ac

Add multiple strokes support

Browse files
app.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  import argparse
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
-
10
  from baseline.DRL.actor import *
11
  from baseline.Renderer.stroke_gen import *
12
  from baseline.Renderer.model import *
@@ -23,6 +23,11 @@ canvas_cnt = divide * divide
23
 
24
  Decoder = FCN()
25
  Decoder.load_state_dict(torch.load(renderer_path))
 
 
 
 
 
26
 
27
  def decode(x, canvas): # b * (10 + 3)
28
  x = x.view(-1, 10 + 3)
@@ -37,6 +42,9 @@ def decode(x, canvas): # b * (10 + 3)
37
  for i in range(5):
38
  canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
39
  res.append(canvas)
 
 
 
40
  return canvas, res
41
 
42
  def small2large(x):
@@ -86,10 +94,7 @@ def save_img(res, imgid, origin_shape, output_name, divide=False):
86
  output = cv2.resize(output, origin_shape)
87
  cv2.imwrite(output_name +"/" + str(imgid) + '.jpg', output)
88
 
89
- actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13
90
- actor.load_state_dict(torch.load(actor_path))
91
- actor = actor.to(device).eval()
92
- Decoder = Decoder.to(device).eval()
93
 
94
 
95
 
@@ -162,6 +167,41 @@ def paint_img(img):
162
 
163
  yield output
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  examples = [
166
  ["image/chaoyue.png"],
167
  ["image/degang.png"],
@@ -170,6 +210,28 @@ examples = [
170
  ["image/mayun.png"],
171
  ]
172
 
173
- demo = gr.Interface(fn=paint_img, inputs=gr.Image(), outputs="image", examples = examples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  demo.queue()
175
  demo.launch(server_name="0.0.0.0")
 
6
  import argparse
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
+ import gc
10
  from baseline.DRL.actor import *
11
  from baseline.Renderer.stroke_gen import *
12
  from baseline.Renderer.model import *
 
23
 
24
  Decoder = FCN()
25
  Decoder.load_state_dict(torch.load(renderer_path))
26
+ actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13
27
+ actor.load_state_dict(torch.load(actor_path))
28
+ actor = actor.to(device).eval()
29
+ Decoder = Decoder.to(device).eval()
30
+
31
 
32
  def decode(x, canvas): # b * (10 + 3)
33
  x = x.view(-1, 10 + 3)
 
42
  for i in range(5):
43
  canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
44
  res.append(canvas)
45
+ gc.collect()
46
+ if torch.cuda.is_available():
47
+ torch.cuda.empty_cache()
48
  return canvas, res
49
 
50
  def small2large(x):
 
94
  output = cv2.resize(output, origin_shape)
95
  cv2.imwrite(output_name +"/" + str(imgid) + '.jpg', output)
96
 
97
+
 
 
 
98
 
99
 
100
 
 
167
 
168
  yield output
169
 
170
+
171
+ def change_model(choice: str):
172
+ global Decoder, actor
173
+ if choice == "Default":
174
+ actor_path = 'ckpts/actor.pkl'
175
+ renderer_path = 'ckpts/renderer.pkl'
176
+ elif choice == "Triangle":
177
+ actor_path = 'ckpts/actor_triangle.pkl'
178
+ renderer_path = 'ckpts/triangle.pkl'
179
+ elif choice == "Round":
180
+ actor_path = 'ckpts/actor_round.pkl'
181
+ renderer_path = 'ckpts/round.pkl'
182
+ else:
183
+ actor_path = 'ckpts/actor_notrans.pkl'
184
+ renderer_path = 'ckpts/bezierwotrans.pkl'
185
+
186
+ Decoder.load_state_dict(torch.load(renderer_path, map_location= "cpu"))
187
+ actor.load_state_dict(torch.load(actor_path, map_location= "cpu"))
188
+ actor = actor.to(device).eval()
189
+ Decoder = Decoder.to(device).eval()
190
+
191
+
192
+ def wrapper(func):
193
+ def inner(*args, **kwargs):
194
+ val = args[0]
195
+ args_ = tuple(x for i,x in enumerate(args) if i > 0)
196
+ event = func(*args_, **kwargs)
197
+ for i in event:
198
+ if val == "Cancel":
199
+ yield i
200
+ else:
201
+ event.close()
202
+ break
203
+ return inner
204
+
205
  examples = [
206
  ["image/chaoyue.png"],
207
  ["image/degang.png"],
 
210
  ["image/mayun.png"],
211
  ]
212
 
213
+ with gr.Blocks() as demo:
214
+ with gr.Row():
215
+ with gr.Column():
216
+ input_image = gr.Image(label="Input image")
217
+ with gr.Row():
218
+ dropdown = gr.Dropdown(['Default', 'Round', 'Triangle', 'Bezier wo trans'], value= 'Default', label= 'Stroke choice')
219
+ with gr.Row():
220
+ with gr.Column():
221
+ clr_btn = gr.ClearButton([input_image], variant= "stop")
222
+ with gr.Column():
223
+ translate_btn = gr.Button(value="Paint", variant="primary")
224
+
225
+ with gr.Column():
226
+ output = gr.Image(label="Painting Result")
227
+ dropdown.select(change_model, dropdown)
228
+ translate_btn.click(lambda x: gr.Button(value="Cancel", variant="stop") if x == "Paint" else gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn).then(wrapper(paint_img), inputs=[translate_btn, input_image], outputs=output)\
229
+ .success(lambda x: gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)
230
+ examples = gr.Examples(examples=examples,
231
+ inputs=[input_image], cache_examples = False)
232
+
233
+
234
+
235
+ # demo = gr.Interface(fn=paint_img, inputs=gr.Image(), outputs="image", examples = examples)
236
  demo.queue()
237
  demo.launch(server_name="0.0.0.0")
ckpts/actor_notrans.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9178218af4256659e4589b378389aa115d2ed293d80d86ac9a650283824b0d6e
3
+ size 44898587
ckpts/actor_round.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29ac4cfddc16642443f3954326055b04a9589d739ae36058ac95c54552990ad7
3
+ size 44898567
ckpts/actor_triangle.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68fefdc1bff96efe8220ed267442b32fd4cc1efdf02701fee0d74315df07997f
3
+ size 44898567
ckpts/bezierwotrans.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82f9e5272d54bc17062f25392e8a0a6070e475852cbcf932ef32818db2cb2fea
3
+ size 44165801
ckpts/round.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a057a399a5f5f481e9606045f133ce79b62c5bc446dc151cb0620b2745f31913
3
+ size 44165813
ckpts/triangle.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59c8fbb8b2ebabf732ea03b9d6904647faa8c074c791df6730b9ddfd1f61cfeb
3
+ size 44165813