fix(method): if attn_map is None
Browse files- Examples.md +7 -1
- README.md +1 -1
- conf/x/iconography.yaml +2 -2
- svgdreamer/painter/__init__.py +2 -2
- svgdreamer/pipelines/SVGDreamer_pipeline.py +73 -40
- svgdreamer/token2attn/attn_control.py +1 -1
Examples.md
CHANGED
@@ -160,4 +160,10 @@ expressive eyes. <br/>
|
|
160 |
|
161 |
````shell
|
162 |
python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" x.num_paths=256 result_path='./logs/VanGogh-Portrait'
|
163 |
-
````
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
````shell
|
162 |
python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" x.num_paths=256 result_path='./logs/VanGogh-Portrait'
|
163 |
+
````
|
164 |
+
|
165 |
+
### Case: planet Saturn
|
166 |
+
|
167 |
+
```shell
|
168 |
+
python svgdreamer.py x=iconography-s1 skip_sive=False "prompt='An icon of the planet Saturn. minimal flat 2D vector icon. plain color background. trending on ArtStation.'" token_ind=6 x.sive.bg.num_iter=50 x.sive.fg.num_iter=50 x.vpsd.t_schedule='randint' result_path='./logs/Saturn' multirun=True state.mprec='fp16
|
169 |
+
```
|
README.md
CHANGED
@@ -80,7 +80,7 @@ realistic <br/>
|
|
80 |
**Script:**
|
81 |
|
82 |
```shell
|
83 |
-
python svgdreamer.py x=iconography skip_sive=False "prompt='an image of Batman. full body action pose, complete detailed body. white background. empty background, high quality, 4K, ultra realistic'" token_ind=4 x.
|
84 |
```
|
85 |
|
86 |
🔹Parameter:
|
|
|
80 |
**Script:**
|
81 |
|
82 |
```shell
|
83 |
+
python svgdreamer.py x=iconography skip_sive=False "prompt='an image of Batman. full body action pose, complete detailed body. white background. empty background, high quality, 4K, ultra realistic'" token_ind=4 x.vpsd.t_schedule='randint' result_path='./logs/batman' multirun=True
|
84 |
```
|
85 |
|
86 |
🔹Parameter:
|
conf/x/iconography.yaml
CHANGED
@@ -41,7 +41,7 @@ sive:
|
|
41 |
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
42 |
bg:
|
43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
44 |
-
num_iter:
|
45 |
num_paths: 256
|
46 |
path_schedule: 'repeat' # 'repeat', 'list'
|
47 |
schedule_each: 128
|
@@ -61,7 +61,7 @@ sive:
|
|
61 |
xing_loss_weight: 0.001
|
62 |
fg:
|
63 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
64 |
-
num_iter:
|
65 |
num_paths: 256 # number of strokes
|
66 |
path_schedule: 'repeat' # 'repeat', 'list'
|
67 |
schedule_each: 128
|
|
|
41 |
mask_tau: 0.3 # the threshold used to convert the attention map into a mask
|
42 |
bg:
|
43 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
44 |
+
num_iter: 50
|
45 |
num_paths: 256
|
46 |
path_schedule: 'repeat' # 'repeat', 'list'
|
47 |
schedule_each: 128
|
|
|
61 |
xing_loss_weight: 0.001
|
62 |
fg:
|
63 |
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
64 |
+
num_iter: 50
|
65 |
num_paths: 256 # number of strokes
|
66 |
path_schedule: 'repeat' # 'repeat', 'list'
|
67 |
schedule_each: 128
|
svgdreamer/painter/__init__.py
CHANGED
@@ -2,8 +2,8 @@
|
|
2 |
# Copyright (c) XiMing Xing. All rights reserved.
|
3 |
# Description:
|
4 |
|
5 |
-
from .painter_params import
|
6 |
-
|
7 |
from .component_painter_params import CompPainter, CompPainterOptimizer
|
8 |
from .loss import xing_loss_fn
|
9 |
from .VPSD_pipeline import VectorizedParticleSDSPipeline
|
|
|
2 |
# Copyright (c) XiMing Xing. All rights reserved.
|
3 |
# Description:
|
4 |
|
5 |
+
from .painter_params import Painter, PainterOptimizer, CosineWithWarmupLRLambda, RandomCoordInit, NaiveCoordInit, \
|
6 |
+
SparseCoordInit, get_sdf
|
7 |
from .component_painter_params import CompPainter, CompPainterOptimizer
|
8 |
from .loss import xing_loss_fn
|
9 |
from .VPSD_pipeline import VectorizedParticleSDSPipeline
|
svgdreamer/pipelines/SVGDreamer_pipeline.py
CHANGED
@@ -20,8 +20,8 @@ from torchvision import transforms
|
|
20 |
from skimage.color import rgb2gray
|
21 |
|
22 |
from svgdreamer.libs import ModelState, get_optimizer
|
23 |
-
from svgdreamer.painter import
|
24 |
-
|
25 |
from svgdreamer.token2attn.attn_control import EmptyControl, AttentionStore
|
26 |
from svgdreamer.token2attn.ptp_utils import view_images
|
27 |
from svgdreamer.utils.plot import plot_img, plot_couple, plot_attn, save_image
|
@@ -38,8 +38,10 @@ class SVGDreamerPipeline(ModelState):
|
|
38 |
# assert
|
39 |
assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"]
|
40 |
args.skip_sive = True if args.x.style in ["pixelart", "low-poly"] else args.skip_sive
|
41 |
-
assert args.x.vpsd.n_particle >= args.x.vpsd.vsd_n_particle
|
42 |
-
|
|
|
|
|
43 |
assert args.x.vpsd.n_phi_sample >= 1
|
44 |
|
45 |
logdir_ = f"sd{args.seed}" \
|
@@ -123,15 +125,26 @@ class SVGDreamerPipeline(ModelState):
|
|
123 |
self.close(msg="painterly rendering complete.")
|
124 |
|
125 |
def SIVE_stage(self, text_prompt: str):
|
126 |
-
#
|
127 |
pipeline = DiffusionPipeline(self.x_cfg.sive_model_cfg, self.args.diffuser, self.device)
|
128 |
|
129 |
merged_svg_paths = []
|
130 |
merged_images = []
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
self.x_cfg.sive_model_cfg,
|
136 |
pipeline,
|
137 |
text_prompt,
|
@@ -139,18 +152,18 @@ class SVGDreamerPipeline(ModelState):
|
|
139 |
self.sive_cfg.attn_cfg,
|
140 |
self.im_size,
|
141 |
self.args.token_ind)
|
142 |
-
#
|
143 |
select_img = self.target_file_preprocess(select_sample_path.as_posix())
|
144 |
self.print(f"load target file from: {select_sample_path.as_posix()}")
|
145 |
|
146 |
-
#
|
147 |
-
fg_img, bg_img, fg_mask, bg_mask = self.extract_object(
|
148 |
tau=self.sive_cfg.mask_tau)
|
149 |
-
self.print(f"fg_img shape: {fg_img.shape}, bg_img: {bg_img.shape}")
|
150 |
|
151 |
-
#
|
152 |
-
self.print(f"->
|
153 |
-
bg_render_path = self.component_rendering(tag=f'{
|
154 |
prompt=text_prompt,
|
155 |
target_img=bg_img,
|
156 |
mask=bg_mask,
|
@@ -160,9 +173,14 @@ class SVGDreamerPipeline(ModelState):
|
|
160 |
optim_cfg=self.sive_optim,
|
161 |
log_png_dir=self.bg_png_logs_dir,
|
162 |
log_svg_dir=self.bg_svg_logs_dir)
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
166 |
prompt=text_prompt,
|
167 |
target_img=fg_img,
|
168 |
mask=fg_mask,
|
@@ -172,8 +190,16 @@ class SVGDreamerPipeline(ModelState):
|
|
172 |
optim_cfg=self.sive_optim,
|
173 |
log_png_dir=self.fg_png_logs_dir,
|
174 |
log_svg_dir=self.fg_svg_logs_dir)
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
merge_svg_files(
|
178 |
svg_path_1=bg_render_path,
|
179 |
svg_path_2=fg_render_path,
|
@@ -182,11 +208,11 @@ class SVGDreamerPipeline(ModelState):
|
|
182 |
out_size=(self.im_size, self.im_size)
|
183 |
)
|
184 |
|
185 |
-
#
|
186 |
# Note: you are not allowed to add further paths here
|
187 |
if self.sive_cfg.tog.reinit:
|
188 |
-
self.print("->
|
189 |
-
merged_svg_path = self.refine_rendering(tag=f'{
|
190 |
prompt=text_prompt,
|
191 |
target_img=select_img,
|
192 |
canvas_size=(self.im_size, self.im_size),
|
@@ -194,22 +220,21 @@ class SVGDreamerPipeline(ModelState):
|
|
194 |
optim_cfg=self.sive_optim,
|
195 |
init_svg_path=merged_svg_path)
|
196 |
|
197 |
-
# svg-to-png
|
198 |
-
merged_png_path = self.result_path / f'SIVE_render_final_{
|
199 |
cairosvg.svg2png(url=merged_svg_path.as_posix(), write_to=merged_png_path.as_posix())
|
200 |
-
|
201 |
-
# collect paths
|
202 |
-
merged_svg_paths.append(merged_svg_path)
|
203 |
merged_images.append(self.target_file_preprocess(merged_png_path))
|
204 |
-
|
|
|
205 |
controller.reset()
|
206 |
|
207 |
-
self.print(f"Vector Particle {
|
208 |
|
209 |
-
#
|
210 |
del pipeline
|
211 |
torch.cuda.empty_cache()
|
212 |
-
#
|
213 |
self.x_cfg.num_paths = self.sive_cfg.bg.num_paths + self.sive_cfg.fg.num_paths
|
214 |
|
215 |
return merged_svg_paths, merged_images
|
@@ -257,6 +282,9 @@ class SVGDreamerPipeline(ModelState):
|
|
257 |
if attention_map is not None:
|
258 |
# init fist control points by attention_map
|
259 |
attn_thresh, select_inds = renderer.attn_init_points(num_paths=sum(path_schedule), mask=mask)
|
|
|
|
|
|
|
260 |
# log attention, just once
|
261 |
plot_attn(attention_map, attn_thresh, target_img, select_inds,
|
262 |
(self.sive_attn_dir / f"attention_{tag}_map.jpg").as_posix())
|
@@ -381,14 +409,16 @@ class SVGDreamerPipeline(ModelState):
|
|
381 |
plot_img(img, self.refine_dir, fname=f"{tag}_before_refined")
|
382 |
|
383 |
n_iter = render_cfg.num_iter
|
|
|
|
|
384 |
# build painter optimizer
|
385 |
optimizer = CompPainterOptimizer(content_renderer, self.style, n_iter, optim_cfg)
|
386 |
# init optimizer
|
387 |
optimizer.init_optimizers()
|
388 |
|
389 |
-
print(f"=> n_point: {len(content_renderer.get_point_params())}, "
|
390 |
-
|
391 |
-
|
392 |
|
393 |
step = 0
|
394 |
with tqdm(initial=step, total=n_iter, disable=not self.accelerator.is_main_process) as pbar:
|
@@ -434,7 +464,8 @@ class SVGDreamerPipeline(ModelState):
|
|
434 |
text_prompt: AnyStr,
|
435 |
init_svg_path: Union[List[AnyPath], AnyPath] = None,
|
436 |
init_image: Union[List[torch.Tensor], torch.Tensor] = None):
|
437 |
-
|
|
|
438 |
return
|
439 |
|
440 |
# for convenience
|
@@ -784,10 +815,12 @@ class SVGDreamerPipeline(ModelState):
|
|
784 |
generator=self.g_device)
|
785 |
outputs_np = [np.array(img) for img in outputs.images]
|
786 |
view_images(outputs_np, save_image=True, fp=gen_sample_path)
|
787 |
-
self.print(f"select_sample shape: {outputs_np[0].shape}")
|
788 |
|
789 |
if attn_init:
|
790 |
-
"
|
|
|
|
|
791 |
cross_attention_maps, tokens = \
|
792 |
pipeline.get_cross_attention([prompts],
|
793 |
controller,
|
@@ -862,7 +895,7 @@ class SVGDreamerPipeline(ModelState):
|
|
862 |
view_images(reversed_attn_map_vis, save_image=True,
|
863 |
fp=self.sive_attn_dir / f'reversed-fusion-attn-{iter}.png')
|
864 |
|
865 |
-
self.print(f"-> fusion attn_map: {attn_map.shape}")
|
866 |
else:
|
867 |
attn_map = None
|
868 |
inverse_attn = None
|
|
|
20 |
from skimage.color import rgb2gray
|
21 |
|
22 |
from svgdreamer.libs import ModelState, get_optimizer
|
23 |
+
from svgdreamer.painter import CompPainter, CompPainterOptimizer, xing_loss_fn, Painter, PainterOptimizer, \
|
24 |
+
CosineWithWarmupLRLambda, VectorizedParticleSDSPipeline, DiffusionPipeline
|
25 |
from svgdreamer.token2attn.attn_control import EmptyControl, AttentionStore
|
26 |
from svgdreamer.token2attn.ptp_utils import view_images
|
27 |
from svgdreamer.utils.plot import plot_img, plot_couple, plot_attn, save_image
|
|
|
38 |
# assert
|
39 |
assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"]
|
40 |
args.skip_sive = True if args.x.style in ["pixelart", "low-poly"] else args.skip_sive
|
41 |
+
# assert args.x.vpsd.n_particle >= args.x.vpsd.vsd_n_particle
|
42 |
+
if args.x.vpsd.vsd_n_particle > args.x.vpsd.n_particle: args.x.vpsd.vsd_n_particle = args.x.vpsd.n_particle
|
43 |
+
# assert args.x.vpsd.n_particle >= args.x.vpsd.phi_n_particle
|
44 |
+
if args.x.vpsd.phi_n_particle > args.x.vpsd.n_particle: args.x.vpsd.phi_n_particle = args.x.vpsd.n_particle
|
45 |
assert args.x.vpsd.n_phi_sample >= 1
|
46 |
|
47 |
logdir_ = f"sd{args.seed}" \
|
|
|
125 |
self.close(msg="painterly rendering complete.")
|
126 |
|
127 |
def SIVE_stage(self, text_prompt: str):
|
128 |
+
# Init diffusion model
|
129 |
pipeline = DiffusionPipeline(self.x_cfg.sive_model_cfg, self.args.diffuser, self.device)
|
130 |
|
131 |
merged_svg_paths = []
|
132 |
merged_images = []
|
133 |
+
|
134 |
+
successful_particles = 0
|
135 |
+
cur_idx = 0
|
136 |
+
|
137 |
+
while successful_particles < self.vpsd_cfg.n_particle:
|
138 |
+
if cur_idx >= self.vpsd_cfg.n_particle + 10: # max attempts
|
139 |
+
self.print(f"Reached maximum attempts ({cur_idx}). "
|
140 |
+
f"Only processed {successful_particles} particles successfully.")
|
141 |
+
break
|
142 |
+
|
143 |
+
self.print(f"Processing particle {cur_idx} "
|
144 |
+
f"(successful so far: {successful_particles}/{self.vpsd_cfg.n_particle})")
|
145 |
+
select_sample_path = self.result_path / f'select_sample_{cur_idx}.png'
|
146 |
+
# Generate sample and attention map
|
147 |
+
fg_attn_map, bg_attn_map, controller = self.extract_ldm_attn(cur_idx,
|
148 |
self.x_cfg.sive_model_cfg,
|
149 |
pipeline,
|
150 |
text_prompt,
|
|
|
152 |
self.sive_cfg.attn_cfg,
|
153 |
self.im_size,
|
154 |
self.args.token_ind)
|
155 |
+
# Load selected file
|
156 |
select_img = self.target_file_preprocess(select_sample_path.as_posix())
|
157 |
self.print(f"load target file from: {select_sample_path.as_posix()}")
|
158 |
|
159 |
+
# Get objects by attention map
|
160 |
+
fg_img, bg_img, fg_mask, bg_mask = self.extract_object(cur_idx, select_img, fg_attn_map, bg_attn_map,
|
161 |
tau=self.sive_cfg.mask_tau)
|
162 |
+
# self.print(f"fg_img shape: {fg_img.shape}, bg_img: {bg_img.shape}")
|
163 |
|
164 |
+
# Background rendering
|
165 |
+
self.print(f"-> Background rendering: ")
|
166 |
+
bg_render_path = self.component_rendering(tag=f'{cur_idx}_bg',
|
167 |
prompt=text_prompt,
|
168 |
target_img=bg_img,
|
169 |
mask=bg_mask,
|
|
|
173 |
optim_cfg=self.sive_optim,
|
174 |
log_png_dir=self.bg_png_logs_dir,
|
175 |
log_svg_dir=self.bg_svg_logs_dir)
|
176 |
+
if bg_render_path == 0:
|
177 |
+
self.print(f"Background rendering failed for particle {cur_idx}, trying next particle")
|
178 |
+
cur_idx += 1
|
179 |
+
continue
|
180 |
+
|
181 |
+
# Foreground rendering
|
182 |
+
self.print(f"-> Foreground rendering: ")
|
183 |
+
fg_render_path = self.component_rendering(tag=f'{cur_idx}_fg',
|
184 |
prompt=text_prompt,
|
185 |
target_img=fg_img,
|
186 |
mask=fg_mask,
|
|
|
190 |
optim_cfg=self.sive_optim,
|
191 |
log_png_dir=self.fg_png_logs_dir,
|
192 |
log_svg_dir=self.fg_svg_logs_dir)
|
193 |
+
if fg_render_path == 0:
|
194 |
+
self.print(f"Foreground rendering failed for particle {cur_idx}, trying next particle")
|
195 |
+
cur_idx += 1
|
196 |
+
continue
|
197 |
+
|
198 |
+
successful_particles += 1
|
199 |
+
cur_idx += 1
|
200 |
+
|
201 |
+
# Merge foreground and background
|
202 |
+
merged_svg_path = self.result_path / f'SIVE_render_final_{cur_idx}.svg'
|
203 |
merge_svg_files(
|
204 |
svg_path_1=bg_render_path,
|
205 |
svg_path_2=fg_render_path,
|
|
|
208 |
out_size=(self.im_size, self.im_size)
|
209 |
)
|
210 |
|
211 |
+
# Foreground and background refinement
|
212 |
# Note: you are not allowed to add further paths here
|
213 |
if self.sive_cfg.tog.reinit:
|
214 |
+
self.print("-> Enable vector graphic refinement:")
|
215 |
+
merged_svg_path = self.refine_rendering(tag=f'{cur_idx}_refine',
|
216 |
prompt=text_prompt,
|
217 |
target_img=select_img,
|
218 |
canvas_size=(self.im_size, self.im_size),
|
|
|
220 |
optim_cfg=self.sive_optim,
|
221 |
init_svg_path=merged_svg_path)
|
222 |
|
223 |
+
# Postprocess: svg-to-png & to tensor
|
224 |
+
merged_png_path = self.result_path / f'SIVE_render_final_{cur_idx}.png'
|
225 |
cairosvg.svg2png(url=merged_svg_path.as_posix(), write_to=merged_png_path.as_posix())
|
226 |
+
merged_svg_paths.append(merged_svg_path) # collect paths
|
|
|
|
|
227 |
merged_images.append(self.target_file_preprocess(merged_png_path))
|
228 |
+
|
229 |
+
# Clear attention recorder
|
230 |
controller.reset()
|
231 |
|
232 |
+
self.print(f"Vector Particle {cur_idx} Rendering End...\n")
|
233 |
|
234 |
+
# Free the VRAM
|
235 |
del pipeline
|
236 |
torch.cuda.empty_cache()
|
237 |
+
# Update paths
|
238 |
self.x_cfg.num_paths = self.sive_cfg.bg.num_paths + self.sive_cfg.fg.num_paths
|
239 |
|
240 |
return merged_svg_paths, merged_images
|
|
|
282 |
if attention_map is not None:
|
283 |
# init fist control points by attention_map
|
284 |
attn_thresh, select_inds = renderer.attn_init_points(num_paths=sum(path_schedule), mask=mask)
|
285 |
+
# Warning: attention map failure
|
286 |
+
if len(select_inds) == 0: return 0
|
287 |
+
|
288 |
# log attention, just once
|
289 |
plot_attn(attention_map, attn_thresh, target_img, select_inds,
|
290 |
(self.sive_attn_dir / f"attention_{tag}_map.jpg").as_posix())
|
|
|
409 |
plot_img(img, self.refine_dir, fname=f"{tag}_before_refined")
|
410 |
|
411 |
n_iter = render_cfg.num_iter
|
412 |
+
self.print(f"Total iters: {n_iter}")
|
413 |
+
|
414 |
# build painter optimizer
|
415 |
optimizer = CompPainterOptimizer(content_renderer, self.style, n_iter, optim_cfg)
|
416 |
# init optimizer
|
417 |
optimizer.init_optimizers()
|
418 |
|
419 |
+
self.print(f"=> n_point: {len(content_renderer.get_point_params())}, "
|
420 |
+
f"n_width: {len(content_renderer.get_width_params())}, "
|
421 |
+
f"n_color: {len(content_renderer.get_color_params())}")
|
422 |
|
423 |
step = 0
|
424 |
with tqdm(initial=step, total=n_iter, disable=not self.accelerator.is_main_process) as pbar:
|
|
|
464 |
text_prompt: AnyStr,
|
465 |
init_svg_path: Union[List[AnyPath], AnyPath] = None,
|
466 |
init_image: Union[List[torch.Tensor], torch.Tensor] = None):
|
467 |
+
# print(f"self.vpsd_cfg.use: {self.vpsd_cfg.use}")
|
468 |
+
if self.vpsd_cfg.use is False:
|
469 |
return
|
470 |
|
471 |
# for convenience
|
|
|
815 |
generator=self.g_device)
|
816 |
outputs_np = [np.array(img) for img in outputs.images]
|
817 |
view_images(outputs_np, save_image=True, fp=gen_sample_path)
|
818 |
+
# self.print(f"select_sample shape: {outputs_np[0].shape}")
|
819 |
|
820 |
if attn_init:
|
821 |
+
self.print(f"\nLDM attn-map logging:")
|
822 |
+
|
823 |
+
# Cross-attention map
|
824 |
cross_attention_maps, tokens = \
|
825 |
pipeline.get_cross_attention([prompts],
|
826 |
controller,
|
|
|
895 |
view_images(reversed_attn_map_vis, save_image=True,
|
896 |
fp=self.sive_attn_dir / f'reversed-fusion-attn-{iter}.png')
|
897 |
|
898 |
+
self.print(f"-> fusion attn_map: {attn_map.shape} \n")
|
899 |
else:
|
900 |
attn_map = None
|
901 |
inverse_attn = None
|
svgdreamer/token2attn/attn_control.py
CHANGED
@@ -85,7 +85,7 @@ class AttentionStore(AttentionControl):
|
|
85 |
self.step_store = self.get_empty_store()
|
86 |
|
87 |
def get_average_attention(self):
|
88 |
-
print(f"step count: {self.cur_step}")
|
89 |
average_attention = {
|
90 |
key: [item / self.cur_step for item in self.attention_store[key]]
|
91 |
for key in self.attention_store
|
|
|
85 |
self.step_store = self.get_empty_store()
|
86 |
|
87 |
def get_average_attention(self):
|
88 |
+
# print(f"step count: {self.cur_step}")
|
89 |
average_attention = {
|
90 |
key: [item / self.cur_step for item in self.attention_store[key]]
|
91 |
for key in self.attention_store
|