Commit
·
10c31ce
1
Parent(s):
d204b62
added option to skip mid block (#5)
Browse files- added option to skip mid block (b84b5c4088d29603d9c765d32549bb39b23231ed)
- pipeline.py +28 -12
pipeline.py
CHANGED
|
@@ -52,6 +52,19 @@ def custom_sort_order(obj):
|
|
| 52 |
return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
|
| 56 |
configurations = FlexibleUnetConfigurations
|
| 57 |
|
|
@@ -105,18 +118,21 @@ class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
|
|
| 105 |
mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
|
| 106 |
mid_num_attentions = self.configurations.get("mid_num_attentions")
|
| 107 |
mid_num_resnets = self.configurations.get("mid_num_resnets")
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
###############
|
| 122 |
# Up blocks #
|
|
|
|
| 52 |
return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)
|
| 53 |
|
| 54 |
|
| 55 |
+
class FlexibleIdentityBlock(nn.Module):
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
hidden_states: torch.FloatTensor,
|
| 59 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 60 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 61 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 62 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 63 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 64 |
+
):
|
| 65 |
+
return hidden_states
|
| 66 |
+
|
| 67 |
+
|
| 68 |
class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
|
| 69 |
configurations = FlexibleUnetConfigurations
|
| 70 |
|
|
|
|
| 118 |
mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
|
| 119 |
mid_num_attentions = self.configurations.get("mid_num_attentions")
|
| 120 |
mid_num_resnets = self.configurations.get("mid_num_resnets")
|
| 121 |
+
|
| 122 |
+
if mid_num_resnets == mid_num_attentions == 0:
|
| 123 |
+
self.mid_block = FlexibleIdentityBlock()
|
| 124 |
+
else:
|
| 125 |
+
self.mid_block = FlexibleUNetMidBlock2DCrossAttn(in_channels=down_blocks_out_channels[-1],
|
| 126 |
+
temb_channels=temb_dim,
|
| 127 |
+
resnet_act_fn=resnet_act_fn,
|
| 128 |
+
resnet_eps=resnet_eps,
|
| 129 |
+
cross_attention_dim=cross_attention_dim,
|
| 130 |
+
num_attention_heads=num_attention_heads,
|
| 131 |
+
num_resnets=mid_num_resnets,
|
| 132 |
+
num_attentions=mid_num_attentions,
|
| 133 |
+
mix_block_in_forward=mix_block_in_forward,
|
| 134 |
+
add_upsample=mid_block_add_upsample
|
| 135 |
+
)
|
| 136 |
|
| 137 |
###############
|
| 138 |
# Up blocks #
|