feat(build): add install script and docker image
Browse files- README.md +69 -11
- assets/Painting-Elephant/init_p0.svg +0 -0
- assets/Painting-Elephant/init_p1.svg +0 -0
- assets/Painting-Elephant/init_p2.svg +0 -0
- assets/Painting-Elephant/init_p3.svg +0 -0
- assets/Painting-Elephant/init_p4.svg +0 -0
- assets/Painting-Elephant/init_p5.svg +0 -0
- assets/Painting-Elephant/p_0.svg +0 -0
- assets/Painting-Elephant/p_1.svg +0 -0
- assets/Painting-Elephant/p_2.svg +0 -0
- assets/Painting-Elephant/p_3.svg +0 -0
- assets/Painting-Elephant/p_4.svg +0 -0
- assets/Painting-Elephant/p_5.svg +0 -0
- assets/Pixelart-DarthVader/init_p0.svg +0 -0
- assets/Pixelart-DarthVader/init_p1.svg +0 -0
- assets/Pixelart-DarthVader/init_p2.svg +0 -0
- assets/Pixelart-DarthVader/init_p3.svg +0 -0
- assets/Pixelart-DarthVader/init_p4.svg +0 -0
- assets/Pixelart-DarthVader/init_p5.svg +0 -0
- assets/Pixelart-DarthVader/p0.svg +0 -0
- assets/Pixelart-DarthVader/p1.svg +0 -0
- assets/Pixelart-DarthVader/p2.svg +0 -0
- assets/Pixelart-DarthVader/p3.svg +0 -0
- assets/Pixelart-DarthVader/p4.svg +0 -0
- assets/Pixelart-DarthVader/p5.svg +0 -0
- assets/SIVE-astronaut-1/attn.png +0 -0
- assets/SIVE-astronaut-1/final_bg.svg +0 -0
- assets/SIVE-astronaut-1/final_fg.svg +0 -0
- assets/SIVE-astronaut-1/init_bg.svg +134 -0
- assets/SIVE-astronaut-1/init_fg.svg +134 -0
- assets/SIVE-astronaut-1/result.svg +0 -0
- conf/x/{iconography_s1.yaml → iconography-s1.yaml} +2 -1
- conf/x/iconography.yaml +1 -6
- conf/x/ink.yaml +1 -6
- conf/x/lowpoly.yaml +1 -6
- conf/x/painting.yaml +1 -6
- conf/x/pixelart.yaml +2 -7
- conf/x/sketch.yaml +1 -6
- script/install.sh +47 -0
- svgdreamer/painter/painter_params.py +9 -8
- svgdreamer/pipelines/SVGDreamer_pipeline.py +12 -13
README.md
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
[](https://arxiv.org/abs/2312.16476)
|
4 |
[](https://arxiv.org/abs/2312.16476)
|
5 |
[](https://ximinng.github.io/SVGDreamer-project/)
|
6 |
-
[](https://huggingface.co/blog/xingxm/svgdreamer)
|
8 |
|
9 |
This repository contains our official implementation of the CVPR 2024 paper: SVGDreamer: Text-Guided SVG Generation with
|
@@ -20,12 +20,37 @@ Diffusion Model. It can generate high-quality SVGs based on text prompts.
|
|
20 |
a novel text-guided vector graphics synthesis method. This method considers both the editing of vector graphics and
|
21 |
the quality of the synthesis.
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
## 🔥 Quickstart
|
24 |
|
25 |
Before running the code, download the stable diffusion model. Append `diffuser.download=True` to the end of the script.
|
26 |
|
27 |
### SIVE + VPSD
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
**Script:**
|
30 |
|
31 |
```shell
|
@@ -52,18 +77,17 @@ python svgdreamer.py x=iconography_s1 skip_sive=False "prompt='a man in an astro
|
|
52 |
|
53 |
### VPSD
|
54 |
|
55 |
-
####
|
56 |
|
57 |
**Prompt:** Sydney opera house. oil painting. by Van Gogh <br/>
|
58 |
-
**Style:** iconography <br/>
|
59 |
**Preview:**
|
60 |
|
61 |
| Particle 1 | Particle 2 | Particle 3 | Particle 4 | Particle 5 | Particle 6 |
|
62 |
|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|
|
63 |
-
| init p1
|
64 |
| <img src="./assets/Icon-SydneyOperaHouse/init_p0.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p1.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p2.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p3.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p4.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p5.svg"> |
|
65 |
| final p1 | final p2 | final p3 | final p4 | final p5 | final p6 |
|
66 |
-
| <img src="./assets/Icon-SydneyOperaHouse/p_0.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_1.svg">
|
67 |
|
68 |
**Script:**
|
69 |
|
@@ -71,19 +95,53 @@ python svgdreamer.py x=iconography_s1 skip_sive=False "prompt='a man in an astro
|
|
71 |
python svgdreamer.py x=iconography "prompt='Sydney opera house. oil painting. by Van Gogh'" result_path='./logs/SydneyOperaHouse-OilPainting'
|
72 |
```
|
73 |
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
```shell
|
77 |
# Style: low-ploy
|
78 |
python svgdreamer.py x=lowpoly "prompt='A picture of a bald eagle. low-ploy. polygon'" result_path='./logs/BaldEagle'
|
79 |
-
# Style: pixel-art
|
80 |
-
python svgdreamer.py x=pixelart "prompt='Darth vader with lightsaber.'" result_path='./log/DarthVader'
|
81 |
-
# Style: painting
|
82 |
-
python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" result_path='./logs/VanGogh-Portrait'
|
83 |
# Style: sketch
|
84 |
python svgdreamer.py x=sketch "prompt='A free-hand drawing of A speeding Lamborghini. black and white drawing.'" result_path='./logs/Lamborghini'
|
85 |
# Style: ink and wash
|
86 |
python svgdreamer.py x=ink "prompt='Big Wild Goose Pagoda. ink style. Minimalist abstract art grayscale watercolor.'" result_path='./logs/BigWildGoosePagoda'
|
|
|
|
|
87 |
```
|
88 |
|
89 |
## 🔑 Tips
|
@@ -94,7 +152,7 @@ python svgdreamer.py x=ink "prompt='Big Wild Goose Pagoda. ink style. Minimalist
|
|
94 |
## 📋 TODO
|
95 |
|
96 |
- [x] Release the code
|
97 |
-
- [
|
98 |
|
99 |
## :books: Acknowledgement
|
100 |
|
|
|
3 |
[](https://arxiv.org/abs/2312.16476)
|
4 |
[](https://arxiv.org/abs/2312.16476)
|
5 |
[](https://ximinng.github.io/SVGDreamer-project/)
|
6 |
+
[](https://huggingface.co/blog/xingxm/svgdreamer)
|
7 |
[](https://huggingface.co/blog/xingxm/svgdreamer)
|
8 |
|
9 |
This repository contains our official implementation of the CVPR 2024 paper: SVGDreamer: Text-Guided SVG Generation with
|
|
|
20 |
a novel text-guided vector graphics synthesis method. This method considers both the editing of vector graphics and
|
21 |
the quality of the synthesis.
|
22 |
|
23 |
+
## Installation
|
24 |
+
|
25 |
+
You can follow the steps below to quickly get up and running with SVGDreamer.
|
26 |
+
These steps will let you run quick inference locally.
|
27 |
+
|
28 |
+
In the top level directory run,
|
29 |
+
|
30 |
+
```bash
|
31 |
+
sh script/install.sh
|
32 |
+
```
|
33 |
+
|
34 |
+
or using docker images,
|
35 |
+
|
36 |
+
```shell
|
37 |
+
docker run --name svgdreamer --gpus all -it --ipc=host ximingxing/svgrender:v1 /bin/bash
|
38 |
+
```
|
39 |
+
|
40 |
## 🔥 Quickstart
|
41 |
|
42 |
Before running the code, download the stable diffusion model. Append `diffuser.download=True` to the end of the script.
|
43 |
|
44 |
### SIVE + VPSD
|
45 |
|
46 |
+
**Prompt:** An image of Batman. full body action pose, complete detailed body. white background. empty background, high
|
47 |
+
quality, 4K, ultra realistic <br/>
|
48 |
+
**Preview:**
|
49 |
+
|
50 |
+
| attn-map | bg init | fg init | bg final | fg final | final |
|
51 |
+
|------------------------------------------------|---------------------------------------------------|---------------------------------------------------|----------------------------------------------------|----------------------------------------------------|--------------------------------------------------|
|
52 |
+
| <img src="./assets/SIVE-astronaut-1/attn.png"> | <img src="./assets/SIVE-astronaut-1/init_bg.svg"> | <img src="./assets/SIVE-astronaut-1/init_fg.svg"> | <img src="./assets/SIVE-astronaut-1/final_bg.svg"> | <img src="./assets/SIVE-astronaut-1/final_fg.svg"> | <img src="./assets/SIVE-astronaut-1/result.svg"> |
|
53 |
+
|
54 |
**Script:**
|
55 |
|
56 |
```shell
|
|
|
77 |
|
78 |
### VPSD
|
79 |
|
80 |
+
#### Iconography style
|
81 |
|
82 |
**Prompt:** Sydney opera house. oil painting. by Van Gogh <br/>
|
|
|
83 |
**Preview:**
|
84 |
|
85 |
| Particle 1 | Particle 2 | Particle 3 | Particle 4 | Particle 5 | Particle 6 |
|
86 |
|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|
|
87 |
+
| randomly init p1 | randomly init p2 | randomly init p3 | randomly init p4 | randomly init p5 | randomly init p6 |
|
88 |
| <img src="./assets/Icon-SydneyOperaHouse/init_p0.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p1.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p2.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p3.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p4.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p5.svg"> |
|
89 |
| final p1 | final p2 | final p3 | final p4 | final p5 | final p6 |
|
90 |
+
| <img src="./assets/Icon-SydneyOperaHouse/p_0.svg"> | <img src="./assets/Icon-SydneyOperaHouse/p_1.svg"> | <img src="./assets/Icon-SydneyOperaHouse/p_2.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_3.svg"> | <img src="./assets/Icon-SydneyOperaHouse/p_4.svg"> | <img src="./assets/Icon-SydneyOperaHouse/p_5.svg"> |
|
91 |
|
92 |
**Script:**
|
93 |
|
|
|
95 |
python svgdreamer.py x=iconography "prompt='Sydney opera house. oil painting. by Van Gogh'" result_path='./logs/SydneyOperaHouse-OilPainting'
|
96 |
```
|
97 |
|
98 |
+
#### Painting style
|
99 |
+
|
100 |
+
**Prompt:** Abstract Vincent van Gogh Oil Painting Elephant, featuring earthy tones of green and brown <br/>
|
101 |
+
**Preview:**
|
102 |
+
|
103 |
+
| Particle 1 | Particle 2 | Particle 3 | Particle 4 | Particle 5 | Particle 6 |
|
104 |
+
|----------------------------------------------------|----------------------------------------------------|----------------------------------------------------|----------------------------------------------------|----------------------------------------------------|----------------------------------------------------|
|
105 |
+
| randomly init p1 | randomly init p2 | randomly init p3 | randomly init p4 | randomly init p5 | randomly init p6 |
|
106 |
+
| <img src="./assets/Painting-Elephant/init_p0.svg"> | <img src="./assets/Painting-Elephant/init_p1.svg"> | <img src="./assets/Painting-Elephant/init_p2.svg"> | <img src="./assets/Painting-Elephant/init_p3.svg"> | <img src="./assets/Painting-Elephant/init_p4.svg"> | <img src="./assets/Painting-Elephant/init_p5.svg"> |
|
107 |
+
| final p1 | final p2 | final p3 | final p4 | final p5 | final p6 |
|
108 |
+
| <img src="./assets/Painting-Elephant/p_0.svg"> | <img src="./assets/Painting-Elephant/p_1.svg"> | <img src="./assets/Painting-Elephant/p_2.svg"> | <img src="./assets/Painting-Elephant/p_3.svg"> | <img src="./assets/Painting-Elephant/p_4.svg"> | <img src="./assets/Painting-Elephant/p_5.svg"> |
|
109 |
+
|
110 |
+
**Script:**
|
111 |
+
|
112 |
+
```shell
|
113 |
+
python svgdreamer.py x=painting "prompt='Abstract Vincent van Gogh Oil Painting Elephant, featuring earthy tones of green and brown.'" x.num_paths=500 result_path='./logs/Elephant-OilPainting'
|
114 |
+
```
|
115 |
+
|
116 |
+
#### Pixel-Art style
|
117 |
+
|
118 |
+
**Prompt:** Darth vader with lightsaber <br/>
|
119 |
+
**Preview:**
|
120 |
+
|
121 |
+
| Particle 1 | Particle 2 | Particle 3 | Particle 4 | Particle 5 | Particle 6 |
|
122 |
+
|------------------------------------------------------|------------------------------------------------------|------------------------------------------------------|------------------------------------------------------|------------------------------------------------------|------------------------------------------------------|
|
123 |
+
| randomly init p1 | randomly init p2 | randomly init p3 | randomly init p4 | randomly init p5 | randomly init p6 |
|
124 |
+
| <img src="./assets/Pixelart-DarthVader/init_p0.svg"> | <img src="./assets/Pixelart-DarthVader/init_p1.svg"> | <img src="./assets/Pixelart-DarthVader/init_p2.svg"> | <img src="./assets/Pixelart-DarthVader/init_p3.svg"> | <img src="./assets/Pixelart-DarthVader/init_p4.svg"> | <img src="./assets/Pixelart-DarthVader/init_p5.svg"> |
|
125 |
+
| final p1 | final p2 | final p3 | final p4 | final p5 | final p6 |
|
126 |
+
| <img src="./assets/Pixelart-DarthVader/p0.svg"> | <img src="./assets/Pixelart-DarthVader/p1.svg"> | <img src="./assets/Pixelart-DarthVader/p2.svg"> | <img src="./assets/Pixelart-DarthVader/p3.svg"> | <img src="./assets/Pixelart-DarthVader/p4.svg"> | <img src="./assets/Pixelart-DarthVader/p5.svg"> |
|
127 |
+
|
128 |
+
**Script:**
|
129 |
+
|
130 |
+
```shell
|
131 |
+
python svgdreamer.py x=pixelart "prompt='Darth vader with lightsaber.'" result_path='./logs/DarthVader'
|
132 |
+
```
|
133 |
+
|
134 |
+
#### Other Styles
|
135 |
|
136 |
```shell
|
137 |
# Style: low-ploy
|
138 |
python svgdreamer.py x=lowpoly "prompt='A picture of a bald eagle. low-ploy. polygon'" result_path='./logs/BaldEagle'
|
|
|
|
|
|
|
|
|
139 |
# Style: sketch
|
140 |
python svgdreamer.py x=sketch "prompt='A free-hand drawing of A speeding Lamborghini. black and white drawing.'" result_path='./logs/Lamborghini'
|
141 |
# Style: ink and wash
|
142 |
python svgdreamer.py x=ink "prompt='Big Wild Goose Pagoda. ink style. Minimalist abstract art grayscale watercolor.'" result_path='./logs/BigWildGoosePagoda'
|
143 |
+
# Style: painting
|
144 |
+
python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" result_path='./logs/VanGogh-Portrait'
|
145 |
```
|
146 |
|
147 |
## 🔑 Tips
|
|
|
152 |
## 📋 TODO
|
153 |
|
154 |
- [x] Release the code
|
155 |
+
- [x] Add docker image
|
156 |
|
157 |
## :books: Acknowledgement
|
158 |
|
assets/Painting-Elephant/init_p0.svg
ADDED
|
assets/Painting-Elephant/init_p1.svg
ADDED
|
assets/Painting-Elephant/init_p2.svg
ADDED
|
assets/Painting-Elephant/init_p3.svg
ADDED
|
assets/Painting-Elephant/init_p4.svg
ADDED
|
assets/Painting-Elephant/init_p5.svg
ADDED
|
assets/Painting-Elephant/p_0.svg
ADDED
|
assets/Painting-Elephant/p_1.svg
ADDED
|
assets/Painting-Elephant/p_2.svg
ADDED
|
assets/Painting-Elephant/p_3.svg
ADDED
|
assets/Painting-Elephant/p_4.svg
ADDED
|
assets/Painting-Elephant/p_5.svg
ADDED
|
assets/Pixelart-DarthVader/init_p0.svg
ADDED
|
assets/Pixelart-DarthVader/init_p1.svg
ADDED
|
assets/Pixelart-DarthVader/init_p2.svg
ADDED
|
assets/Pixelart-DarthVader/init_p3.svg
ADDED
|
assets/Pixelart-DarthVader/init_p4.svg
ADDED
|
assets/Pixelart-DarthVader/init_p5.svg
ADDED
|
assets/Pixelart-DarthVader/p0.svg
ADDED
|
assets/Pixelart-DarthVader/p1.svg
ADDED
|
assets/Pixelart-DarthVader/p2.svg
ADDED
|
assets/Pixelart-DarthVader/p3.svg
ADDED
|
assets/Pixelart-DarthVader/p4.svg
ADDED
|
assets/Pixelart-DarthVader/p5.svg
ADDED
|
assets/SIVE-astronaut-1/attn.png
ADDED
![]() |
assets/SIVE-astronaut-1/final_bg.svg
ADDED
|
assets/SIVE-astronaut-1/final_fg.svg
ADDED
|
assets/SIVE-astronaut-1/init_bg.svg
ADDED
|
assets/SIVE-astronaut-1/init_fg.svg
ADDED
|
assets/SIVE-astronaut-1/result.svg
ADDED
|
conf/x/{iconography_s1.yaml → iconography-s1.yaml}
RENAMED
@@ -38,6 +38,7 @@ sive:
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
|
|
41 |
bg:
|
42 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
num_iter: 500
|
@@ -131,7 +132,7 @@ vpsd_model_cfg:
|
|
131 |
vpsd:
|
132 |
use: False
|
133 |
type: 'vpsd'
|
134 |
-
n_particle:
|
135 |
vsd_n_particle: 4 # the batch size of particles
|
136 |
particle_aug: False # do data enhancement for the input particles
|
137 |
num_iter: 1 # total iterations
|
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
42 |
bg:
|
43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
44 |
num_iter: 500
|
|
|
132 |
vpsd:
|
133 |
use: False
|
134 |
type: 'vpsd'
|
135 |
+
n_particle: 6 # 4, 8, 16
|
136 |
vsd_n_particle: 4 # the batch size of particles
|
137 |
particle_aug: False # do data enhancement for the input particles
|
138 |
num_iter: 1 # total iterations
|
conf/x/iconography.yaml
CHANGED
@@ -38,6 +38,7 @@ sive:
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
|
|
41 |
bg:
|
42 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
num_iter: 10
|
@@ -81,12 +82,6 @@ sive:
|
|
81 |
tog: # for refinement
|
82 |
reinit: False # if False, use fg params to init content
|
83 |
num_iter: 10
|
84 |
-
# optim
|
85 |
-
lr_schedule: False # enable lr_scheduler or not
|
86 |
-
# loss
|
87 |
-
bg_lam: 0
|
88 |
-
fg_lam: 1
|
89 |
-
xing_loss_weight: 0
|
90 |
|
91 |
# VPSD primitives
|
92 |
num_paths: 512 # number of strokes
|
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
42 |
bg:
|
43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
44 |
num_iter: 10
|
|
|
82 |
tog: # for refinement
|
83 |
reinit: False # if False, use fg params to init content
|
84 |
num_iter: 10
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# VPSD primitives
|
87 |
num_paths: 512 # number of strokes
|
conf/x/ink.yaml
CHANGED
@@ -38,6 +38,7 @@ sive:
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
|
|
41 |
bg:
|
42 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
num_iter: 10
|
@@ -81,12 +82,6 @@ sive:
|
|
81 |
tog: # for refinement
|
82 |
reinit: False # if False, use fg params to init content
|
83 |
num_iter: 10
|
84 |
-
# optim
|
85 |
-
lr_schedule: False # enable lr_scheduler or not
|
86 |
-
# loss
|
87 |
-
bg_lam: 0
|
88 |
-
fg_lam: 1
|
89 |
-
xing_loss_weight: 0
|
90 |
|
91 |
# VPSD primitives
|
92 |
num_paths: 128 # number of strokes
|
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
42 |
bg:
|
43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
44 |
num_iter: 10
|
|
|
82 |
tog: # for refinement
|
83 |
reinit: False # if False, use fg params to init content
|
84 |
num_iter: 10
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# VPSD primitives
|
87 |
num_paths: 128 # number of strokes
|
conf/x/lowpoly.yaml
CHANGED
@@ -38,6 +38,7 @@ sive:
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
|
|
41 |
bg:
|
42 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
num_iter: 10
|
@@ -81,12 +82,6 @@ sive:
|
|
81 |
tog: # for refinement
|
82 |
reinit: False # if False, use fg params to init content
|
83 |
num_iter: 10
|
84 |
-
# optim
|
85 |
-
lr_schedule: False # enable lr_scheduler or not
|
86 |
-
# loss
|
87 |
-
bg_lam: 0
|
88 |
-
fg_lam: 1
|
89 |
-
xing_loss_weight: 0
|
90 |
|
91 |
# VPSD primitives
|
92 |
num_paths: 512 # number of strokes
|
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
42 |
bg:
|
43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
44 |
num_iter: 10
|
|
|
82 |
tog: # for refinement
|
83 |
reinit: False # if False, use fg params to init content
|
84 |
num_iter: 10
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# VPSD primitives
|
87 |
num_paths: 512 # number of strokes
|
conf/x/painting.yaml
CHANGED
@@ -38,6 +38,7 @@ sive:
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
|
|
41 |
bg:
|
42 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
num_iter: 10
|
@@ -81,12 +82,6 @@ sive:
|
|
81 |
tog: # for refinement
|
82 |
reinit: False # if False, use fg params to init content
|
83 |
num_iter: 10
|
84 |
-
# optim
|
85 |
-
lr_schedule: False # enable lr_scheduler or not
|
86 |
-
# loss
|
87 |
-
bg_lam: 0
|
88 |
-
fg_lam: 1
|
89 |
-
xing_loss_weight: 0
|
90 |
|
91 |
# VPSD primitives
|
92 |
num_paths: 1500 # number of strokes
|
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
42 |
bg:
|
43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
44 |
num_iter: 10
|
|
|
82 |
tog: # for refinement
|
83 |
reinit: False # if False, use fg params to init content
|
84 |
num_iter: 10
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# VPSD primitives
|
87 |
num_paths: 1500 # number of strokes
|
conf/x/pixelart.yaml
CHANGED
@@ -38,6 +38,7 @@ sive:
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
|
|
41 |
bg:
|
42 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
num_iter: 10
|
@@ -81,12 +82,6 @@ sive:
|
|
81 |
tog: # for refinement
|
82 |
reinit: False # if False, use fg params to init content
|
83 |
num_iter: 10
|
84 |
-
# optim
|
85 |
-
lr_schedule: False # enable lr_scheduler or not
|
86 |
-
# loss
|
87 |
-
bg_lam: 0
|
88 |
-
fg_lam: 1
|
89 |
-
xing_loss_weight: 0
|
90 |
|
91 |
# VPSD primitives
|
92 |
num_paths: 512 # number of strokes
|
@@ -110,7 +105,7 @@ vpsd_stage_optim:
|
|
110 |
width: 0.1
|
111 |
color: 0.01
|
112 |
bg: 0.01
|
113 |
-
lr_schedule:
|
114 |
optim:
|
115 |
name: 'adam'
|
116 |
betas: [ 0.9, 0.9 ]
|
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
42 |
bg:
|
43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
44 |
num_iter: 10
|
|
|
82 |
tog: # for refinement
|
83 |
reinit: False # if False, use fg params to init content
|
84 |
num_iter: 10
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# VPSD primitives
|
87 |
num_paths: 512 # number of strokes
|
|
|
105 |
width: 0.1
|
106 |
color: 0.01
|
107 |
bg: 0.01
|
108 |
+
lr_schedule: False
|
109 |
optim:
|
110 |
name: 'adam'
|
111 |
betas: [ 0.9, 0.9 ]
|
conf/x/sketch.yaml
CHANGED
@@ -38,6 +38,7 @@ sive:
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
|
|
41 |
bg:
|
42 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
num_iter: 10
|
@@ -81,12 +82,6 @@ sive:
|
|
81 |
tog: # for refinement
|
82 |
reinit: False # if False, use fg params to init content
|
83 |
num_iter: 10
|
84 |
-
# optim
|
85 |
-
lr_schedule: False # enable lr_scheduler or not
|
86 |
-
# loss
|
87 |
-
bg_lam: 0
|
88 |
-
fg_lam: 1
|
89 |
-
xing_loss_weight: 0
|
90 |
|
91 |
# VPSD primitives
|
92 |
num_paths: 128 # number of strokes
|
|
|
38 |
mean_comp: False
|
39 |
comp_idx: 0
|
40 |
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
42 |
bg:
|
43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
44 |
num_iter: 10
|
|
|
82 |
tog: # for refinement
|
83 |
reinit: False # if False, use fg params to init content
|
84 |
num_iter: 10
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# VPSD primitives
|
87 |
num_paths: 128 # number of strokes
|
script/install.sh
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
eval "$(conda shell.bash hook)"
|
3 |
+
|
4 |
+
conda create --name svgrender python=3.10
|
5 |
+
conda activate svgrender
|
6 |
+
|
7 |
+
echo "The conda environment was successfully created"
|
8 |
+
|
9 |
+
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
10 |
+
|
11 |
+
echo "Pytorch installation is complete. version: 1.12.1"
|
12 |
+
|
13 |
+
pip install hydra-core omegaconf
|
14 |
+
pip install freetype-py shapely svgutils
|
15 |
+
pip install opencv-python scikit-image matplotlib visdom wandb BeautifulSoup4
|
16 |
+
pip install triton numba
|
17 |
+
pip install numpy scipy scikit-fmm einops timm fairscale==0.4.13
|
18 |
+
pip install accelerate transformers safetensors datasets
|
19 |
+
pip install easydict scikit-learn pytorch_lightning==2.1.0 webdataset
|
20 |
+
|
21 |
+
echo "The basic dependency library is installed."
|
22 |
+
|
23 |
+
pip install ftfy regex tqdm
|
24 |
+
pip install git+https://github.com/openai/CLIP.git
|
25 |
+
|
26 |
+
echo "CLIP installation is complete."
|
27 |
+
|
28 |
+
pip install diffusers==0.20.2
|
29 |
+
|
30 |
+
echo "Diffusers installation is complete. version: 0.20.2"
|
31 |
+
# if xformers doesnt install properly with conda try installing with pip using the code below
|
32 |
+
# pip install --pre -U xformers
|
33 |
+
conda install xformers -c xformers
|
34 |
+
|
35 |
+
echo "xformers installation is complete."
|
36 |
+
|
37 |
+
git clone https://github.com/BachiLi/diffvg.git
|
38 |
+
cd diffvg
|
39 |
+
git submodule update --init --recursive
|
40 |
+
conda install -y -c anaconda cmake
|
41 |
+
conda install -y -c conda-forge ffmpeg
|
42 |
+
pip install svgwrite svgpathtools cssutils torch-tools
|
43 |
+
python setup.py install
|
44 |
+
|
45 |
+
echo "DiffVG installation is complete."
|
46 |
+
|
47 |
+
echo "the running environment has been successfully installed!!!"
|
svgdreamer/painter/painter_params.py
CHANGED
@@ -301,7 +301,7 @@ class Painter(DiffVGState):
|
|
301 |
fpath: The path to save the reinitialized SVG.
|
302 |
"""
|
303 |
if self.style not in ['iconography', 'low-poly', 'painting', 'ink']:
|
304 |
-
return
|
305 |
|
306 |
def get_keys_below_threshold(my_dict, threshold):
|
307 |
keys_below_threshold = [key for key, value in my_dict.items() if value < threshold]
|
@@ -360,7 +360,8 @@ class Painter(DiffVGState):
|
|
360 |
if path.id in reinit_union:
|
361 |
coord = [i, i] if self.style == 'low-poly' else None
|
362 |
self.shapes[i] = self.get_path(coord=coord)
|
363 |
-
#
|
|
|
364 |
self.shapes[i].points.requires_grad = True
|
365 |
extra_point_params.append(self.shapes[i].points)
|
366 |
if self.style == 'painting':
|
@@ -377,7 +378,7 @@ class Painter(DiffVGState):
|
|
377 |
shape_ids=torch.tensor(list(shp_ids)),
|
378 |
fill_color=fill_color_init,
|
379 |
stroke_color=None)
|
380 |
-
#
|
381 |
self.shape_groups[i].fill_color.requires_grad = True
|
382 |
extra_color_params.append(self.shape_groups[i].fill_color)
|
383 |
elif self.style in ['painting']:
|
@@ -387,7 +388,7 @@ class Painter(DiffVGState):
|
|
387 |
shape_ids=torch.tensor([len(self.shapes) - 1]),
|
388 |
fill_color=None,
|
389 |
stroke_color=stroke_color_init)
|
390 |
-
#
|
391 |
self.shape_groups[i].stroke_color.requires_grad = True
|
392 |
extra_color_params.append(self.shape_groups[i].stroke_color)
|
393 |
elif self.style in ['ink']:
|
@@ -397,7 +398,7 @@ class Painter(DiffVGState):
|
|
397 |
shape_ids=torch.tensor([len(self.shapes) - 1]),
|
398 |
fill_color=None,
|
399 |
stroke_color=stroke_color_init)
|
400 |
-
#
|
401 |
self.shape_groups[i].stroke_color.requires_grad = True
|
402 |
extra_color_params.append(self.shape_groups[i].stroke_color)
|
403 |
|
@@ -685,11 +686,11 @@ class PainterOptimizer:
|
|
685 |
self.point_scheduler = LambdaLR(self.point_optimizer, lr_lambda=self.lr_lambda, last_epoch=-1)
|
686 |
|
687 |
def add_params(self, point_params, color_params, width_params):
|
688 |
-
if len(point_params) > 0:
|
689 |
self.point_optimizer.add_param_group({f'params': point_params})
|
690 |
-
if len(color_params) > 0:
|
691 |
self.color_optimizer.add_param_group({f'params': color_params})
|
692 |
-
if len(width_params) > 0:
|
693 |
self.width_optimizer.add_param_group({f'params': width_params})
|
694 |
|
695 |
def update_lr(self):
|
|
|
301 |
fpath: The path to save the reinitialized SVG.
|
302 |
"""
|
303 |
if self.style not in ['iconography', 'low-poly', 'painting', 'ink']:
|
304 |
+
return None, None, None
|
305 |
|
306 |
def get_keys_below_threshold(my_dict, threshold):
|
307 |
keys_below_threshold = [key for key, value in my_dict.items() if value < threshold]
|
|
|
360 |
if path.id in reinit_union:
|
361 |
coord = [i, i] if self.style == 'low-poly' else None
|
362 |
self.shapes[i] = self.get_path(coord=coord)
|
363 |
+
# new point
|
364 |
+
self.shapes[i].id = path.id
|
365 |
self.shapes[i].points.requires_grad = True
|
366 |
extra_point_params.append(self.shapes[i].points)
|
367 |
if self.style == 'painting':
|
|
|
378 |
shape_ids=torch.tensor(list(shp_ids)),
|
379 |
fill_color=fill_color_init,
|
380 |
stroke_color=None)
|
381 |
+
# new shape
|
382 |
self.shape_groups[i].fill_color.requires_grad = True
|
383 |
extra_color_params.append(self.shape_groups[i].fill_color)
|
384 |
elif self.style in ['painting']:
|
|
|
388 |
shape_ids=torch.tensor([len(self.shapes) - 1]),
|
389 |
fill_color=None,
|
390 |
stroke_color=stroke_color_init)
|
391 |
+
# new shape
|
392 |
self.shape_groups[i].stroke_color.requires_grad = True
|
393 |
extra_color_params.append(self.shape_groups[i].stroke_color)
|
394 |
elif self.style in ['ink']:
|
|
|
398 |
shape_ids=torch.tensor([len(self.shapes) - 1]),
|
399 |
fill_color=None,
|
400 |
stroke_color=stroke_color_init)
|
401 |
+
# new shape
|
402 |
self.shape_groups[i].stroke_color.requires_grad = True
|
403 |
extra_color_params.append(self.shape_groups[i].stroke_color)
|
404 |
|
|
|
686 |
self.point_scheduler = LambdaLR(self.point_optimizer, lr_lambda=self.lr_lambda, last_epoch=-1)
|
687 |
|
688 |
def add_params(self, point_params, color_params, width_params):
|
689 |
+
if point_params is not None and len(point_params) > 0:
|
690 |
self.point_optimizer.add_param_group({f'params': point_params})
|
691 |
+
if color_params is not None and len(color_params) > 0:
|
692 |
self.color_optimizer.add_param_group({f'params': color_params})
|
693 |
+
if width_params is not None and len(width_params) > 0:
|
694 |
self.width_optimizer.add_param_group({f'params': width_params})
|
695 |
|
696 |
def update_lr(self):
|
svgdreamer/pipelines/SVGDreamer_pipeline.py
CHANGED
@@ -92,10 +92,6 @@ class SVGDreamerPipeline(ModelState):
|
|
92 |
self.vpsd_cfg = self.x_cfg.vpsd
|
93 |
self.vpsd_optim = self.x_cfg.vpsd_stage_optim
|
94 |
|
95 |
-
if self.style == "pixelart":
|
96 |
-
self.x_cfg.sive_stage_optim.lr_schedule = False
|
97 |
-
self.x_cfg.vpsd_stage_optim.lr_schedule = False
|
98 |
-
|
99 |
def painterly_rendering(self, text_prompt: str, target_file: AnyPath = None):
|
100 |
# log prompts
|
101 |
self.print(f"prompt: {text_prompt}")
|
@@ -132,9 +128,9 @@ class SVGDreamerPipeline(ModelState):
|
|
132 |
merged_images = []
|
133 |
for i in range(self.vpsd_cfg.n_particle):
|
134 |
select_sample_path = self.result_path / f'select_sample_{i}.png'
|
135 |
-
|
136 |
# generate sample and attention map
|
137 |
-
fg_attn_map, bg_attn_map, controller = self.extract_ldm_attn(
|
|
|
138 |
pipeline,
|
139 |
text_prompt,
|
140 |
select_sample_path,
|
@@ -146,7 +142,8 @@ class SVGDreamerPipeline(ModelState):
|
|
146 |
self.print(f"load target file from: {select_sample_path.as_posix()}")
|
147 |
|
148 |
# get objects by attention map
|
149 |
-
fg_img, bg_img, fg_mask, bg_mask = self.extract_object(select_img, fg_attn_map, bg_attn_map,
|
|
|
150 |
self.print(f"fg_img shape: {fg_img.shape}, bg_img: {bg_img.shape}")
|
151 |
|
152 |
# background rendering
|
@@ -641,7 +638,7 @@ class SVGDreamerPipeline(ModelState):
|
|
641 |
|
642 |
# save final
|
643 |
for i, r in enumerate(renderers):
|
644 |
-
ft_svg_path = self.result_path / f"
|
645 |
r.pretty_save_svg(ft_svg_path)
|
646 |
# save SVGs
|
647 |
torchvision.utils.save_image(raster_imgs, fp=self.result_path / f'all_particles.png')
|
@@ -683,10 +680,10 @@ class SVGDreamerPipeline(ModelState):
|
|
683 |
return target_img
|
684 |
|
685 |
def extract_object(self,
|
|
|
686 |
select_img: torch.Tensor,
|
687 |
fg_attn_map: np.ndarray,
|
688 |
bg_attn_map: np.ndarray,
|
689 |
-
iter: Union[str, int],
|
690 |
tau: float = 0.2):
|
691 |
# attention to mask
|
692 |
bool_fg_attn_map = fg_attn_map > tau
|
@@ -755,6 +752,7 @@ class SVGDreamerPipeline(ModelState):
|
|
755 |
return fg_img_final, bg_img_final, fg_mask, bg_mask
|
756 |
|
757 |
def extract_ldm_attn(self,
|
|
|
758 |
model_cfg: omegaconf.DictConfig,
|
759 |
pipeline: DiffusionPipeline,
|
760 |
prompts: str,
|
@@ -762,7 +760,7 @@ class SVGDreamerPipeline(ModelState):
|
|
762 |
attn_init_cfg: omegaconf.DictConfig,
|
763 |
image_size: int,
|
764 |
token_ind: int,
|
765 |
-
attn_init: bool = True
|
766 |
if token_ind <= 0:
|
767 |
raise ValueError("The 'token_ind' should be greater than 0")
|
768 |
|
@@ -837,7 +835,7 @@ class SVGDreamerPipeline(ModelState):
|
|
837 |
self_attn_vis = np.copy(self_attn)
|
838 |
self_attn_vis = self_attn_vis * 255
|
839 |
self_attn_vis = np.repeat(np.expand_dims(self_attn_vis, axis=2), 3, axis=2).astype(np.uint8)
|
840 |
-
view_images(self_attn_vis, save_image=True, fp=self.sive_attn_dir / "self-attn-final.png")
|
841 |
|
842 |
"""get final attention map"""
|
843 |
attn_map = attn_init_cfg.attn_coeff * cross_attn_map + (1 - attn_init_cfg.attn_coeff) * self_attn
|
@@ -847,7 +845,7 @@ class SVGDreamerPipeline(ModelState):
|
|
847 |
attn_map_vis = np.copy(attn_map)
|
848 |
attn_map_vis = attn_map_vis * 255
|
849 |
attn_map_vis = np.repeat(np.expand_dims(attn_map_vis, axis=2), 3, axis=2).astype(np.uint8)
|
850 |
-
view_images(attn_map_vis, save_image=True, fp=self.sive_attn_dir / 'fusion-attn.png')
|
851 |
|
852 |
# inverse fusion-attention to [0, 1]
|
853 |
inverse_attn = 1 - attn_map
|
@@ -855,7 +853,8 @@ class SVGDreamerPipeline(ModelState):
|
|
855 |
reversed_attn_map_vis = np.copy(inverse_attn)
|
856 |
reversed_attn_map_vis = reversed_attn_map_vis * 255
|
857 |
reversed_attn_map_vis = np.repeat(np.expand_dims(reversed_attn_map_vis, axis=2), 3, axis=2).astype(np.uint8)
|
858 |
-
view_images(reversed_attn_map_vis, save_image=True,
|
|
|
859 |
|
860 |
self.print(f"-> fusion attn_map: {attn_map.shape}")
|
861 |
else:
|
|
|
92 |
self.vpsd_cfg = self.x_cfg.vpsd
|
93 |
self.vpsd_optim = self.x_cfg.vpsd_stage_optim
|
94 |
|
|
|
|
|
|
|
|
|
95 |
def painterly_rendering(self, text_prompt: str, target_file: AnyPath = None):
|
96 |
# log prompts
|
97 |
self.print(f"prompt: {text_prompt}")
|
|
|
128 |
merged_images = []
|
129 |
for i in range(self.vpsd_cfg.n_particle):
|
130 |
select_sample_path = self.result_path / f'select_sample_{i}.png'
|
|
|
131 |
# generate sample and attention map
|
132 |
+
fg_attn_map, bg_attn_map, controller = self.extract_ldm_attn(i,
|
133 |
+
self.x_cfg.sive_model_cfg,
|
134 |
pipeline,
|
135 |
text_prompt,
|
136 |
select_sample_path,
|
|
|
142 |
self.print(f"load target file from: {select_sample_path.as_posix()}")
|
143 |
|
144 |
# get objects by attention map
|
145 |
+
fg_img, bg_img, fg_mask, bg_mask = self.extract_object(i, select_img, fg_attn_map, bg_attn_map,
|
146 |
+
tau=self.sive_cfg.mask_tau)
|
147 |
self.print(f"fg_img shape: {fg_img.shape}, bg_img: {bg_img.shape}")
|
148 |
|
149 |
# background rendering
|
|
|
638 |
|
639 |
# save final
|
640 |
for i, r in enumerate(renderers):
|
641 |
+
ft_svg_path = self.result_path / f"finetune_final_p{i}.svg"
|
642 |
r.pretty_save_svg(ft_svg_path)
|
643 |
# save SVGs
|
644 |
torchvision.utils.save_image(raster_imgs, fp=self.result_path / f'all_particles.png')
|
|
|
680 |
return target_img
|
681 |
|
682 |
def extract_object(self,
|
683 |
+
iter: Union[str, int],
|
684 |
select_img: torch.Tensor,
|
685 |
fg_attn_map: np.ndarray,
|
686 |
bg_attn_map: np.ndarray,
|
|
|
687 |
tau: float = 0.2):
|
688 |
# attention to mask
|
689 |
bool_fg_attn_map = fg_attn_map > tau
|
|
|
752 |
return fg_img_final, bg_img_final, fg_mask, bg_mask
|
753 |
|
754 |
def extract_ldm_attn(self,
|
755 |
+
iter: int,
|
756 |
model_cfg: omegaconf.DictConfig,
|
757 |
pipeline: DiffusionPipeline,
|
758 |
prompts: str,
|
|
|
760 |
attn_init_cfg: omegaconf.DictConfig,
|
761 |
image_size: int,
|
762 |
token_ind: int,
|
763 |
+
attn_init: bool = True):
|
764 |
if token_ind <= 0:
|
765 |
raise ValueError("The 'token_ind' should be greater than 0")
|
766 |
|
|
|
835 |
self_attn_vis = np.copy(self_attn)
|
836 |
self_attn_vis = self_attn_vis * 255
|
837 |
self_attn_vis = np.repeat(np.expand_dims(self_attn_vis, axis=2), 3, axis=2).astype(np.uint8)
|
838 |
+
view_images(self_attn_vis, save_image=True, fp=self.sive_attn_dir / f"self-attn-final-{iter}.png")
|
839 |
|
840 |
"""get final attention map"""
|
841 |
attn_map = attn_init_cfg.attn_coeff * cross_attn_map + (1 - attn_init_cfg.attn_coeff) * self_attn
|
|
|
845 |
attn_map_vis = np.copy(attn_map)
|
846 |
attn_map_vis = attn_map_vis * 255
|
847 |
attn_map_vis = np.repeat(np.expand_dims(attn_map_vis, axis=2), 3, axis=2).astype(np.uint8)
|
848 |
+
view_images(attn_map_vis, save_image=True, fp=self.sive_attn_dir / f'fusion-attn-{iter}.png')
|
849 |
|
850 |
# inverse fusion-attention to [0, 1]
|
851 |
inverse_attn = 1 - attn_map
|
|
|
853 |
reversed_attn_map_vis = np.copy(inverse_attn)
|
854 |
reversed_attn_map_vis = reversed_attn_map_vis * 255
|
855 |
reversed_attn_map_vis = np.repeat(np.expand_dims(reversed_attn_map_vis, axis=2), 3, axis=2).astype(np.uint8)
|
856 |
+
view_images(reversed_attn_map_vis, save_image=True,
|
857 |
+
fp=self.sive_attn_dir / f'reversed-fusion-attn-{iter}.png')
|
858 |
|
859 |
self.print(f"-> fusion attn_map: {attn_map.shape}")
|
860 |
else:
|