xingxm commited on
Commit
acdfbd8
·
1 Parent(s): fc0e1a8

fix(method): if attn_map is None

Browse files
Examples.md CHANGED
@@ -160,4 +160,10 @@ expressive eyes. <br/>
160
 
161
  ````shell
162
  python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" x.num_paths=256 result_path='./logs/VanGogh-Portrait'
163
- ````
 
 
 
 
 
 
 
160
 
161
  ````shell
162
  python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" x.num_paths=256 result_path='./logs/VanGogh-Portrait'
163
+ ````
164
+
165
+ ### Case: planet Saturn
166
+
167
+ ```shell
168
+ python svgdreamer.py x=iconography-s1 skip_sive=False "prompt='An icon of the planet Saturn. minimal flat 2D vector icon. plain color background. trending on ArtStation.'" token_ind=6 x.sive.bg.num_iter=50 x.sive.fg.num_iter=50 x.vpsd.t_schedule='randint' result_path='./logs/Saturn' multirun=True state.mprec='fp16
169
+ ```
README.md CHANGED
@@ -80,7 +80,7 @@ realistic <br/>
80
  **Script:**
81
 
82
  ```shell
83
- python svgdreamer.py x=iconography skip_sive=False "prompt='an image of Batman. full body action pose, complete detailed body. white background. empty background, high quality, 4K, ultra realistic'" token_ind=4 x.sive.bg.num_iter=10 x.sive.fg.num_iter=10 x.vpsd.t_schedule='randint' result_path='./logs/batman' multirun=True
84
  ```
85
 
86
  🔹Parameter:
 
80
  **Script:**
81
 
82
  ```shell
83
+ python svgdreamer.py x=iconography skip_sive=False "prompt='an image of Batman. full body action pose, complete detailed body. white background. empty background, high quality, 4K, ultra realistic'" token_ind=4 x.vpsd.t_schedule='randint' result_path='./logs/batman' multirun=True
84
  ```
85
 
86
  🔹Parameter:
conf/x/iconography.yaml CHANGED
@@ -41,7 +41,7 @@ sive:
41
  mask_tau: 0.3 # the threshold used to convert the attention map into a mask
42
  bg:
43
  style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
44
- num_iter: 10
45
  num_paths: 256
46
  path_schedule: 'repeat' # 'repeat', 'list'
47
  schedule_each: 128
@@ -61,7 +61,7 @@ sive:
61
  xing_loss_weight: 0.001
62
  fg:
63
  style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
64
- num_iter: 10
65
  num_paths: 256 # number of strokes
66
  path_schedule: 'repeat' # 'repeat', 'list'
67
  schedule_each: 128
 
41
  mask_tau: 0.3 # the threshold used to convert the attention map into a mask
42
  bg:
43
  style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
44
+ num_iter: 50
45
  num_paths: 256
46
  path_schedule: 'repeat' # 'repeat', 'list'
47
  schedule_each: 128
 
61
  xing_loss_weight: 0.001
62
  fg:
63
  style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
64
+ num_iter: 50
65
  num_paths: 256 # number of strokes
66
  path_schedule: 'repeat' # 'repeat', 'list'
67
  schedule_each: 128
svgdreamer/painter/__init__.py CHANGED
@@ -2,8 +2,8 @@
2
  # Copyright (c) XiMing Xing. All rights reserved.
3
  # Description:
4
 
5
- from .painter_params import (
6
- Painter, PainterOptimizer, CosineWithWarmupLRLambda, RandomCoordInit, NaiveCoordInit, SparseCoordInit, get_sdf)
7
  from .component_painter_params import CompPainter, CompPainterOptimizer
8
  from .loss import xing_loss_fn
9
  from .VPSD_pipeline import VectorizedParticleSDSPipeline
 
2
  # Copyright (c) XiMing Xing. All rights reserved.
3
  # Description:
4
 
5
+ from .painter_params import Painter, PainterOptimizer, CosineWithWarmupLRLambda, RandomCoordInit, NaiveCoordInit, \
6
+ SparseCoordInit, get_sdf
7
  from .component_painter_params import CompPainter, CompPainterOptimizer
8
  from .loss import xing_loss_fn
9
  from .VPSD_pipeline import VectorizedParticleSDSPipeline
svgdreamer/pipelines/SVGDreamer_pipeline.py CHANGED
@@ -20,8 +20,8 @@ from torchvision import transforms
20
  from skimage.color import rgb2gray
21
 
22
  from svgdreamer.libs import ModelState, get_optimizer
23
- from svgdreamer.painter import (CompPainter, CompPainterOptimizer, xing_loss_fn, Painter, PainterOptimizer,
24
- CosineWithWarmupLRLambda, VectorizedParticleSDSPipeline, DiffusionPipeline)
25
  from svgdreamer.token2attn.attn_control import EmptyControl, AttentionStore
26
  from svgdreamer.token2attn.ptp_utils import view_images
27
  from svgdreamer.utils.plot import plot_img, plot_couple, plot_attn, save_image
@@ -38,8 +38,10 @@ class SVGDreamerPipeline(ModelState):
38
  # assert
39
  assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"]
40
  args.skip_sive = True if args.x.style in ["pixelart", "low-poly"] else args.skip_sive
41
- assert args.x.vpsd.n_particle >= args.x.vpsd.vsd_n_particle
42
- assert args.x.vpsd.n_particle >= args.x.vpsd.phi_n_particle
 
 
43
  assert args.x.vpsd.n_phi_sample >= 1
44
 
45
  logdir_ = f"sd{args.seed}" \
@@ -123,15 +125,26 @@ class SVGDreamerPipeline(ModelState):
123
  self.close(msg="painterly rendering complete.")
124
 
125
  def SIVE_stage(self, text_prompt: str):
126
- # init diffusion model
127
  pipeline = DiffusionPipeline(self.x_cfg.sive_model_cfg, self.args.diffuser, self.device)
128
 
129
  merged_svg_paths = []
130
  merged_images = []
131
- for i in range(self.vpsd_cfg.n_particle):
132
- select_sample_path = self.result_path / f'select_sample_{i}.png'
133
- # generate sample and attention map
134
- fg_attn_map, bg_attn_map, controller = self.extract_ldm_attn(i,
 
 
 
 
 
 
 
 
 
 
 
135
  self.x_cfg.sive_model_cfg,
136
  pipeline,
137
  text_prompt,
@@ -139,18 +152,18 @@ class SVGDreamerPipeline(ModelState):
139
  self.sive_cfg.attn_cfg,
140
  self.im_size,
141
  self.args.token_ind)
142
- # load selected file
143
  select_img = self.target_file_preprocess(select_sample_path.as_posix())
144
  self.print(f"load target file from: {select_sample_path.as_posix()}")
145
 
146
- # get objects by attention map
147
- fg_img, bg_img, fg_mask, bg_mask = self.extract_object(i, select_img, fg_attn_map, bg_attn_map,
148
  tau=self.sive_cfg.mask_tau)
149
- self.print(f"fg_img shape: {fg_img.shape}, bg_img: {bg_img.shape}")
150
 
151
- # background rendering
152
- self.print(f"-> background rendering: ")
153
- bg_render_path = self.component_rendering(tag=f'{i}_bg',
154
  prompt=text_prompt,
155
  target_img=bg_img,
156
  mask=bg_mask,
@@ -160,9 +173,14 @@ class SVGDreamerPipeline(ModelState):
160
  optim_cfg=self.sive_optim,
161
  log_png_dir=self.bg_png_logs_dir,
162
  log_svg_dir=self.bg_svg_logs_dir)
163
- # foreground rendering
164
- self.print(f"-> foreground rendering: ")
165
- fg_render_path = self.component_rendering(tag=f'{i}_fg',
 
 
 
 
 
166
  prompt=text_prompt,
167
  target_img=fg_img,
168
  mask=fg_mask,
@@ -172,8 +190,16 @@ class SVGDreamerPipeline(ModelState):
172
  optim_cfg=self.sive_optim,
173
  log_png_dir=self.fg_png_logs_dir,
174
  log_svg_dir=self.fg_svg_logs_dir)
175
- # merge foreground and background
176
- merged_svg_path = self.result_path / f'SIVE_render_final_{i}.svg'
 
 
 
 
 
 
 
 
177
  merge_svg_files(
178
  svg_path_1=bg_render_path,
179
  svg_path_2=fg_render_path,
@@ -182,11 +208,11 @@ class SVGDreamerPipeline(ModelState):
182
  out_size=(self.im_size, self.im_size)
183
  )
184
 
185
- # foreground and background refinement
186
  # Note: you are not allowed to add further paths here
187
  if self.sive_cfg.tog.reinit:
188
- self.print("-> enable vector graphic refinement:")
189
- merged_svg_path = self.refine_rendering(tag=f'{i}_refine',
190
  prompt=text_prompt,
191
  target_img=select_img,
192
  canvas_size=(self.im_size, self.im_size),
@@ -194,22 +220,21 @@ class SVGDreamerPipeline(ModelState):
194
  optim_cfg=self.sive_optim,
195
  init_svg_path=merged_svg_path)
196
 
197
- # svg-to-png, to tensor
198
- merged_png_path = self.result_path / f'SIVE_render_final_{i}.png'
199
  cairosvg.svg2png(url=merged_svg_path.as_posix(), write_to=merged_png_path.as_posix())
200
-
201
- # collect paths
202
- merged_svg_paths.append(merged_svg_path)
203
  merged_images.append(self.target_file_preprocess(merged_png_path))
204
- # empty attention record
 
205
  controller.reset()
206
 
207
- self.print(f"Vector Particle {i} Rendering End...\n")
208
 
209
- # free the VRAM
210
  del pipeline
211
  torch.cuda.empty_cache()
212
- # update paths
213
  self.x_cfg.num_paths = self.sive_cfg.bg.num_paths + self.sive_cfg.fg.num_paths
214
 
215
  return merged_svg_paths, merged_images
@@ -257,6 +282,9 @@ class SVGDreamerPipeline(ModelState):
257
  if attention_map is not None:
258
  # init fist control points by attention_map
259
  attn_thresh, select_inds = renderer.attn_init_points(num_paths=sum(path_schedule), mask=mask)
 
 
 
260
  # log attention, just once
261
  plot_attn(attention_map, attn_thresh, target_img, select_inds,
262
  (self.sive_attn_dir / f"attention_{tag}_map.jpg").as_posix())
@@ -381,14 +409,16 @@ class SVGDreamerPipeline(ModelState):
381
  plot_img(img, self.refine_dir, fname=f"{tag}_before_refined")
382
 
383
  n_iter = render_cfg.num_iter
 
 
384
  # build painter optimizer
385
  optimizer = CompPainterOptimizer(content_renderer, self.style, n_iter, optim_cfg)
386
  # init optimizer
387
  optimizer.init_optimizers()
388
 
389
- print(f"=> n_point: {len(content_renderer.get_point_params())}, "
390
- f"n_width: {len(content_renderer.get_width_params())}, "
391
- f"n_color: {len(content_renderer.get_color_params())}")
392
 
393
  step = 0
394
  with tqdm(initial=step, total=n_iter, disable=not self.accelerator.is_main_process) as pbar:
@@ -434,7 +464,8 @@ class SVGDreamerPipeline(ModelState):
434
  text_prompt: AnyStr,
435
  init_svg_path: Union[List[AnyPath], AnyPath] = None,
436
  init_image: Union[List[torch.Tensor], torch.Tensor] = None):
437
- if not self.vpsd_cfg.use:
 
438
  return
439
 
440
  # for convenience
@@ -784,10 +815,12 @@ class SVGDreamerPipeline(ModelState):
784
  generator=self.g_device)
785
  outputs_np = [np.array(img) for img in outputs.images]
786
  view_images(outputs_np, save_image=True, fp=gen_sample_path)
787
- self.print(f"select_sample shape: {outputs_np[0].shape}")
788
 
789
  if attn_init:
790
- """ldm cross-attention map"""
 
 
791
  cross_attention_maps, tokens = \
792
  pipeline.get_cross_attention([prompts],
793
  controller,
@@ -862,7 +895,7 @@ class SVGDreamerPipeline(ModelState):
862
  view_images(reversed_attn_map_vis, save_image=True,
863
  fp=self.sive_attn_dir / f'reversed-fusion-attn-{iter}.png')
864
 
865
- self.print(f"-> fusion attn_map: {attn_map.shape}")
866
  else:
867
  attn_map = None
868
  inverse_attn = None
 
20
  from skimage.color import rgb2gray
21
 
22
  from svgdreamer.libs import ModelState, get_optimizer
23
+ from svgdreamer.painter import CompPainter, CompPainterOptimizer, xing_loss_fn, Painter, PainterOptimizer, \
24
+ CosineWithWarmupLRLambda, VectorizedParticleSDSPipeline, DiffusionPipeline
25
  from svgdreamer.token2attn.attn_control import EmptyControl, AttentionStore
26
  from svgdreamer.token2attn.ptp_utils import view_images
27
  from svgdreamer.utils.plot import plot_img, plot_couple, plot_attn, save_image
 
38
  # assert
39
  assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"]
40
  args.skip_sive = True if args.x.style in ["pixelart", "low-poly"] else args.skip_sive
41
+ # assert args.x.vpsd.n_particle >= args.x.vpsd.vsd_n_particle
42
+ if args.x.vpsd.vsd_n_particle > args.x.vpsd.n_particle: args.x.vpsd.vsd_n_particle = args.x.vpsd.n_particle
43
+ # assert args.x.vpsd.n_particle >= args.x.vpsd.phi_n_particle
44
+ if args.x.vpsd.phi_n_particle > args.x.vpsd.n_particle: args.x.vpsd.phi_n_particle = args.x.vpsd.n_particle
45
  assert args.x.vpsd.n_phi_sample >= 1
46
 
47
  logdir_ = f"sd{args.seed}" \
 
125
  self.close(msg="painterly rendering complete.")
126
 
127
  def SIVE_stage(self, text_prompt: str):
128
+ # Init diffusion model
129
  pipeline = DiffusionPipeline(self.x_cfg.sive_model_cfg, self.args.diffuser, self.device)
130
 
131
  merged_svg_paths = []
132
  merged_images = []
133
+
134
+ successful_particles = 0
135
+ cur_idx = 0
136
+
137
+ while successful_particles < self.vpsd_cfg.n_particle:
138
+ if cur_idx >= self.vpsd_cfg.n_particle + 10: # max attempts
139
+ self.print(f"Reached maximum attempts ({cur_idx}). "
140
+ f"Only processed {successful_particles} particles successfully.")
141
+ break
142
+
143
+ self.print(f"Processing particle {cur_idx} "
144
+ f"(successful so far: {successful_particles}/{self.vpsd_cfg.n_particle})")
145
+ select_sample_path = self.result_path / f'select_sample_{cur_idx}.png'
146
+ # Generate sample and attention map
147
+ fg_attn_map, bg_attn_map, controller = self.extract_ldm_attn(cur_idx,
148
  self.x_cfg.sive_model_cfg,
149
  pipeline,
150
  text_prompt,
 
152
  self.sive_cfg.attn_cfg,
153
  self.im_size,
154
  self.args.token_ind)
155
+ # Load selected file
156
  select_img = self.target_file_preprocess(select_sample_path.as_posix())
157
  self.print(f"load target file from: {select_sample_path.as_posix()}")
158
 
159
+ # Get objects by attention map
160
+ fg_img, bg_img, fg_mask, bg_mask = self.extract_object(cur_idx, select_img, fg_attn_map, bg_attn_map,
161
  tau=self.sive_cfg.mask_tau)
162
+ # self.print(f"fg_img shape: {fg_img.shape}, bg_img: {bg_img.shape}")
163
 
164
+ # Background rendering
165
+ self.print(f"-> Background rendering: ")
166
+ bg_render_path = self.component_rendering(tag=f'{cur_idx}_bg',
167
  prompt=text_prompt,
168
  target_img=bg_img,
169
  mask=bg_mask,
 
173
  optim_cfg=self.sive_optim,
174
  log_png_dir=self.bg_png_logs_dir,
175
  log_svg_dir=self.bg_svg_logs_dir)
176
+ if bg_render_path == 0:
177
+ self.print(f"Background rendering failed for particle {cur_idx}, trying next particle")
178
+ cur_idx += 1
179
+ continue
180
+
181
+ # Foreground rendering
182
+ self.print(f"-> Foreground rendering: ")
183
+ fg_render_path = self.component_rendering(tag=f'{cur_idx}_fg',
184
  prompt=text_prompt,
185
  target_img=fg_img,
186
  mask=fg_mask,
 
190
  optim_cfg=self.sive_optim,
191
  log_png_dir=self.fg_png_logs_dir,
192
  log_svg_dir=self.fg_svg_logs_dir)
193
+ if fg_render_path == 0:
194
+ self.print(f"Foreground rendering failed for particle {cur_idx}, trying next particle")
195
+ cur_idx += 1
196
+ continue
197
+
198
+ successful_particles += 1
199
+ cur_idx += 1
200
+
201
+ # Merge foreground and background
202
+ merged_svg_path = self.result_path / f'SIVE_render_final_{cur_idx}.svg'
203
  merge_svg_files(
204
  svg_path_1=bg_render_path,
205
  svg_path_2=fg_render_path,
 
208
  out_size=(self.im_size, self.im_size)
209
  )
210
 
211
+ # Foreground and background refinement
212
  # Note: you are not allowed to add further paths here
213
  if self.sive_cfg.tog.reinit:
214
+ self.print("-> Enable vector graphic refinement:")
215
+ merged_svg_path = self.refine_rendering(tag=f'{cur_idx}_refine',
216
  prompt=text_prompt,
217
  target_img=select_img,
218
  canvas_size=(self.im_size, self.im_size),
 
220
  optim_cfg=self.sive_optim,
221
  init_svg_path=merged_svg_path)
222
 
223
+ # Postprocess: svg-to-png & to tensor
224
+ merged_png_path = self.result_path / f'SIVE_render_final_{cur_idx}.png'
225
  cairosvg.svg2png(url=merged_svg_path.as_posix(), write_to=merged_png_path.as_posix())
226
+ merged_svg_paths.append(merged_svg_path) # collect paths
 
 
227
  merged_images.append(self.target_file_preprocess(merged_png_path))
228
+
229
+ # Clear attention recorder
230
  controller.reset()
231
 
232
+ self.print(f"Vector Particle {cur_idx} Rendering End...\n")
233
 
234
+ # Free the VRAM
235
  del pipeline
236
  torch.cuda.empty_cache()
237
+ # Update paths
238
  self.x_cfg.num_paths = self.sive_cfg.bg.num_paths + self.sive_cfg.fg.num_paths
239
 
240
  return merged_svg_paths, merged_images
 
282
  if attention_map is not None:
283
  # init fist control points by attention_map
284
  attn_thresh, select_inds = renderer.attn_init_points(num_paths=sum(path_schedule), mask=mask)
285
+ # Warning: attention map failure
286
+ if len(select_inds) == 0: return 0
287
+
288
  # log attention, just once
289
  plot_attn(attention_map, attn_thresh, target_img, select_inds,
290
  (self.sive_attn_dir / f"attention_{tag}_map.jpg").as_posix())
 
409
  plot_img(img, self.refine_dir, fname=f"{tag}_before_refined")
410
 
411
  n_iter = render_cfg.num_iter
412
+ self.print(f"Total iters: {n_iter}")
413
+
414
  # build painter optimizer
415
  optimizer = CompPainterOptimizer(content_renderer, self.style, n_iter, optim_cfg)
416
  # init optimizer
417
  optimizer.init_optimizers()
418
 
419
+ self.print(f"=> n_point: {len(content_renderer.get_point_params())}, "
420
+ f"n_width: {len(content_renderer.get_width_params())}, "
421
+ f"n_color: {len(content_renderer.get_color_params())}")
422
 
423
  step = 0
424
  with tqdm(initial=step, total=n_iter, disable=not self.accelerator.is_main_process) as pbar:
 
464
  text_prompt: AnyStr,
465
  init_svg_path: Union[List[AnyPath], AnyPath] = None,
466
  init_image: Union[List[torch.Tensor], torch.Tensor] = None):
467
+ # print(f"self.vpsd_cfg.use: {self.vpsd_cfg.use}")
468
+ if self.vpsd_cfg.use is False:
469
  return
470
 
471
  # for convenience
 
815
  generator=self.g_device)
816
  outputs_np = [np.array(img) for img in outputs.images]
817
  view_images(outputs_np, save_image=True, fp=gen_sample_path)
818
+ # self.print(f"select_sample shape: {outputs_np[0].shape}")
819
 
820
  if attn_init:
821
+ self.print(f"\nLDM attn-map logging:")
822
+
823
+ # Cross-attention map
824
  cross_attention_maps, tokens = \
825
  pipeline.get_cross_attention([prompts],
826
  controller,
 
895
  view_images(reversed_attn_map_vis, save_image=True,
896
  fp=self.sive_attn_dir / f'reversed-fusion-attn-{iter}.png')
897
 
898
+ self.print(f"-> fusion attn_map: {attn_map.shape} \n")
899
  else:
900
  attn_map = None
901
  inverse_attn = None
svgdreamer/token2attn/attn_control.py CHANGED
@@ -85,7 +85,7 @@ class AttentionStore(AttentionControl):
85
  self.step_store = self.get_empty_store()
86
 
87
  def get_average_attention(self):
88
- print(f"step count: {self.cur_step}")
89
  average_attention = {
90
  key: [item / self.cur_step for item in self.attention_store[key]]
91
  for key in self.attention_store
 
85
  self.step_store = self.get_empty_store()
86
 
87
  def get_average_attention(self):
88
+ # print(f"step count: {self.cur_step}")
89
  average_attention = {
90
  key: [item / self.cur_step for item in self.attention_store[key]]
91
  for key in self.attention_store