svjack commited on
Commit
13bc08e
·
1 Parent(s): 9f9b215

Upload visualizer_drag_gradio_local.py

Browse files
Files changed (1) hide show
  1. visualizer_drag_gradio_local.py +975 -0
visualizer_drag_gradio_local.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/spaces/DragGan/DragGan
2
+ # https://huggingface.co/DragGan/DragGan-Models
3
+ # https://arxiv.org/abs/2305.10973
4
+
5
+ #### !git clone https://huggingface.co/spaces/DragGan/DragGan
6
+ #### !pip install -r DragGan/requirements.txt
7
+ #### !wget https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
8
+ #### !wget https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-afhqv2-512x512.pkl
9
+
10
+ import os
11
+
12
+ os.chdir("DragGan")
13
+
14
+ #os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
15
+
16
+ import os.path as osp
17
+ from argparse import ArgumentParser
18
+ from functools import partial
19
+ from pathlib import Path
20
+ import time
21
+
22
+ import psutil
23
+
24
+ import gradio as gr
25
+ import numpy as np
26
+ import torch
27
+ from PIL import Image
28
+
29
+ import dnnlib
30
+ from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
31
+ get_latest_points_pair, get_valid_mask,
32
+ on_change_single_global_state)
33
+ from viz.renderer import Renderer, add_watermark_np
34
+
35
+
36
+ # download models from Hugging Face hub
37
+ from huggingface_hub import snapshot_download
38
+
39
+ model_dir = Path('./checkpoints')
40
+ os.mkdir(model_dir)
41
+ snapshot_download('DragGan/DragGan-Models',
42
+ repo_type='model', local_dir=model_dir)
43
+
44
+ #### !wget https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
45
+ #### !wget https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-afhqv2-512x512.pkl
46
+ #### !cp stylegan3-t-ffhq-1024x1024.pkl checkpoints
47
+ #### !cp stylegan3-t-afhqv2-512x512.pkl checkpoints
48
+
49
+ #### !git clone https://huggingface.co/svjack/stylegan2_cat_400
50
+ #### !cp stylegan2_cat_400/cat_512_stylegan2.pkl checkpoints/
51
+
52
+ '''
53
+ parser = ArgumentParser()
54
+ parser.add_argument('--share', action='store_true')
55
+ parser.add_argument('--cache-dir', type=str, default='./checkpoints')
56
+ args = parser.parse_args()
57
+ '''
58
+
59
+ #cache_dir = args.cache_dir
60
+ cache_dir = './checkpoints'
61
+
62
+ device = 'cuda'
63
+ #device = "cpu"
64
+ IS_SPACE = "DragGan/DragGan" in os.environ.get('SPACE_ID', '')
65
+ TIMEOUT = 80
66
+
67
+
68
+ def reverse_point_pairs(points):
69
+ new_points = []
70
+ for p in points:
71
+ new_points.append([p[1], p[0]])
72
+ return new_points
73
+
74
+
75
+ def clear_state(global_state, target=None):
76
+ """Clear target history state from global_state
77
+ If target is not defined, points and mask will be both removed.
78
+ 1. set global_state['points'] as empty dict
79
+ 2. set global_state['mask'] as full-one mask.
80
+ """
81
+ if target is None:
82
+ target = ['point', 'mask']
83
+ if not isinstance(target, list):
84
+ target = [target]
85
+ if 'point' in target:
86
+ global_state['points'] = dict()
87
+ print('Clear Points State!')
88
+ if 'mask' in target:
89
+ image_raw = global_state["images"]["image_raw"]
90
+ global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
91
+ dtype=np.uint8)
92
+ print('Clear mask State!')
93
+
94
+ return global_state
95
+
96
+
97
+ def init_images(global_state):
98
+ """This function is called only ones with Gradio App is started.
99
+ 0. pre-process global_state, unpack value from global_state of need
100
+ 1. Re-init renderer
101
+ 2. run `renderer._render_drag_impl` with `is_drag=False` to generate
102
+ new image
103
+ 3. Assign images to global state and re-generate mask
104
+ """
105
+
106
+ if isinstance(global_state, gr.State):
107
+ state = global_state.value
108
+ else:
109
+ state = global_state
110
+
111
+ state['renderer'].init_network(
112
+ state['generator_params'], # res
113
+ valid_checkpoints_dict[state['pretrained_weight']], # pkl
114
+ state['params']['seed'], # w0_seed,
115
+ None, # w_load
116
+ state['params']['latent_space'] == 'w+', # w_plus
117
+ 'const',
118
+ state['params']['trunc_psi'], # trunc_psi,
119
+ state['params']['trunc_cutoff'], # trunc_cutoff,
120
+ None, # input_transform
121
+ state['params']['lr'] # lr,
122
+ )
123
+
124
+ state['renderer']._render_drag_impl(state['generator_params'],
125
+ is_drag=False,
126
+ to_pil=True)
127
+
128
+ init_image = state['generator_params'].image
129
+ state['images']['image_orig'] = init_image
130
+ state['images']['image_raw'] = init_image
131
+ state['images']['image_show'] = Image.fromarray(
132
+ add_watermark_np(np.array(init_image)))
133
+ state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
134
+ dtype=np.uint8)
135
+ return global_state
136
+
137
+
138
+ def update_image_draw(image, points, mask, show_mask, global_state=None):
139
+
140
+ image_draw = draw_points_on_image(image, points)
141
+ if show_mask and mask is not None and not (mask == 0).all() and not (
142
+ mask == 1).all():
143
+ image_draw = draw_mask_on_image(image_draw, mask)
144
+
145
+ image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
146
+ if global_state is not None:
147
+ global_state['images']['image_show'] = image_draw
148
+ return image_draw
149
+
150
+
151
+ def preprocess_mask_info(global_state, image):
152
+ """Function to handle mask information.
153
+ 1. last_mask is None: Do not need to change mask, return mask
154
+ 2. last_mask is not None:
155
+ 2.1 global_state is remove_mask:
156
+ 2.2 global_state is add_mask:
157
+ """
158
+ if isinstance(image, dict):
159
+ last_mask = get_valid_mask(image['mask'])
160
+ else:
161
+ last_mask = None
162
+ mask = global_state['mask']
163
+
164
+ # mask in global state is a placeholder with all 1.
165
+ if (mask == 1).all():
166
+ mask = last_mask
167
+
168
+ # last_mask = global_state['last_mask']
169
+ editing_mode = global_state['editing_state']
170
+
171
+ if last_mask is None:
172
+ return global_state
173
+
174
+ if editing_mode == 'remove_mask':
175
+ updated_mask = np.clip(mask - last_mask, 0, 1)
176
+ print(f'Last editing_state is {editing_mode}, do remove.')
177
+ elif editing_mode == 'add_mask':
178
+ updated_mask = np.clip(mask + last_mask, 0, 1)
179
+ print(f'Last editing_state is {editing_mode}, do add.')
180
+ else:
181
+ updated_mask = mask
182
+ print(f'Last editing_state is {editing_mode}, '
183
+ 'do nothing to mask.')
184
+
185
+ global_state['mask'] = updated_mask
186
+ # global_state['last_mask'] = None # clear buffer
187
+ return global_state
188
+
189
+
190
+ def print_memory_usage():
191
+ # Print system memory usage
192
+ print(f"System memory usage: {psutil.virtual_memory().percent}%")
193
+
194
+ # Print GPU memory usage
195
+ if torch.cuda.is_available():
196
+ device = torch.device("cuda")
197
+ print(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9} GB")
198
+ print(
199
+ f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9} GB")
200
+ device_properties = torch.cuda.get_device_properties(device)
201
+ available_memory = device_properties.total_memory - \
202
+ torch.cuda.max_memory_allocated()
203
+ print(f"Available GPU memory: {available_memory / 1e9} GB")
204
+ else:
205
+ print("No GPU available")
206
+
207
+
208
+ # filter large models running on SPACES
209
+ allowed_checkpoints = [] # all checkpoints
210
+ if IS_SPACE:
211
+ '''
212
+ allowed_checkpoints = ["stylegan_human_v2_512.pkl",
213
+ "stylegan2_dogs_1024_pytorch.pkl", "stylegan3-t-ffhq-1024x1024.pkl"]
214
+ '''
215
+ #allowed_checkpoints = ["stylegan3-t-ffhq-1024x1024.pkl"]
216
+ #allowed_checkpoints = ['stylegan3-t-afhqv2-512x512.pkl']
217
+ allowed_checkpoints = ["cat_512_stylegan2.pkl"]
218
+
219
+ valid_checkpoints_dict = {
220
+ f.name.split('.')[0]: str(f)
221
+ for f in Path(cache_dir).glob('*.pkl')
222
+ if f.name in allowed_checkpoints or not IS_SPACE
223
+ }
224
+ print('Valid checkpoint file:')
225
+ print(valid_checkpoints_dict)
226
+
227
+ ###init_pkl = 'stylegan_human_v2_512'
228
+ ###init_pkl = "stylegan3-t-ffhq-1024x1024"
229
+ #init_pkl = "stylegan3-t-afhqv2-512x512"
230
+ init_pkl = "cat_512_stylegan2"
231
+
232
+ with gr.Blocks() as app:
233
+ gr.Markdown("""
234
+ # DragGAN - Drag Your GAN
235
+ ## Interactive Point-based Manipulation on the Generative Image Manifold
236
+ ### Unofficial Gradio Demo
237
+
238
+ **Due to high demand, only one model can be run at a time, or you can duplicate the space and run your own copy.**
239
+
240
+ <a href="https://huggingface.co/spaces/radames/DragGan?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
241
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p>
242
+
243
+ * Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN)
244
+ * Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996) © [OpenMMLab MMagic](https://github.com/open-mmlab/mmagic)
245
+ """)
246
+
247
+ # renderer = Renderer()
248
+ global_state = gr.State({
249
+ "images": {
250
+ # image_orig: the original image, change with seed/model is changed
251
+ # image_raw: image with mask and points, change durning optimization
252
+ # image_show: image showed on screen
253
+ },
254
+ "temporal_params": {
255
+ # stop
256
+ },
257
+ 'mask':
258
+ None, # mask for visualization, 1 for editing and 0 for unchange
259
+ 'last_mask': None, # last edited mask
260
+ 'show_mask': True, # add button
261
+ "generator_params": dnnlib.EasyDict(),
262
+ "params": {
263
+ "seed": int(np.random.randint(0, 2**32 - 1)),
264
+ "motion_lambda": 20,
265
+ "r1_in_pixels": 3,
266
+ "r2_in_pixels": 12,
267
+ "magnitude_direction_in_pixels": 1.0,
268
+ "latent_space": "w+",
269
+ "trunc_psi": 0.7,
270
+ "trunc_cutoff": None,
271
+ "lr": 0.001,
272
+ },
273
+ "device": device,
274
+ "draw_interval": 1,
275
+ "renderer": Renderer(disable_timing=True),
276
+ "points": {},
277
+ "curr_point": None,
278
+ "curr_type_point": "start",
279
+ 'editing_state': 'add_points',
280
+ 'pretrained_weight': init_pkl
281
+ })
282
+
283
+ # init image
284
+ global_state = init_images(global_state)
285
+ with gr.Row():
286
+
287
+ with gr.Row():
288
+
289
+ # Left --> tools
290
+ with gr.Column(scale=3):
291
+
292
+ # Pickle
293
+ with gr.Row():
294
+
295
+ with gr.Column(scale=1, min_width=10):
296
+ gr.Markdown(value='Pickle', show_label=False)
297
+
298
+ with gr.Column(scale=4, min_width=10):
299
+ form_pretrained_dropdown = gr.Dropdown(
300
+ choices=list(valid_checkpoints_dict.keys()),
301
+ label="Pretrained Model",
302
+ value=init_pkl,
303
+ )
304
+
305
+ # Latent
306
+ with gr.Row():
307
+ with gr.Column(scale=1, min_width=10):
308
+ gr.Markdown(value='Latent', show_label=False)
309
+
310
+ with gr.Column(scale=4, min_width=10):
311
+ form_seed_number = gr.Slider(
312
+ mininium=0,
313
+ maximum=2**32-1,
314
+ step=1,
315
+ value=global_state.value['params']['seed'],
316
+ interactive=True,
317
+ # randomize=True,
318
+ label="Seed",
319
+ )
320
+ form_lr_number = gr.Number(
321
+ value=global_state.value["params"]["lr"],
322
+ interactive=True,
323
+ label="Step Size")
324
+
325
+ with gr.Row():
326
+ with gr.Column(scale=2, min_width=10):
327
+ form_reset_image = gr.Button("Reset Image")
328
+ with gr.Column(scale=3, min_width=10):
329
+ form_latent_space = gr.Radio(
330
+ ['w', 'w+'],
331
+ value=global_state.value['params']
332
+ ['latent_space'],
333
+ interactive=True,
334
+ label='Latent space to optimize',
335
+ show_label=False,
336
+ )
337
+
338
+ # Drag
339
+ with gr.Row():
340
+ with gr.Column(scale=1, min_width=10):
341
+ gr.Markdown(value='Drag', show_label=False)
342
+ with gr.Column(scale=4, min_width=10):
343
+ with gr.Row():
344
+ with gr.Column(scale=1, min_width=10):
345
+ enable_add_points = gr.Button('Add Points')
346
+ with gr.Column(scale=1, min_width=10):
347
+ undo_points = gr.Button('Reset Points')
348
+ with gr.Row():
349
+ with gr.Column(scale=1, min_width=10):
350
+ form_start_btn = gr.Button("Start")
351
+ with gr.Column(scale=1, min_width=10):
352
+ form_stop_btn = gr.Button("Stop")
353
+
354
+ form_steps_number = gr.Number(value=0,
355
+ label="Steps",
356
+ interactive=False)
357
+
358
+ # Mask
359
+ with gr.Row():
360
+ with gr.Column(scale=1, min_width=10):
361
+ gr.Markdown(value='Mask', show_label=False)
362
+ with gr.Column(scale=4, min_width=10):
363
+ enable_add_mask = gr.Button('Edit Flexible Area')
364
+ with gr.Row():
365
+ with gr.Column(scale=1, min_width=10):
366
+ form_reset_mask_btn = gr.Button("Reset mask")
367
+ with gr.Column(scale=1, min_width=10):
368
+ show_mask = gr.Checkbox(
369
+ label='Show Mask',
370
+ value=global_state.value['show_mask'],
371
+ show_label=False)
372
+
373
+ with gr.Row():
374
+ form_lambda_number = gr.Number(
375
+ value=global_state.value["params"]
376
+ ["motion_lambda"],
377
+ interactive=True,
378
+ label="Lambda",
379
+ )
380
+
381
+ form_draw_interval_number = gr.Number(
382
+ value=global_state.value["draw_interval"],
383
+ label="Draw Interval (steps)",
384
+ interactive=True,
385
+ visible=False)
386
+
387
+ # Right --> Image
388
+ with gr.Column(scale=8):
389
+ form_image = ImageMask(
390
+ value=global_state.value['images']['image_show'],
391
+ brush_radius=20).style(
392
+ width=768,
393
+ height=768) # NOTE: hard image size code here.
394
+ gr.Markdown("""
395
+ ## Quick Start
396
+
397
+ 1. Select desired `Pretrained Model` and adjust `Seed` to generate an
398
+ initial image.
399
+ 2. Click on image to add control points.
400
+ 3. Click `Start` and enjoy it!
401
+
402
+ ## Advance Usage
403
+
404
+ 1. Change `Step Size` to adjust learning rate in drag optimization.
405
+ 2. Select `w` or `w+` to change latent space to optimize:
406
+ * Optimize on `w` space may cause greater influence to the image.
407
+ * Optimize on `w+` space may work slower than `w`, but usually achieve
408
+ better results.
409
+ * Note that changing the latent space will reset the image, points and
410
+ mask (this has the same effect as `Reset Image` button).
411
+ 3. Click `Edit Flexible Area` to create a mask and constrain the
412
+ unmasked region to remain unchanged.
413
+
414
+
415
+ """)
416
+ gr.HTML("""
417
+ <style>
418
+ .container {
419
+ position: absolute;
420
+ height: 50px;
421
+ text-align: center;
422
+ line-height: 50px;
423
+ width: 100%;
424
+ }
425
+ </style>
426
+ <div class="container">
427
+ Gradio demo supported by
428
+ <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
429
+ <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
430
+ </div>
431
+ """)
432
+ # Network & latents tab listeners
433
+
434
+ def on_change_pretrained_dropdown(pretrained_value, global_state):
435
+ """Function to handle model change.
436
+ 1. Set pretrained value to global_state
437
+ 2. Re-init images and clear all states
438
+ """
439
+
440
+ global_state['pretrained_weight'] = pretrained_value
441
+ init_images(global_state)
442
+ clear_state(global_state)
443
+
444
+ return global_state, global_state["images"]['image_show']
445
+
446
+ form_pretrained_dropdown.change(
447
+ on_change_pretrained_dropdown,
448
+ inputs=[form_pretrained_dropdown, global_state],
449
+ outputs=[global_state, form_image],
450
+ queue=True,
451
+ )
452
+
453
+ def on_click_reset_image(global_state):
454
+ """Reset image to the original one and clear all states
455
+ 1. Re-init images
456
+ 2. Clear all states
457
+ """
458
+
459
+ init_images(global_state)
460
+ clear_state(global_state)
461
+
462
+ return global_state, global_state['images']['image_show']
463
+
464
+ form_reset_image.click(
465
+ on_click_reset_image,
466
+ inputs=[global_state],
467
+ outputs=[global_state, form_image],
468
+ queue=False,
469
+ )
470
+
471
+ # Update parameters
472
+ def on_change_update_image_seed(seed, global_state):
473
+ """Function to handle generation seed change.
474
+ 1. Set seed to global_state
475
+ 2. Re-init images and clear all states
476
+ """
477
+
478
+ global_state["params"]["seed"] = int(seed)
479
+ init_images(global_state)
480
+ clear_state(global_state)
481
+
482
+ return global_state, global_state['images']['image_show']
483
+
484
+ form_seed_number.change(
485
+ on_change_update_image_seed,
486
+ inputs=[form_seed_number, global_state],
487
+ outputs=[global_state, form_image],
488
+ )
489
+
490
+ def on_click_latent_space(latent_space, global_state):
491
+ """Function to reset latent space to optimize.
492
+ NOTE: this function we reset the image and all controls
493
+ 1. Set latent-space to global_state
494
+ 2. Re-init images and clear all state
495
+ """
496
+
497
+ global_state['params']['latent_space'] = latent_space
498
+ init_images(global_state)
499
+ clear_state(global_state)
500
+
501
+ return global_state, global_state['images']['image_show']
502
+
503
+ form_latent_space.change(on_click_latent_space,
504
+ inputs=[form_latent_space, global_state],
505
+ outputs=[global_state, form_image])
506
+
507
+ # ==== Params
508
+ form_lambda_number.change(
509
+ partial(on_change_single_global_state, ["params", "motion_lambda"]),
510
+ inputs=[form_lambda_number, global_state],
511
+ outputs=[global_state],
512
+ )
513
+
514
+ def on_change_lr(lr, global_state):
515
+ if lr == 0:
516
+ print('lr is 0, do nothing.')
517
+ return global_state
518
+ else:
519
+ global_state["params"]["lr"] = lr
520
+ renderer = global_state['renderer']
521
+ renderer.update_lr(lr)
522
+ print('New optimizer: ')
523
+ print(renderer.w_optim)
524
+ return global_state
525
+
526
+ form_lr_number.change(
527
+ on_change_lr,
528
+ inputs=[form_lr_number, global_state],
529
+ outputs=[global_state],
530
+ queue=False,
531
+ )
532
+
533
+ def on_click_start(global_state, image):
534
+ p_in_pixels = []
535
+ t_in_pixels = []
536
+ valid_points = []
537
+
538
+ # handle of start drag in mask editing mode
539
+ global_state = preprocess_mask_info(global_state, image)
540
+
541
+ # Prepare the points for the inference
542
+ if len(global_state["points"]) == 0:
543
+ # yield on_click_start_wo_points(global_state, image)
544
+ image_raw = global_state['images']['image_raw']
545
+ update_image_draw(
546
+ image_raw,
547
+ global_state['points'],
548
+ global_state['mask'],
549
+ global_state['show_mask'],
550
+ global_state,
551
+ )
552
+
553
+ yield (
554
+ global_state,
555
+ 0,
556
+ global_state['images']['image_show'],
557
+ # gr.File.update(visible=False),
558
+ gr.Button.update(interactive=True),
559
+ gr.Button.update(interactive=True),
560
+ gr.Button.update(interactive=True),
561
+ gr.Button.update(interactive=True),
562
+ gr.Button.update(interactive=True),
563
+ # latent space
564
+ gr.Radio.update(interactive=True),
565
+ gr.Button.update(interactive=True),
566
+ # NOTE: disable stop button
567
+ gr.Button.update(interactive=False),
568
+
569
+ # update other comps
570
+ gr.Dropdown.update(interactive=True),
571
+ gr.Number.update(interactive=True),
572
+ gr.Number.update(interactive=True),
573
+ gr.Button.update(interactive=True),
574
+ gr.Button.update(interactive=True),
575
+ gr.Checkbox.update(interactive=True),
576
+ # gr.Number.update(interactive=True),
577
+ gr.Number.update(interactive=True),
578
+ )
579
+ else:
580
+
581
+ # Transform the points into torch tensors
582
+ for key_point, point in global_state["points"].items():
583
+ try:
584
+ p_start = point.get("start_temp", point["start"])
585
+ p_end = point["target"]
586
+
587
+ if p_start is None or p_end is None:
588
+ continue
589
+
590
+ except KeyError:
591
+ continue
592
+
593
+ p_in_pixels.append(p_start)
594
+ t_in_pixels.append(p_end)
595
+ valid_points.append(key_point)
596
+
597
+ mask = torch.tensor(global_state['mask']).float()
598
+ drag_mask = 1 - mask
599
+
600
+ renderer: Renderer = global_state["renderer"]
601
+ global_state['temporal_params']['stop'] = False
602
+ global_state['editing_state'] = 'running'
603
+
604
+ # reverse points order
605
+ p_to_opt = reverse_point_pairs(p_in_pixels)
606
+ t_to_opt = reverse_point_pairs(t_in_pixels)
607
+ print('Running with:')
608
+ print(f' Source: {p_in_pixels}')
609
+ print(f' Target: {t_in_pixels}')
610
+ step_idx = 0
611
+ last_time = time.time()
612
+ while True:
613
+ print_memory_usage()
614
+ # add a TIMEOUT break
615
+ print(f'Running time: {time.time() - last_time}')
616
+ if IS_SPACE and time.time() - last_time > TIMEOUT:
617
+ print('Timeout break!')
618
+ break
619
+ if global_state["temporal_params"]["stop"] or global_state['generator_params']["stop"]:
620
+ break
621
+
622
+ # do drage here!
623
+ renderer._render_drag_impl(
624
+ global_state['generator_params'],
625
+ p_to_opt, # point
626
+ t_to_opt, # target
627
+ drag_mask, # mask,
628
+ global_state['params']['motion_lambda'], # lambda_mask
629
+ reg=0,
630
+ feature_idx=5, # NOTE: do not support change for now
631
+ r1=global_state['params']['r1_in_pixels'], # r1
632
+ r2=global_state['params']['r2_in_pixels'], # r2
633
+ # random_seed = 0,
634
+ # noise_mode = 'const',
635
+ trunc_psi=global_state['params']['trunc_psi'],
636
+ # force_fp32 = False,
637
+ # layer_name = None,
638
+ # sel_channels = 3,
639
+ # base_channel = 0,
640
+ # img_scale_db = 0,
641
+ # img_normalize = False,
642
+ # untransform = False,
643
+ is_drag=True,
644
+ to_pil=True)
645
+
646
+ if step_idx % global_state['draw_interval'] == 0:
647
+ print('Current Source:')
648
+ for key_point, p_i, t_i in zip(valid_points, p_to_opt,
649
+ t_to_opt):
650
+ global_state["points"][key_point]["start_temp"] = [
651
+ p_i[1],
652
+ p_i[0],
653
+ ]
654
+ global_state["points"][key_point]["target"] = [
655
+ t_i[1],
656
+ t_i[0],
657
+ ]
658
+ start_temp = global_state["points"][key_point][
659
+ "start_temp"]
660
+ print(f' {start_temp}')
661
+
662
+ image_result = global_state['generator_params']['image']
663
+ image_draw = update_image_draw(
664
+ image_result,
665
+ global_state['points'],
666
+ global_state['mask'],
667
+ global_state['show_mask'],
668
+ global_state,
669
+ )
670
+ global_state['images']['image_raw'] = image_result
671
+
672
+ yield (
673
+ global_state,
674
+ step_idx,
675
+ global_state['images']['image_show'],
676
+ # gr.File.update(visible=False),
677
+ gr.Button.update(interactive=False),
678
+ gr.Button.update(interactive=False),
679
+ gr.Button.update(interactive=False),
680
+ gr.Button.update(interactive=False),
681
+ gr.Button.update(interactive=False),
682
+ # latent space
683
+ gr.Radio.update(interactive=False),
684
+ gr.Button.update(interactive=False),
685
+ # enable stop button in loop
686
+ gr.Button.update(interactive=True),
687
+
688
+ # update other comps
689
+ gr.Dropdown.update(interactive=False),
690
+ gr.Number.update(interactive=False),
691
+ gr.Number.update(interactive=False),
692
+ gr.Button.update(interactive=False),
693
+ gr.Button.update(interactive=False),
694
+ gr.Checkbox.update(interactive=False),
695
+ # gr.Number.update(interactive=False),
696
+ gr.Number.update(interactive=False),
697
+ )
698
+
699
+ # increate step
700
+ step_idx += 1
701
+
702
+ image_result = global_state['generator_params']['image']
703
+ global_state['images']['image_raw'] = image_result
704
+ image_draw = update_image_draw(image_result,
705
+ global_state['points'],
706
+ global_state['mask'],
707
+ global_state['show_mask'],
708
+ global_state)
709
+
710
+ # fp = NamedTemporaryFile(suffix=".png", delete=False)
711
+ # image_result.save(fp, "PNG")
712
+
713
+ global_state['editing_state'] = 'add_points'
714
+
715
+ yield (
716
+ global_state,
717
+ 0, # reset step to 0 after stop.
718
+ global_state['images']['image_show'],
719
+ # gr.File.update(visible=True, value=fp.name),
720
+ gr.Button.update(interactive=True),
721
+ gr.Button.update(interactive=True),
722
+ gr.Button.update(interactive=True),
723
+ gr.Button.update(interactive=True),
724
+ gr.Button.update(interactive=True),
725
+ # latent space
726
+ gr.Radio.update(interactive=True),
727
+ gr.Button.update(interactive=True),
728
+ # NOTE: disable stop button with loop finish
729
+ gr.Button.update(interactive=False),
730
+
731
+ # update other comps
732
+ gr.Dropdown.update(interactive=True),
733
+ gr.Number.update(interactive=True),
734
+ gr.Number.update(interactive=True),
735
+ gr.Checkbox.update(interactive=True),
736
+ gr.Number.update(interactive=True),
737
+ )
738
+
739
+ form_start_btn.click(
740
+ on_click_start,
741
+ inputs=[global_state, form_image],
742
+ outputs=[
743
+ global_state,
744
+ form_steps_number,
745
+ form_image,
746
+ # form_download_result_file,
747
+ # >>> buttons
748
+ form_reset_image,
749
+ enable_add_points,
750
+ enable_add_mask,
751
+ undo_points,
752
+ form_reset_mask_btn,
753
+ form_latent_space,
754
+ form_start_btn,
755
+ form_stop_btn,
756
+ # <<< buttonm
757
+ # >>> inputs comps
758
+ form_pretrained_dropdown,
759
+ form_seed_number,
760
+ form_lr_number,
761
+ show_mask,
762
+ form_lambda_number,
763
+ ],
764
+ )
765
+
766
+ def on_click_stop(global_state):
767
+ """Function to handle stop button is clicked.
768
+ 1. send a stop signal by set global_state["temporal_params"]["stop"] as True
769
+ 2. Disable Stop button
770
+ """
771
+ global_state["temporal_params"]["stop"] = True
772
+
773
+ return global_state, gr.Button.update(interactive=False)
774
+
775
+ form_stop_btn.click(on_click_stop,
776
+ inputs=[global_state],
777
+ outputs=[global_state, form_stop_btn],
778
+ queue=False)
779
+
780
+ form_draw_interval_number.change(
781
+ partial(
782
+ on_change_single_global_state,
783
+ "draw_interval",
784
+ map_transform=lambda x: int(x),
785
+ ),
786
+ inputs=[form_draw_interval_number, global_state],
787
+ outputs=[global_state],
788
+ queue=False,
789
+ )
790
+
791
+ def on_click_remove_point(global_state):
792
+ choice = global_state["curr_point"]
793
+ del global_state["points"][choice]
794
+
795
+ choices = list(global_state["points"].keys())
796
+
797
+ if len(choices) > 0:
798
+ global_state["curr_point"] = choices[0]
799
+
800
+ return (
801
+ gr.Dropdown.update(choices=choices, value=choices[0]),
802
+ global_state,
803
+ )
804
+
805
+ # Mask
806
+ def on_click_reset_mask(global_state):
807
+ global_state['mask'] = np.ones(
808
+ (
809
+ global_state["images"]["image_raw"].size[1],
810
+ global_state["images"]["image_raw"].size[0],
811
+ ),
812
+ dtype=np.uint8,
813
+ )
814
+ image_draw = update_image_draw(global_state['images']['image_raw'],
815
+ global_state['points'],
816
+ global_state['mask'],
817
+ global_state['show_mask'], global_state)
818
+ return global_state, image_draw
819
+
820
+ form_reset_mask_btn.click(
821
+ on_click_reset_mask,
822
+ inputs=[global_state],
823
+ outputs=[global_state, form_image],
824
+ )
825
+
826
+ # Image
827
+ def on_click_enable_draw(global_state, image):
828
+ """Function to start add mask mode.
829
+ 1. Preprocess mask info from last state
830
+ 2. Change editing state to add_mask
831
+ 3. Set curr image with points and mask
832
+ """
833
+ global_state = preprocess_mask_info(global_state, image)
834
+ global_state['editing_state'] = 'add_mask'
835
+ image_raw = global_state['images']['image_raw']
836
+ image_draw = update_image_draw(image_raw, global_state['points'],
837
+ global_state['mask'], True,
838
+ global_state)
839
+ return (global_state,
840
+ gr.Image.update(value=image_draw, interactive=True))
841
+
842
+ def on_click_remove_draw(global_state, image):
843
+ """Function to start remove mask mode.
844
+ 1. Preprocess mask info from last state
845
+ 2. Change editing state to remove_mask
846
+ 3. Set curr image with points and mask
847
+ """
848
+ global_state = preprocess_mask_info(global_state, image)
849
+ global_state['edinting_state'] = 'remove_mask'
850
+ image_raw = global_state['images']['image_raw']
851
+ image_draw = update_image_draw(image_raw, global_state['points'],
852
+ global_state['mask'], True,
853
+ global_state)
854
+ return (global_state,
855
+ gr.Image.update(value=image_draw, interactive=True))
856
+
857
+ enable_add_mask.click(on_click_enable_draw,
858
+ inputs=[global_state, form_image],
859
+ outputs=[
860
+ global_state,
861
+ form_image,
862
+ ],
863
+ queue=False)
864
+
865
+ def on_click_add_point(global_state, image: dict):
866
+ """Function switch from add mask mode to add points mode.
867
+ 1. Updaste mask buffer if need
868
+ 2. Change global_state['editing_state'] to 'add_points'
869
+ 3. Set current image with mask
870
+ """
871
+
872
+ global_state = preprocess_mask_info(global_state, image)
873
+ global_state['editing_state'] = 'add_points'
874
+ mask = global_state['mask']
875
+ image_raw = global_state['images']['image_raw']
876
+ image_draw = update_image_draw(image_raw, global_state['points'], mask,
877
+ global_state['show_mask'], global_state)
878
+
879
+ return (global_state,
880
+ gr.Image.update(value=image_draw, interactive=False))
881
+
882
+ enable_add_points.click(on_click_add_point,
883
+ inputs=[global_state, form_image],
884
+ outputs=[global_state, form_image],
885
+ queue=False)
886
+
887
+ def on_click_image(global_state, evt: gr.SelectData):
888
+ """This function only support click for point selection
889
+ """
890
+ xy = evt.index
891
+ if global_state['editing_state'] != 'add_points':
892
+ print(f'In {global_state["editing_state"]} state. '
893
+ 'Do not add points.')
894
+
895
+ return global_state, global_state['images']['image_show']
896
+
897
+ points = global_state["points"]
898
+
899
+ point_idx = get_latest_points_pair(points)
900
+ if point_idx is None:
901
+ points[0] = {'start': xy, 'target': None}
902
+ print(f'Click Image - Start - {xy}')
903
+ elif points[point_idx].get('target', None) is None:
904
+ points[point_idx]['target'] = xy
905
+ print(f'Click Image - Target - {xy}')
906
+ else:
907
+ points[point_idx + 1] = {'start': xy, 'target': None}
908
+ print(f'Click Image - Start - {xy}')
909
+
910
+ image_raw = global_state['images']['image_raw']
911
+ image_draw = update_image_draw(
912
+ image_raw,
913
+ global_state['points'],
914
+ global_state['mask'],
915
+ global_state['show_mask'],
916
+ global_state,
917
+ )
918
+
919
+ return global_state, image_draw
920
+
921
+ form_image.select(
922
+ on_click_image,
923
+ inputs=[global_state],
924
+ outputs=[global_state, form_image],
925
+ queue=False,
926
+ )
927
+
928
+ def on_click_clear_points(global_state):
929
+ """Function to handle clear all control points
930
+ 1. clear global_state['points'] (clear_state)
931
+ 2. re-init network
932
+ 2. re-draw image
933
+ """
934
+ clear_state(global_state, target='point')
935
+
936
+ renderer: Renderer = global_state["renderer"]
937
+ renderer.feat_refs = None
938
+
939
+ image_raw = global_state['images']['image_raw']
940
+ image_draw = update_image_draw(image_raw, {}, global_state['mask'],
941
+ global_state['show_mask'], global_state)
942
+ return global_state, image_draw
943
+
944
+ undo_points.click(on_click_clear_points,
945
+ inputs=[global_state],
946
+ outputs=[global_state, form_image],
947
+ queue=False)
948
+
949
+ def on_click_show_mask(global_state, show_mask):
950
+ """Function to control whether show mask on image."""
951
+ global_state['show_mask'] = show_mask
952
+
953
+ image_raw = global_state['images']['image_raw']
954
+ image_draw = update_image_draw(
955
+ image_raw,
956
+ global_state['points'],
957
+ global_state['mask'],
958
+ global_state['show_mask'],
959
+ global_state,
960
+ )
961
+ return global_state, image_draw
962
+
963
+ show_mask.change(
964
+ on_click_show_mask,
965
+ inputs=[global_state, show_mask],
966
+ outputs=[global_state, form_image],
967
+ queue=False,
968
+ )
969
+
970
+ #print("SHAReD: Start app", parser.parse_args())
971
+ gr.close_all()
972
+ app.queue(concurrency_count=1, max_size=200, api_open=False)
973
+ ###app.launch(share=args.share, show_api=False)
974
+ #app.launch()
975
+ app.launch(share = True)