mishiawan commited on
Commit
fa9f8d5
·
verified ·
1 Parent(s): b04b31c

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +178 -0
utils.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import FlowMatchEulerDiscreteScheduler
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+
9
+ from lbm.models.embedders import (
10
+ ConditionerWrapper,
11
+ LatentsConcatEmbedder,
12
+ LatentsConcatEmbedderConfig,
13
+ )
14
+ from lbm.models.lbm import LBMConfig, LBMModel
15
+ from lbm.models.unets import DiffusersUNet2DCondWrapper
16
+ from lbm.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
17
+
18
+
19
+ def get_model_from_config(
20
+ backbone_signature: str = "stabilityai/stable-diffusion-xl-base-1.0",
21
+ vae_num_channels: int = 4,
22
+ unet_input_channels: int = 4,
23
+ timestep_sampling: str = "log_normal",
24
+ selected_timesteps: List[float] = None,
25
+ prob: List[float] = None,
26
+ conditioning_images_keys: List[str] = [],
27
+ conditioning_masks_keys: List[str] = ["mask"],
28
+ source_key: str = "source_image",
29
+ target_key: str = "source_image_paste",
30
+ bridge_noise_sigma: float = 0.0,
31
+ ):
32
+
33
+ conditioners = []
34
+
35
+ denoiser = DiffusersUNet2DCondWrapper(
36
+ in_channels=unet_input_channels, # Add downsampled_image
37
+ out_channels=vae_num_channels,
38
+ center_input_sample=False,
39
+ flip_sin_to_cos=True,
40
+ freq_shift=0,
41
+ down_block_types=[
42
+ "DownBlock2D",
43
+ "CrossAttnDownBlock2D",
44
+ "CrossAttnDownBlock2D",
45
+ ],
46
+ mid_block_type="UNetMidBlock2DCrossAttn",
47
+ up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
48
+ only_cross_attention=False,
49
+ block_out_channels=[320, 640, 1280],
50
+ layers_per_block=2,
51
+ downsample_padding=1,
52
+ mid_block_scale_factor=1,
53
+ dropout=0.0,
54
+ act_fn="silu",
55
+ norm_num_groups=32,
56
+ norm_eps=1e-05,
57
+ cross_attention_dim=[320, 640, 1280],
58
+ transformer_layers_per_block=[1, 2, 10],
59
+ reverse_transformer_layers_per_block=None,
60
+ encoder_hid_dim=None,
61
+ encoder_hid_dim_type=None,
62
+ attention_head_dim=[5, 10, 20],
63
+ num_attention_heads=None,
64
+ dual_cross_attention=False,
65
+ use_linear_projection=True,
66
+ class_embed_type=None,
67
+ addition_embed_type=None,
68
+ addition_time_embed_dim=None,
69
+ num_class_embeds=None,
70
+ upcast_attention=None,
71
+ resnet_time_scale_shift="default",
72
+ resnet_skip_time_act=False,
73
+ resnet_out_scale_factor=1.0,
74
+ time_embedding_type="positional",
75
+ time_embedding_dim=None,
76
+ time_embedding_act_fn=None,
77
+ timestep_post_act=None,
78
+ time_cond_proj_dim=None,
79
+ conv_in_kernel=3,
80
+ conv_out_kernel=3,
81
+ projection_class_embeddings_input_dim=None,
82
+ attention_type="default",
83
+ class_embeddings_concat=False,
84
+ mid_block_only_cross_attention=None,
85
+ cross_attention_norm=None,
86
+ addition_embed_type_num_heads=64,
87
+ ).to(torch.bfloat16)
88
+
89
+ if conditioning_images_keys != [] or conditioning_masks_keys != []:
90
+
91
+ latents_concat_embedder_config = LatentsConcatEmbedderConfig(
92
+ image_keys=conditioning_images_keys,
93
+ mask_keys=conditioning_masks_keys,
94
+ )
95
+ latent_concat_embedder = LatentsConcatEmbedder(latents_concat_embedder_config)
96
+ latent_concat_embedder.freeze()
97
+ conditioners.append(latent_concat_embedder)
98
+
99
+ # Wrap conditioners and set to device
100
+ conditioner = ConditionerWrapper(
101
+ conditioners=conditioners,
102
+ )
103
+
104
+ ## VAE ##
105
+ # Get VAE model
106
+ vae_config = AutoencoderKLDiffusersConfig(
107
+ version=backbone_signature,
108
+ subfolder="vae",
109
+ tiling_size=(128, 128),
110
+ )
111
+ vae = AutoencoderKLDiffusers(vae_config).to(torch.bfloat16)
112
+ vae.freeze()
113
+ vae.to(torch.bfloat16)
114
+
115
+ ## Diffusion Model ##
116
+ # Get diffusion model
117
+ config = LBMConfig(
118
+ source_key=source_key,
119
+ target_key=target_key,
120
+ timestep_sampling=timestep_sampling,
121
+ selected_timesteps=selected_timesteps,
122
+ prob=prob,
123
+ bridge_noise_sigma=bridge_noise_sigma,
124
+ )
125
+
126
+ sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
127
+ backbone_signature,
128
+ subfolder="scheduler",
129
+ )
130
+
131
+ model = LBMModel(
132
+ config,
133
+ denoiser=denoiser,
134
+ sampling_noise_scheduler=sampling_noise_scheduler,
135
+ vae=vae,
136
+ conditioner=conditioner,
137
+ ).to(torch.bfloat16)
138
+
139
+ return model
140
+
141
+
142
+ def extract_object(birefnet, img):
143
+ # Data settings
144
+ image_size = (1024, 1024)
145
+ transform_image = transforms.Compose(
146
+ [
147
+ transforms.Resize(image_size),
148
+ transforms.ToTensor(),
149
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
150
+ ]
151
+ )
152
+
153
+ image = img
154
+ input_images = transform_image(image).unsqueeze(0).cuda()
155
+
156
+ # Prediction
157
+ with torch.no_grad():
158
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
159
+ pred = preds[0].squeeze()
160
+ pred_pil = transforms.ToPILImage()(pred)
161
+ mask = pred_pil.resize(image.size)
162
+ image = Image.composite(image, Image.new("RGB", image.size, (127, 127, 127)), mask)
163
+ return image, mask
164
+
165
+
166
+ def resize_and_center_crop(image, target_width, target_height):
167
+ original_width, original_height = image.size
168
+ scale_factor = max(target_width / original_width, target_height / original_height)
169
+ resized_width = int(round(original_width * scale_factor))
170
+ resized_height = int(round(original_height * scale_factor))
171
+ resized_image = image.resize((resized_width, resized_height), Image.LANCZOS)
172
+ left = (resized_width - target_width) / 2
173
+ top = (resized_height - target_height) / 2
174
+ right = (resized_width + target_width) / 2
175
+ bottom = (resized_height + target_height) / 2
176
+ cropped_image = resized_image.crop((left, top, right, bottom))
177
+ return cropped_image
178
+