xingxm commited on
Commit
55f4606
·
1 Parent(s): 2f6bc1f

style(method): rename variables

Browse files
svgdreamer.py CHANGED
@@ -20,8 +20,8 @@ from svgdreamer.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline
20
  @hydra.main(version_base=None, config_path="conf", config_name='config')
21
  def main(cfg: omegaconf.DictConfig):
22
  """
23
- The project configuration is stored in './conf/config.yaml
24
- And style configurations are stored in './conf/x/iconographic.yaml
25
  """
26
 
27
  # set seed
 
20
  @hydra.main(version_base=None, config_path="conf", config_name='config')
21
  def main(cfg: omegaconf.DictConfig):
22
  """
23
+ The project configuration is stored in './conf/config.yaml'
24
+ And style configurations are stored in './conf/x/iconographic.yaml'
25
  """
26
 
27
  # set seed
svgdreamer/painter/VPSD_pipeline.py CHANGED
@@ -3,12 +3,9 @@
3
  # Author: XiMing Xing
4
  # Description:
5
  import re
6
- import PIL
7
- from PIL import Image
8
  from typing import Any, List, Optional, Union, Dict
9
  from omegaconf import DictConfig
10
 
11
- import numpy as np
12
  import torch
13
  import torch.nn.functional as F
14
  from torchvision import transforms
 
3
  # Author: XiMing Xing
4
  # Description:
5
  import re
 
 
6
  from typing import Any, List, Optional, Union, Dict
7
  from omegaconf import DictConfig
8
 
 
9
  import torch
10
  import torch.nn.functional as F
11
  from torchvision import transforms
svgdreamer/pipelines/SVGDreamer_pipeline.py CHANGED
@@ -439,14 +439,14 @@ class SVGDreamerPipeline(ModelState):
439
 
440
  # for convenience
441
  guidance_cfg = self.x_cfg.vpsd
442
- vpsd_model_cfg = self.x_cfg.vpsd_model_cfg
443
  n_particle = guidance_cfg.n_particle
444
  total_step = guidance_cfg.num_iter
445
  path_reinit = self.x_cfg.path_reinit
446
 
447
  # init VPSD
448
- pipeline = VectorizedParticleSDSPipeline(vpsd_model_cfg, self.args.diffuser, guidance_cfg,
449
- self.device, self.args.state.mprec)
450
  # init reward model
451
  reward_model = None
452
  if guidance_cfg.phi_ReFL:
@@ -529,7 +529,7 @@ class SVGDreamerPipeline(ModelState):
529
  negative_prompt=self.args.neg_prompt,
530
  grad_scale=guidance_cfg.grad_scale,
531
  enhance_particle=guidance_cfg.particle_aug,
532
- im_size=model2res(vpsd_model_cfg.model_id)
533
  )
534
 
535
  # Xing Loss for Self-Interaction Problem
 
439
 
440
  # for convenience
441
  guidance_cfg = self.x_cfg.vpsd
442
+ sd_model_cfg = self.x_cfg.vpsd_model_cfg
443
  n_particle = guidance_cfg.n_particle
444
  total_step = guidance_cfg.num_iter
445
  path_reinit = self.x_cfg.path_reinit
446
 
447
  # init VPSD
448
+ pipeline = VectorizedParticleSDSPipeline(
449
+ sd_model_cfg, self.args.diffuser, guidance_cfg, self.device, self.weight_dtype)
450
  # init reward model
451
  reward_model = None
452
  if guidance_cfg.phi_ReFL:
 
529
  negative_prompt=self.args.neg_prompt,
530
  grad_scale=guidance_cfg.grad_scale,
531
  enhance_particle=guidance_cfg.particle_aug,
532
+ im_size=model2res(sd_model_cfg.model_id)
533
  )
534
 
535
  # Xing Loss for Self-Interaction Problem