Diffusers
TalHach61 commited on
Commit
f173bed
·
verified ·
1 Parent(s): ebe67e3

Delete controlnet_flux.py

Browse files
Files changed (1) hide show
  1. controlnet_flux.py +0 -649
controlnet_flux.py DELETED
@@ -1,649 +0,0 @@
1
- # type: ignore
2
- # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from dataclasses import dataclass
17
- from typing import Any, Dict, List, Optional, Tuple, Union
18
-
19
- import torch
20
- import torch.nn as nn
21
-
22
- from transformer_bria import TimestepProjEmbeddings
23
- from diffusers.models.controlnet import zero_module, BaseOutput
24
- from diffusers.configuration_utils import ConfigMixin, register_to_config
25
- from diffusers.loaders import PeftAdapterMixin
26
- from diffusers.models.modeling_utils import ModelMixin
27
- from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
29
-
30
- # from transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock, EmbedND
31
- from diffusers.models.transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock
32
-
33
- from diffusers.models.attention_processor import AttentionProcessor
34
-
35
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
-
37
-
38
- @dataclass
39
- class FluxControlNetOutput(BaseOutput):
40
- controlnet_block_samples: Tuple[torch.Tensor]
41
- controlnet_single_block_samples: Tuple[torch.Tensor]
42
-
43
-
44
- class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
45
- _supports_gradient_checkpointing = True
46
-
47
- @register_to_config
48
- def __init__(
49
- self,
50
- patch_size: int = 1,
51
- in_channels: int = 64,
52
- num_layers: int = 19,
53
- num_single_layers: int = 38,
54
- attention_head_dim: int = 128,
55
- num_attention_heads: int = 24,
56
- joint_attention_dim: int = 4096,
57
- pooled_projection_dim: int = 768,
58
- guidance_embeds: bool = False,
59
- axes_dims_rope: List[int] = [16, 56, 56],
60
- num_mode: int = None,
61
- rope_theta: int = 10000,
62
- time_theta: int = 10000,
63
- ):
64
- super().__init__()
65
- self.out_channels = in_channels
66
- self.inner_dim = num_attention_heads * attention_head_dim
67
-
68
- # self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
69
- self.pos_embed = EmbedND(dim=self.inner_dim, theta=rope_theta, axes_dim=axes_dims_rope)
70
-
71
- # text_time_guidance_cls = (
72
- # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
73
- # )
74
- # self.time_text_embed = text_time_guidance_cls(
75
- # embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
76
- # )
77
- self.time_embed = TimestepProjEmbeddings(
78
- embedding_dim=self.inner_dim,time_theta=time_theta
79
- )
80
- self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
81
- self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
82
-
83
- self.transformer_blocks = nn.ModuleList(
84
- [
85
- FluxTransformerBlock(
86
- dim=self.inner_dim,
87
- num_attention_heads=num_attention_heads,
88
- attention_head_dim=attention_head_dim,
89
- )
90
- for i in range(num_layers)
91
- ]
92
- )
93
-
94
- self.single_transformer_blocks = nn.ModuleList(
95
- [
96
- FluxSingleTransformerBlock(
97
- dim=self.inner_dim,
98
- num_attention_heads=num_attention_heads,
99
- attention_head_dim=attention_head_dim,
100
- )
101
- for i in range(num_single_layers)
102
- ]
103
- )
104
-
105
- # controlnet_blocks
106
- self.controlnet_blocks = nn.ModuleList([])
107
- for _ in range(len(self.transformer_blocks)):
108
- self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
109
-
110
- self.controlnet_single_blocks = nn.ModuleList([])
111
- for _ in range(len(self.single_transformer_blocks)):
112
- self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
113
-
114
- self.union = num_mode is not None and num_mode > 0
115
- if self.union:
116
- self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
117
-
118
- self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
119
-
120
- self.gradient_checkpointing = False
121
-
122
- @property
123
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
124
- def attn_processors(self):
125
- r"""
126
- Returns:
127
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
128
- indexed by its weight name.
129
- """
130
- # set recursively
131
- processors = {}
132
-
133
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
134
- if hasattr(module, "get_processor"):
135
- processors[f"{name}.processor"] = module.get_processor()
136
-
137
- for sub_name, child in module.named_children():
138
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
139
-
140
- return processors
141
-
142
- for name, module in self.named_children():
143
- fn_recursive_add_processors(name, module, processors)
144
-
145
- return processors
146
-
147
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
148
- def set_attn_processor(self, processor):
149
- r"""
150
- Sets the attention processor to use to compute attention.
151
-
152
- Parameters:
153
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
154
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
155
- for **all** `Attention` layers.
156
-
157
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
158
- processor. This is strongly recommended when setting trainable attention processors.
159
-
160
- """
161
- count = len(self.attn_processors.keys())
162
-
163
- if isinstance(processor, dict) and len(processor) != count:
164
- raise ValueError(
165
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
166
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
167
- )
168
-
169
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
170
- if hasattr(module, "set_processor"):
171
- if not isinstance(processor, dict):
172
- module.set_processor(processor)
173
- else:
174
- module.set_processor(processor.pop(f"{name}.processor"))
175
-
176
- for sub_name, child in module.named_children():
177
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
178
-
179
- for name, module in self.named_children():
180
- fn_recursive_attn_processor(name, module, processor)
181
-
182
- def _set_gradient_checkpointing(self, module, value=False):
183
- if hasattr(module, "gradient_checkpointing"):
184
- module.gradient_checkpointing = value
185
-
186
- @classmethod
187
- def from_transformer(
188
- cls,
189
- transformer,
190
- num_layers: int = 4,
191
- num_single_layers: int = 10,
192
- attention_head_dim: int = 128,
193
- num_attention_heads: int = 24,
194
- load_weights_from_transformer=True,
195
- ):
196
- config = transformer.config
197
- config["num_layers"] = num_layers
198
- config["num_single_layers"] = num_single_layers
199
- config["attention_head_dim"] = attention_head_dim
200
- config["num_attention_heads"] = num_attention_heads
201
-
202
- controlnet = cls(**config)
203
-
204
- if load_weights_from_transformer:
205
- controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
206
- controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
207
- controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
208
- controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
209
- controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
210
- controlnet.single_transformer_blocks.load_state_dict(
211
- transformer.single_transformer_blocks.state_dict(), strict=False
212
- )
213
-
214
- controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
215
-
216
- return controlnet
217
-
218
- def forward(
219
- self,
220
- hidden_states: torch.Tensor,
221
- controlnet_cond: torch.Tensor,
222
- controlnet_mode: torch.Tensor = None,
223
- conditioning_scale: float = 1.0,
224
- encoder_hidden_states: torch.Tensor = None,
225
- pooled_projections: torch.Tensor = None,
226
- timestep: torch.LongTensor = None,
227
- img_ids: torch.Tensor = None,
228
- txt_ids: torch.Tensor = None,
229
- guidance: torch.Tensor = None,
230
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
231
- return_dict: bool = True,
232
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
233
- """
234
- The [`FluxTransformer2DModel`] forward method.
235
-
236
- Args:
237
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
238
- Input `hidden_states`.
239
- controlnet_cond (`torch.Tensor`):
240
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
241
- controlnet_mode (`torch.Tensor`):
242
- The mode tensor of shape `(batch_size, 1)`.
243
- conditioning_scale (`float`, defaults to `1.0`):
244
- The scale factor for ControlNet outputs.
245
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
246
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
247
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
248
- from the embeddings of input conditions.
249
- timestep ( `torch.LongTensor`):
250
- Used to indicate denoising step.
251
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
252
- A list of tensors that if specified are added to the residuals of transformer blocks.
253
- joint_attention_kwargs (`dict`, *optional*):
254
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
255
- `self.processor` in
256
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
257
- return_dict (`bool`, *optional*, defaults to `True`):
258
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
259
- tuple.
260
-
261
- Returns:
262
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
263
- `tuple` where the first element is the sample tensor.
264
- """
265
- if guidance is not None:
266
- print("guidance is not supported in BriaFluxControlNetModel")
267
- if pooled_projections is not None:
268
- print("pooled_projections is not supported in BriaFluxControlNetModel")
269
- if joint_attention_kwargs is not None:
270
- joint_attention_kwargs = joint_attention_kwargs.copy()
271
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
272
- else:
273
- lora_scale = 1.0
274
-
275
- if USE_PEFT_BACKEND:
276
- # weight the lora layers by setting `lora_scale` for each PEFT layer
277
- scale_lora_layers(self, lora_scale)
278
- else:
279
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
280
- logger.warning(
281
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
282
- )
283
- hidden_states = self.x_embedder(hidden_states)
284
-
285
- # add
286
- hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
287
-
288
- timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
289
- if guidance is not None:
290
- guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
291
- else:
292
- guidance = None
293
- # temb = (
294
- # self.time_text_embed(timestep, pooled_projections)
295
- # if guidance is None
296
- # else self.time_text_embed(timestep, guidance, pooled_projections)
297
- # )
298
- temb = self.time_embed(timestep, dtype=hidden_states.dtype)
299
-
300
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
301
-
302
- if self.union:
303
- # union mode
304
- if controlnet_mode is None:
305
- raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
306
- # union mode emb
307
- controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
308
- if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]:
309
- controlnet_mode_emb = controlnet_mode_emb.expand(encoder_hidden_states.shape[0], 1, 2048)
310
- encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
311
- txt_ids = torch.cat((txt_ids[:, 0:1, :], txt_ids), dim=1)
312
-
313
- # if txt_ids.ndim == 3:
314
- # logger.warning(
315
- # "Passing `txt_ids` 3d torch.Tensor is deprecated."
316
- # "Please remove the batch dimension and pass it as a 2d torch Tensor"
317
- # )
318
- # txt_ids = txt_ids[0]
319
- # if img_ids.ndim == 3:
320
- # logger.warning(
321
- # "Passing `img_ids` 3d torch.Tensor is deprecated."
322
- # "Please remove the batch dimension and pass it as a 2d torch Tensor"
323
- # )
324
- # img_ids = img_ids[0]
325
-
326
- # ids = torch.cat((txt_ids, img_ids), dim=0)
327
- ids = torch.cat((txt_ids, img_ids), dim=1)
328
- image_rotary_emb = self.pos_embed(ids)
329
-
330
- block_samples = ()
331
- for index_block, block in enumerate(self.transformer_blocks):
332
- if self.training and self.gradient_checkpointing:
333
-
334
- def create_custom_forward(module, return_dict=None):
335
- def custom_forward(*inputs):
336
- if return_dict is not None:
337
- return module(*inputs, return_dict=return_dict)
338
- else:
339
- return module(*inputs)
340
-
341
- return custom_forward
342
-
343
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
345
- create_custom_forward(block),
346
- hidden_states,
347
- encoder_hidden_states,
348
- temb,
349
- image_rotary_emb,
350
- **ckpt_kwargs,
351
- )
352
-
353
- else:
354
- encoder_hidden_states, hidden_states = block(
355
- hidden_states=hidden_states,
356
- encoder_hidden_states=encoder_hidden_states,
357
- temb=temb,
358
- image_rotary_emb=image_rotary_emb,
359
- )
360
- block_samples = block_samples + (hidden_states,)
361
-
362
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
363
-
364
- single_block_samples = ()
365
- for index_block, block in enumerate(self.single_transformer_blocks):
366
- if self.training and self.gradient_checkpointing:
367
-
368
- def create_custom_forward(module, return_dict=None):
369
- def custom_forward(*inputs):
370
- if return_dict is not None:
371
- return module(*inputs, return_dict=return_dict)
372
- else:
373
- return module(*inputs)
374
-
375
- return custom_forward
376
-
377
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
378
- hidden_states = torch.utils.checkpoint.checkpoint(
379
- create_custom_forward(block),
380
- hidden_states,
381
- temb,
382
- image_rotary_emb,
383
- **ckpt_kwargs,
384
- )
385
-
386
- else:
387
- hidden_states = block(
388
- hidden_states=hidden_states,
389
- temb=temb,
390
- image_rotary_emb=image_rotary_emb,
391
- )
392
- single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
393
-
394
- # controlnet block
395
- controlnet_block_samples = ()
396
- for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
397
- block_sample = controlnet_block(block_sample)
398
- controlnet_block_samples = controlnet_block_samples + (block_sample,)
399
-
400
- controlnet_single_block_samples = ()
401
- for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
402
- single_block_sample = controlnet_block(single_block_sample)
403
- controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
404
-
405
- # scaling
406
- controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
407
- controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
408
-
409
- controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
410
- controlnet_single_block_samples = (
411
- None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
412
- )
413
-
414
- if USE_PEFT_BACKEND:
415
- # remove `lora_scale` from each PEFT layer
416
- unscale_lora_layers(self, lora_scale)
417
-
418
- if not return_dict:
419
- return (controlnet_block_samples, controlnet_single_block_samples)
420
-
421
- return FluxControlNetOutput(
422
- controlnet_block_samples=controlnet_block_samples,
423
- controlnet_single_block_samples=controlnet_single_block_samples,
424
- )
425
-
426
-
427
- class FluxMultiControlNetModel(ModelMixin):
428
- r"""
429
- `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
430
-
431
- This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
432
- compatible with `FluxControlNetModel`.
433
-
434
- Args:
435
- controlnets (`List[FluxControlNetModel]`):
436
- Provides additional conditioning to the unet during the denoising process. You must set multiple
437
- `FluxControlNetModel` as a list.
438
- """
439
-
440
- def __init__(self, controlnets):
441
- super().__init__()
442
- self.nets = nn.ModuleList(controlnets)
443
-
444
- def forward(
445
- self,
446
- hidden_states: torch.FloatTensor,
447
- controlnet_cond: List[torch.tensor],
448
- controlnet_mode: List[torch.tensor],
449
- conditioning_scale: List[float],
450
- encoder_hidden_states: torch.Tensor = None,
451
- pooled_projections: torch.Tensor = None,
452
- timestep: torch.LongTensor = None,
453
- img_ids: torch.Tensor = None,
454
- txt_ids: torch.Tensor = None,
455
- guidance: torch.Tensor = None,
456
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
457
- return_dict: bool = True,
458
- ) -> Union[FluxControlNetOutput, Tuple]:
459
- # ControlNet-Union with multiple conditions
460
- # only load one ControlNet for saving memories
461
- if len(self.nets) == 1 and self.nets[0].union:
462
- controlnet = self.nets[0]
463
-
464
- for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
465
- block_samples, single_block_samples = controlnet(
466
- hidden_states=hidden_states,
467
- controlnet_cond=image,
468
- controlnet_mode=mode[:, None],
469
- conditioning_scale=scale,
470
- timestep=timestep,
471
- guidance=guidance,
472
- pooled_projections=pooled_projections,
473
- encoder_hidden_states=encoder_hidden_states,
474
- txt_ids=txt_ids,
475
- img_ids=img_ids,
476
- joint_attention_kwargs=joint_attention_kwargs,
477
- return_dict=return_dict,
478
- )
479
-
480
- # merge samples
481
- if i == 0:
482
- control_block_samples = block_samples
483
- control_single_block_samples = single_block_samples
484
- else:
485
- control_block_samples = [
486
- control_block_sample + block_sample
487
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
488
- ]
489
-
490
- control_single_block_samples = [
491
- control_single_block_sample + block_sample
492
- for control_single_block_sample, block_sample in zip(
493
- control_single_block_samples, single_block_samples
494
- )
495
- ]
496
-
497
- # Regular Multi-ControlNets
498
- # load all ControlNets into memories
499
- else:
500
- for i, (image, mode, scale, controlnet) in enumerate(
501
- zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
502
- ):
503
- block_samples, single_block_samples = controlnet(
504
- hidden_states=hidden_states,
505
- controlnet_cond=image,
506
- controlnet_mode=mode[:, None],
507
- conditioning_scale=scale,
508
- timestep=timestep,
509
- guidance=guidance,
510
- pooled_projections=pooled_projections,
511
- encoder_hidden_states=encoder_hidden_states,
512
- txt_ids=txt_ids,
513
- img_ids=img_ids,
514
- joint_attention_kwargs=joint_attention_kwargs,
515
- return_dict=return_dict,
516
- )
517
-
518
- # merge samples
519
- if i == 0:
520
- control_block_samples = block_samples
521
- control_single_block_samples = single_block_samples
522
- else:
523
- if block_samples is not None and control_block_samples is not None:
524
- control_block_samples = [
525
- control_block_sample + block_sample
526
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
527
- ]
528
- if single_block_samples is not None and control_single_block_samples is not None:
529
- control_single_block_samples = [
530
- control_single_block_sample + block_sample
531
- for control_single_block_sample, block_sample in zip(
532
- control_single_block_samples, single_block_samples
533
- )
534
- ]
535
-
536
- return control_block_samples, control_single_block_samples
537
-
538
-
539
-
540
- class FluxMultiControlNetModel(ModelMixin):
541
- r"""
542
- `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
543
-
544
- This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
545
- compatible with `FluxControlNetModel`.
546
-
547
- Args:
548
- controlnets (`List[FluxControlNetModel]`):
549
- Provides additional conditioning to the unet during the denoising process. You must set multiple
550
- `FluxControlNetModel` as a list.
551
- """
552
-
553
- def __init__(self, controlnets):
554
- super().__init__()
555
- self.nets = nn.ModuleList(controlnets)
556
-
557
- def forward(
558
- self,
559
- hidden_states: torch.FloatTensor,
560
- controlnet_cond: List[torch.tensor],
561
- controlnet_mode: List[torch.tensor],
562
- conditioning_scale: List[float],
563
- encoder_hidden_states: torch.Tensor = None,
564
- pooled_projections: torch.Tensor = None,
565
- timestep: torch.LongTensor = None,
566
- img_ids: torch.Tensor = None,
567
- txt_ids: torch.Tensor = None,
568
- guidance: torch.Tensor = None,
569
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
570
- return_dict: bool = True,
571
- ) -> Union[FluxControlNetOutput, Tuple]:
572
- # ControlNet-Union with multiple conditions
573
- # only load one ControlNet for saving memories
574
- if len(self.nets) == 1 and self.nets[0].union:
575
- controlnet = self.nets[0]
576
-
577
- for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
578
- block_samples, single_block_samples = controlnet(
579
- hidden_states=hidden_states,
580
- controlnet_cond=image,
581
- controlnet_mode=mode[:, None],
582
- conditioning_scale=scale,
583
- timestep=timestep,
584
- guidance=guidance,
585
- pooled_projections=pooled_projections,
586
- encoder_hidden_states=encoder_hidden_states,
587
- txt_ids=txt_ids,
588
- img_ids=img_ids,
589
- joint_attention_kwargs=joint_attention_kwargs,
590
- return_dict=return_dict,
591
- )
592
-
593
- # merge samples
594
- if i == 0:
595
- control_block_samples = block_samples
596
- control_single_block_samples = single_block_samples
597
- else:
598
- control_block_samples = [
599
- control_block_sample + block_sample
600
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
601
- ]
602
-
603
- control_single_block_samples = [
604
- control_single_block_sample + block_sample
605
- for control_single_block_sample, block_sample in zip(
606
- control_single_block_samples, single_block_samples
607
- )
608
- ]
609
-
610
- # Regular Multi-ControlNets
611
- # load all ControlNets into memories
612
- else:
613
- for i, (image, mode, scale, controlnet) in enumerate(
614
- zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
615
- ):
616
- block_samples, single_block_samples = controlnet(
617
- hidden_states=hidden_states,
618
- controlnet_cond=image,
619
- controlnet_mode=mode[:, None],
620
- conditioning_scale=scale,
621
- timestep=timestep,
622
- guidance=guidance,
623
- pooled_projections=pooled_projections,
624
- encoder_hidden_states=encoder_hidden_states,
625
- txt_ids=txt_ids,
626
- img_ids=img_ids,
627
- joint_attention_kwargs=joint_attention_kwargs,
628
- return_dict=return_dict,
629
- )
630
-
631
- # merge samples
632
- if i == 0:
633
- control_block_samples = block_samples
634
- control_single_block_samples = single_block_samples
635
- else:
636
- if block_samples is not None and control_block_samples is not None:
637
- control_block_samples = [
638
- control_block_sample + block_sample
639
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
640
- ]
641
- if single_block_samples is not None and control_single_block_samples is not None:
642
- control_single_block_samples = [
643
- control_single_block_sample + block_sample
644
- for control_single_block_sample, block_sample in zip(
645
- control_single_block_samples, single_block_samples
646
- )
647
- ]
648
-
649
- return control_block_samples, control_single_block_samples