#ifndef __VAE_HPP__ #define __VAE_HPP__ #include "common.hpp" #include "ggml_extend.hpp" /*================================================== AutoEncoderKL ===================================================*/ #define VAE_GRAPH_SIZE 20480 class ResnetBlock : public UnaryBlock { protected: int64_t in_channels; int64_t out_channels; public: ResnetBlock(int64_t in_channels, int64_t out_channels) : in_channels(in_channels), out_channels(out_channels) { // temb_channels is always 0 blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels)); blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels)); blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); if (out_channels != in_channels) { blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1})); } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, in_channels, h, w] // t_emb is always None auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]); auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]); auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm2"]); auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]); auto h = x; h = norm1->forward(ctx, h); h = ggml_silu_inplace(ctx, h); // swish h = conv1->forward(ctx, h); // return h; h = norm2->forward(ctx, h); h = ggml_silu_inplace(ctx, h); // swish // dropout, skip for inference h = conv2->forward(ctx, h); // skip connection if (out_channels != in_channels) { auto nin_shortcut = std::dynamic_pointer_cast<Conv2d>(blocks["nin_shortcut"]); x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w] } h = ggml_add(ctx, h, x); return h; // [N, out_channels, h, w] } }; class AttnBlock : public UnaryBlock { protected: int64_t in_channels; public: AttnBlock(int64_t in_channels) : in_channels(in_channels) { blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels)); blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1})); blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1})); blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1})); blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1})); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, in_channels, h, w] auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]); auto q_proj = std::dynamic_pointer_cast<Conv2d>(blocks["q"]); auto k_proj = std::dynamic_pointer_cast<Conv2d>(blocks["k"]); auto v_proj = std::dynamic_pointer_cast<Conv2d>(blocks["v"]); auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]); auto h_ = norm->forward(ctx, x); const int64_t n = h_->ne[3]; const int64_t c = h_->ne[2]; const int64_t h = h_->ne[1]; const int64_t w = h_->ne[0]; auto q = q_proj->forward(ctx, h_); // [N, in_channels, h, w] q = ggml_cont(ctx, ggml_permute(ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] q = ggml_reshape_3d(ctx, q, c, h * w, n); // [N, h * w, in_channels] auto k = k_proj->forward(ctx, h_); // [N, in_channels, h, w] k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels] auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w] h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels] h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w] h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w] h_ = ggml_add(ctx, h_, x); return h_; } }; class AE3DConv : public Conv2d { public: AE3DConv(int64_t in_channels, int64_t out_channels, std::pair<int, int> kernel_size, int64_t video_kernel_size = 3, std::pair<int, int> stride = {1, 1}, std::pair<int, int> padding = {0, 0}, std::pair<int, int> dilation = {1, 1}, bool bias = true) : Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) { int64_t kernel_padding = video_kernel_size / 2; blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(out_channels, out_channels, video_kernel_size, 1, kernel_padding)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // timesteps always None // skip_video always False // x: [N, IC, IH, IW] // result: [N, OC, OH, OW] auto time_mix_conv = std::dynamic_pointer_cast<Conv3dnx1x1>(blocks["time_mix_conv"]); x = Conv2d::forward(ctx, x); // timesteps = x.shape[0] // x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) // x = conv3d(x) // return rearrange(x, "b c t h w -> (b t) c h w") int64_t T = x->ne[3]; int64_t B = x->ne[3] / T; int64_t C = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW] x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w return x; // [B*T, OC, OH, OW] } }; class VideoResnetBlock : public ResnetBlock { protected: void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") { enum ggml_type wtype = (tensor_types.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32; params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); } float get_alpha() { float alpha = ggml_backend_tensor_get_f32(params["mix_factor"]); return sigmoid(alpha); } public: VideoResnetBlock(int64_t in_channels, int64_t out_channels, int video_kernel_size = 3) : ResnetBlock(in_channels, out_channels) { // merge_strategy is always learned blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, in_channels, h, w] aka [b*t, in_channels, h, w] // return: [N, out_channels, h, w] aka [b*t, out_channels, h, w] // t_emb is always None // skip_video is always False // timesteps is always None auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]); x = ResnetBlock::forward(ctx, x); // [N, out_channels, h, w] // return x; int64_t T = x->ne[3]; int64_t B = x->ne[3] / T; int64_t C = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) auto x_mix = x; x = time_stack->forward(ctx, x); // b t c (h w) float alpha = get_alpha(); x = ggml_add(ctx, ggml_scale(ctx, x, alpha), ggml_scale(ctx, x_mix, 1.0f - alpha)); x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w return x; } }; // ldm.modules.diffusionmodules.model.Encoder class Encoder : public GGMLBlock { protected: int ch = 128; std::vector<int> ch_mult = {1, 2, 4, 4}; int num_res_blocks = 2; int in_channels = 3; int z_channels = 4; bool double_z = true; public: Encoder(int ch, std::vector<int> ch_mult, int num_res_blocks, int in_channels, int z_channels, bool double_z = true) : ch(ch), ch_mult(ch_mult), num_res_blocks(num_res_blocks), in_channels(in_channels), z_channels(z_channels), double_z(double_z) { blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1})); size_t num_resolutions = ch_mult.size(); int block_in = 1; for (int i = 0; i < num_resolutions; i++) { if (i == 0) { block_in = ch; } else { block_in = ch * ch_mult[i - 1]; } int block_out = ch * ch_mult[i]; for (int j = 0; j < num_res_blocks; j++) { std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j); blocks[name] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_out)); block_in = block_out; } if (i != num_resolutions - 1) { std::string name = "down." + std::to_string(i) + ".downsample"; blocks[name] = std::shared_ptr<GGMLBlock>(new DownSampleBlock(block_in, block_in, true)); } } blocks["mid.block_1"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in)); blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in)); blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in)); blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in)); blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1})); } virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, in_channels, h, w] auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]); auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]); auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]); auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]); auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]); auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]); auto h = conv_in->forward(ctx, x); // [N, ch, h, w] // downsampling size_t num_resolutions = ch_mult.size(); for (int i = 0; i < num_resolutions; i++) { for (int j = 0; j < num_res_blocks; j++) { std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j); auto down_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]); h = down_block->forward(ctx, h); } if (i != num_resolutions - 1) { std::string name = "down." + std::to_string(i) + ".downsample"; auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]); h = down_sample->forward(ctx, h); } } // middle h = mid_block_1->forward(ctx, h); h = mid_attn_1->forward(ctx, h); h = mid_block_2->forward(ctx, h); // [N, block_in, h, w] // end h = norm_out->forward(ctx, h); h = ggml_silu_inplace(ctx, h); // nonlinearity/swish h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w] return h; } }; // ldm.modules.diffusionmodules.model.Decoder class Decoder : public GGMLBlock { protected: int ch = 128; int out_ch = 3; std::vector<int> ch_mult = {1, 2, 4, 4}; int num_res_blocks = 2; int z_channels = 4; bool video_decoder = false; int video_kernel_size = 3; virtual std::shared_ptr<GGMLBlock> get_conv_out(int64_t in_channels, int64_t out_channels, std::pair<int, int> kernel_size, std::pair<int, int> stride = {1, 1}, std::pair<int, int> padding = {0, 0}) { if (video_decoder) { return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding)); } else { return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding)); } } virtual std::shared_ptr<GGMLBlock> get_resnet_block(int64_t in_channels, int64_t out_channels) { if (video_decoder) { return std::shared_ptr<GGMLBlock>(new VideoResnetBlock(in_channels, out_channels, video_kernel_size)); } else { return std::shared_ptr<GGMLBlock>(new ResnetBlock(in_channels, out_channels)); } } public: Decoder(int ch, int out_ch, std::vector<int> ch_mult, int num_res_blocks, int z_channels, bool video_decoder = false, int video_kernel_size = 3) : ch(ch), out_ch(out_ch), ch_mult(ch_mult), num_res_blocks(num_res_blocks), z_channels(z_channels), video_decoder(video_decoder), video_kernel_size(video_kernel_size) { size_t num_resolutions = ch_mult.size(); int block_in = ch * ch_mult[num_resolutions - 1]; blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1})); blocks["mid.block_1"] = get_resnet_block(block_in, block_in); blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in)); blocks["mid.block_2"] = get_resnet_block(block_in, block_in); for (int i = num_resolutions - 1; i >= 0; i--) { int mult = ch_mult[i]; int block_out = ch * mult; for (int j = 0; j < num_res_blocks + 1; j++) { std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j); blocks[name] = get_resnet_block(block_in, block_out); block_in = block_out; } if (i != 0) { std::string name = "up." + std::to_string(i) + ".upsample"; blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(block_in, block_in)); } } blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in)); blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1}); } virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) { // z: [N, z_channels, h, w] // alpha is always 0 // merge_strategy is always learned // time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock // AttnVideoBlock will not be used auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]); auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]); auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]); auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]); auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]); auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]); // conv_in auto h = conv_in->forward(ctx, z); // [N, block_in, h, w] // middle h = mid_block_1->forward(ctx, h); // return h; h = mid_attn_1->forward(ctx, h); h = mid_block_2->forward(ctx, h); // [N, block_in, h, w] // upsampling size_t num_resolutions = ch_mult.size(); for (int i = num_resolutions - 1; i >= 0; i--) { for (int j = 0; j < num_res_blocks + 1; j++) { std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j); auto up_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]); h = up_block->forward(ctx, h); } if (i != 0) { std::string name = "up." + std::to_string(i) + ".upsample"; auto up_sample = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]); h = up_sample->forward(ctx, h); } } h = norm_out->forward(ctx, h); h = ggml_silu_inplace(ctx, h); // nonlinearity/swish h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8] return h; } }; // ldm.models.autoencoder.AutoencoderKL class AutoencodingEngine : public GGMLBlock { protected: bool decode_only = true; bool use_video_decoder = false; bool use_quant = true; int embed_dim = 4; struct { int z_channels = 4; int resolution = 256; int in_channels = 3; int out_ch = 3; int ch = 128; std::vector<int> ch_mult = {1, 2, 4, 4}; int num_res_blocks = 2; bool double_z = true; } dd_config; public: AutoencodingEngine(bool decode_only = true, bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { if (sd_version_is_dit(version)) { dd_config.z_channels = 16; use_quant = false; } if (use_video_decoder) { use_quant = false; } blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch, dd_config.out_ch, dd_config.ch_mult, dd_config.num_res_blocks, dd_config.z_channels, use_video_decoder)); if (use_quant) { blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels, embed_dim, {1, 1})); } if (!decode_only) { blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder(dd_config.ch, dd_config.ch_mult, dd_config.num_res_blocks, dd_config.in_channels, dd_config.z_channels, dd_config.double_z)); if (use_quant) { int factor = dd_config.double_z ? 2 : 1; blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor, dd_config.z_channels * factor, {1, 1})); } } } struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) { // z: [N, z_channels, h, w] if (use_quant) { auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]); z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w] } auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]); ggml_set_name(z, "bench-start"); auto h = decoder->forward(ctx, z); ggml_set_name(h, "bench-end"); return h; } struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, in_channels, h, w] auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]); auto h = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8] if (use_quant) { auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]); h = quant_conv->forward(ctx, h); // [N, 2*embed_dim, h/8, w/8] } return h; } }; struct AutoEncoderKL : public GGMLRunner { bool decode_only = true; AutoencodingEngine ae; AutoEncoderKL(ggml_backend_t backend, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix, bool decode_only = false, bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend) { ae.init(params_ctx, tensor_types, prefix); } std::string get_desc() { return "vae"; } void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) { ae.get_param_tensors(tensors, prefix); } struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); z = to_backend(z); struct ggml_tensor* out = decode_graph ? ae.decode(compute_ctx, z) : ae.encode(compute_ctx, z); ggml_build_forward_expand(gf, out); return gf; } void compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, struct ggml_context* output_ctx = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(z, decode_graph); }; // ggml_set_f32(z, 0.5f); // print_ggml_tensor(z); GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } void test() { struct ggml_init_params params; params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB params.mem_buffer = NULL; params.no_alloc = false; struct ggml_context* work_ctx = ggml_init(params); GGML_ASSERT(work_ctx != NULL); { // CPU, x{1, 3, 64, 64}: Pass // CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan // CPU, x{2, 3, 64, 64}: Wrong result // CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2); ggml_set_f32(x, 0.5f); print_ggml_tensor(x); struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); compute(8, x, false, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); LOG_DEBUG("encode test done in %dms", t1 - t0); } if (false) { // CPU, z{1, 4, 8, 8}: Pass // CUDA, z{1, 4, 8, 8}: Pass // CPU, z{3, 4, 8, 8}: Wrong result // CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); ggml_set_f32(z, 0.5f); print_ggml_tensor(z); struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); compute(8, z, true, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); LOG_DEBUG("decode test done in %dms", t1 - t0); } }; }; #endif