#ifndef __VAE_HPP__
#define __VAE_HPP__

#include "common.hpp"
#include "ggml_extend.hpp"

/*================================================== AutoEncoderKL ===================================================*/

#define VAE_GRAPH_SIZE 20480

class ResnetBlock : public UnaryBlock {
protected:
    int64_t in_channels;
    int64_t out_channels;

public:
    ResnetBlock(int64_t in_channels,
                int64_t out_channels)
        : in_channels(in_channels),
          out_channels(out_channels) {
        // temb_channels is always 0
        blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
        blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));

        blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
        blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));

        if (out_channels != in_channels) {
            blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}));
        }
    }

    struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
        // x: [N, in_channels, h, w]
        // t_emb is always None
        auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
        auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]);
        auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm2"]);
        auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]);

        auto h = x;
        h      = norm1->forward(ctx, h);
        h      = ggml_silu_inplace(ctx, h);  // swish
        h      = conv1->forward(ctx, h);
        // return h;

        h = norm2->forward(ctx, h);
        h = ggml_silu_inplace(ctx, h);  // swish
        // dropout, skip for inference
        h = conv2->forward(ctx, h);

        // skip connection
        if (out_channels != in_channels) {
            auto nin_shortcut = std::dynamic_pointer_cast<Conv2d>(blocks["nin_shortcut"]);

            x = nin_shortcut->forward(ctx, x);  // [N, out_channels, h, w]
        }

        h = ggml_add(ctx, h, x);
        return h;  // [N, out_channels, h, w]
    }
};

class AttnBlock : public UnaryBlock {
protected:
    int64_t in_channels;

public:
    AttnBlock(int64_t in_channels)
        : in_channels(in_channels) {
        blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
        blocks["q"]    = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
        blocks["k"]    = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
        blocks["v"]    = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));

        blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
    }

    struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
        // x: [N, in_channels, h, w]
        auto norm     = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
        auto q_proj   = std::dynamic_pointer_cast<Conv2d>(blocks["q"]);
        auto k_proj   = std::dynamic_pointer_cast<Conv2d>(blocks["k"]);
        auto v_proj   = std::dynamic_pointer_cast<Conv2d>(blocks["v"]);
        auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);

        auto h_ = norm->forward(ctx, x);

        const int64_t n = h_->ne[3];
        const int64_t c = h_->ne[2];
        const int64_t h = h_->ne[1];
        const int64_t w = h_->ne[0];

        auto q = q_proj->forward(ctx, h_);                          // [N, in_channels, h, w]
        q      = ggml_cont(ctx, ggml_permute(ctx, q, 1, 2, 0, 3));  // [N, h, w, in_channels]
        q      = ggml_reshape_3d(ctx, q, c, h * w, n);              // [N, h * w, in_channels]

        auto k = k_proj->forward(ctx, h_);                          // [N, in_channels, h, w]
        k      = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3));  // [N, h, w, in_channels]
        k      = ggml_reshape_3d(ctx, k, c, h * w, n);              // [N, h * w, in_channels]

        auto v = v_proj->forward(ctx, h_);              // [N, in_channels, h, w]
        v      = ggml_reshape_3d(ctx, v, h * w, c, n);  // [N, in_channels, h * w]

        h_ = ggml_nn_attention(ctx, q, k, v, false);  // [N, h * w, in_channels]

        h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3));  // [N, in_channels, h * w]
        h_ = ggml_reshape_4d(ctx, h_, w, h, c, n);               // [N, in_channels, h, w]

        h_ = proj_out->forward(ctx, h_);  // [N, in_channels, h, w]

        h_ = ggml_add(ctx, h_, x);
        return h_;
    }
};

class AE3DConv : public Conv2d {
public:
    AE3DConv(int64_t in_channels,
             int64_t out_channels,
             std::pair<int, int> kernel_size,
             int64_t video_kernel_size    = 3,
             std::pair<int, int> stride   = {1, 1},
             std::pair<int, int> padding  = {0, 0},
             std::pair<int, int> dilation = {1, 1},
             bool bias                    = true)
        : Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
        int64_t kernel_padding  = video_kernel_size / 2;
        blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(out_channels,
                                                                             out_channels,
                                                                             video_kernel_size,
                                                                             1,
                                                                             kernel_padding));
    }

    struct ggml_tensor* forward(struct ggml_context* ctx,
                                struct ggml_tensor* x) {
        // timesteps always None
        // skip_video always False
        // x: [N, IC, IH, IW]
        // result: [N, OC, OH, OW]
        auto time_mix_conv = std::dynamic_pointer_cast<Conv3dnx1x1>(blocks["time_mix_conv"]);

        x = Conv2d::forward(ctx, x);
        // timesteps = x.shape[0]
        // x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
        // x = conv3d(x)
        // return rearrange(x, "b c t h w -> (b t) c h w")
        int64_t T = x->ne[3];
        int64_t B = x->ne[3] / T;
        int64_t C = x->ne[2];
        int64_t H = x->ne[1];
        int64_t W = x->ne[0];

        x = ggml_reshape_4d(ctx, x, W * H, C, T, B);           // (b t) c h w -> b t c (h w)
        x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));  // b t c (h w) -> b c t (h w)
        x = time_mix_conv->forward(ctx, x);                    // [B, OC, T, OH * OW]
        x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));  // b c t (h w) -> b t c (h w)
        x = ggml_reshape_4d(ctx, x, W, H, C, T * B);           // b t c (h w) -> (b t) c h w
        return x;                                              // [B*T, OC, OH, OW]
    }
};

class VideoResnetBlock : public ResnetBlock {
protected:
    void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
        enum ggml_type wtype = (tensor_types.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32;
        params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
    }

    float get_alpha() {
        float alpha = ggml_backend_tensor_get_f32(params["mix_factor"]);
        return sigmoid(alpha);
    }

public:
    VideoResnetBlock(int64_t in_channels,
                     int64_t out_channels,
                     int video_kernel_size = 3)
        : ResnetBlock(in_channels, out_channels) {
        // merge_strategy is always learned
        blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true));
    }

    struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
        // x: [N, in_channels, h, w] aka [b*t, in_channels, h, w]
        // return: [N, out_channels, h, w] aka [b*t, out_channels, h, w]
        // t_emb is always None
        // skip_video is always False
        // timesteps is always None
        auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]);

        x = ResnetBlock::forward(ctx, x);  // [N, out_channels, h, w]
        // return x;

        int64_t T = x->ne[3];
        int64_t B = x->ne[3] / T;
        int64_t C = x->ne[2];
        int64_t H = x->ne[1];
        int64_t W = x->ne[0];

        x          = ggml_reshape_4d(ctx, x, W * H, C, T, B);           // (b t) c h w -> b t c (h w)
        x          = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));  // b t c (h w) -> b c t (h w)
        auto x_mix = x;

        x = time_stack->forward(ctx, x);  // b t c (h w)

        float alpha = get_alpha();
        x           = ggml_add(ctx,
                               ggml_scale(ctx, x, alpha),
                               ggml_scale(ctx, x_mix, 1.0f - alpha));

        x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));  // b c t (h w) -> b t c (h w)
        x = ggml_reshape_4d(ctx, x, W, H, C, T * B);           // b t c (h w) -> (b t) c h w

        return x;
    }
};

// ldm.modules.diffusionmodules.model.Encoder
class Encoder : public GGMLBlock {
protected:
    int ch                   = 128;
    std::vector<int> ch_mult = {1, 2, 4, 4};
    int num_res_blocks       = 2;
    int in_channels          = 3;
    int z_channels           = 4;
    bool double_z            = true;

public:
    Encoder(int ch,
            std::vector<int> ch_mult,
            int num_res_blocks,
            int in_channels,
            int z_channels,
            bool double_z = true)
        : ch(ch),
          ch_mult(ch_mult),
          num_res_blocks(num_res_blocks),
          in_channels(in_channels),
          z_channels(z_channels),
          double_z(double_z) {
        blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));

        size_t num_resolutions = ch_mult.size();

        int block_in = 1;
        for (int i = 0; i < num_resolutions; i++) {
            if (i == 0) {
                block_in = ch;
            } else {
                block_in = ch * ch_mult[i - 1];
            }
            int block_out = ch * ch_mult[i];
            for (int j = 0; j < num_res_blocks; j++) {
                std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
                blocks[name]     = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_out));
                block_in         = block_out;
            }
            if (i != num_resolutions - 1) {
                std::string name = "down." + std::to_string(i) + ".downsample";
                blocks[name]     = std::shared_ptr<GGMLBlock>(new DownSampleBlock(block_in, block_in, true));
            }
        }

        blocks["mid.block_1"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
        blocks["mid.attn_1"]  = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in));
        blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));

        blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
        blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
    }

    virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
        // x: [N, in_channels, h, w]

        auto conv_in     = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
        auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
        auto mid_attn_1  = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
        auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
        auto norm_out    = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
        auto conv_out    = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);

        auto h = conv_in->forward(ctx, x);  // [N, ch, h, w]

        // downsampling
        size_t num_resolutions = ch_mult.size();
        for (int i = 0; i < num_resolutions; i++) {
            for (int j = 0; j < num_res_blocks; j++) {
                std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
                auto down_block  = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);

                h = down_block->forward(ctx, h);
            }
            if (i != num_resolutions - 1) {
                std::string name = "down." + std::to_string(i) + ".downsample";
                auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);

                h = down_sample->forward(ctx, h);
            }
        }

        // middle
        h = mid_block_1->forward(ctx, h);
        h = mid_attn_1->forward(ctx, h);
        h = mid_block_2->forward(ctx, h);  // [N, block_in, h, w]

        // end
        h = norm_out->forward(ctx, h);
        h = ggml_silu_inplace(ctx, h);  // nonlinearity/swish
        h = conv_out->forward(ctx, h);  // [N, z_channels*2, h, w]
        return h;
    }
};

// ldm.modules.diffusionmodules.model.Decoder
class Decoder : public GGMLBlock {
protected:
    int ch                   = 128;
    int out_ch               = 3;
    std::vector<int> ch_mult = {1, 2, 4, 4};
    int num_res_blocks       = 2;
    int z_channels           = 4;
    bool video_decoder       = false;
    int video_kernel_size    = 3;

    virtual std::shared_ptr<GGMLBlock> get_conv_out(int64_t in_channels,
                                                    int64_t out_channels,
                                                    std::pair<int, int> kernel_size,
                                                    std::pair<int, int> stride  = {1, 1},
                                                    std::pair<int, int> padding = {0, 0}) {
        if (video_decoder) {
            return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
        } else {
            return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding));
        }
    }

    virtual std::shared_ptr<GGMLBlock> get_resnet_block(int64_t in_channels,
                                                        int64_t out_channels) {
        if (video_decoder) {
            return std::shared_ptr<GGMLBlock>(new VideoResnetBlock(in_channels, out_channels, video_kernel_size));
        } else {
            return std::shared_ptr<GGMLBlock>(new ResnetBlock(in_channels, out_channels));
        }
    }

public:
    Decoder(int ch,
            int out_ch,
            std::vector<int> ch_mult,
            int num_res_blocks,
            int z_channels,
            bool video_decoder    = false,
            int video_kernel_size = 3)
        : ch(ch),
          out_ch(out_ch),
          ch_mult(ch_mult),
          num_res_blocks(num_res_blocks),
          z_channels(z_channels),
          video_decoder(video_decoder),
          video_kernel_size(video_kernel_size) {
        size_t num_resolutions = ch_mult.size();
        int block_in           = ch * ch_mult[num_resolutions - 1];

        blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));

        blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
        blocks["mid.attn_1"]  = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in));
        blocks["mid.block_2"] = get_resnet_block(block_in, block_in);

        for (int i = num_resolutions - 1; i >= 0; i--) {
            int mult      = ch_mult[i];
            int block_out = ch * mult;
            for (int j = 0; j < num_res_blocks + 1; j++) {
                std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
                blocks[name]     = get_resnet_block(block_in, block_out);

                block_in = block_out;
            }
            if (i != 0) {
                std::string name = "up." + std::to_string(i) + ".upsample";
                blocks[name]     = std::shared_ptr<GGMLBlock>(new UpSampleBlock(block_in, block_in));
            }
        }

        blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
        blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
    }

    virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) {
        // z: [N, z_channels, h, w]
        // alpha is always 0
        // merge_strategy is always learned
        // time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
        // AttnVideoBlock will not be used
        auto conv_in     = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
        auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
        auto mid_attn_1  = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
        auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
        auto norm_out    = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
        auto conv_out    = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);

        // conv_in
        auto h = conv_in->forward(ctx, z);  // [N, block_in, h, w]

        // middle
        h = mid_block_1->forward(ctx, h);
        // return h;

        h = mid_attn_1->forward(ctx, h);
        h = mid_block_2->forward(ctx, h);  // [N, block_in, h, w]

        // upsampling
        size_t num_resolutions = ch_mult.size();
        for (int i = num_resolutions - 1; i >= 0; i--) {
            for (int j = 0; j < num_res_blocks + 1; j++) {
                std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
                auto up_block    = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);

                h = up_block->forward(ctx, h);
            }
            if (i != 0) {
                std::string name = "up." + std::to_string(i) + ".upsample";
                auto up_sample   = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);

                h = up_sample->forward(ctx, h);
            }
        }

        h = norm_out->forward(ctx, h);
        h = ggml_silu_inplace(ctx, h);  // nonlinearity/swish
        h = conv_out->forward(ctx, h);  // [N, out_ch, h*8, w*8]
        return h;
    }
};

// ldm.models.autoencoder.AutoencoderKL
class AutoencodingEngine : public GGMLBlock {
protected:
    bool decode_only       = true;
    bool use_video_decoder = false;
    bool use_quant         = true;
    int embed_dim          = 4;
    struct {
        int z_channels           = 4;
        int resolution           = 256;
        int in_channels          = 3;
        int out_ch               = 3;
        int ch                   = 128;
        std::vector<int> ch_mult = {1, 2, 4, 4};
        int num_res_blocks       = 2;
        bool double_z            = true;
    } dd_config;

public:
    AutoencodingEngine(bool decode_only       = true,
                       bool use_video_decoder = false,
                       SDVersion version      = VERSION_SD1)
        : decode_only(decode_only), use_video_decoder(use_video_decoder) {
        if (sd_version_is_dit(version)) {
            dd_config.z_channels = 16;
            use_quant            = false;
        }
        if (use_video_decoder) {
            use_quant = false;
        }
        blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
                                                                   dd_config.out_ch,
                                                                   dd_config.ch_mult,
                                                                   dd_config.num_res_blocks,
                                                                   dd_config.z_channels,
                                                                   use_video_decoder));
        if (use_quant) {
            blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
                                                                              embed_dim,
                                                                              {1, 1}));
        }
        if (!decode_only) {
            blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder(dd_config.ch,
                                                                       dd_config.ch_mult,
                                                                       dd_config.num_res_blocks,
                                                                       dd_config.in_channels,
                                                                       dd_config.z_channels,
                                                                       dd_config.double_z));
            if (use_quant) {
                int factor = dd_config.double_z ? 2 : 1;

                blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
                                                                             dd_config.z_channels * factor,
                                                                             {1, 1}));
            }
        }
    }

    struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) {
        // z: [N, z_channels, h, w]
        if (use_quant) {
            auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
            z                    = post_quant_conv->forward(ctx, z);  // [N, z_channels, h, w]
        }
        auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);

        ggml_set_name(z, "bench-start");
        auto h = decoder->forward(ctx, z);
        ggml_set_name(h, "bench-end");
        return h;
    }

    struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x) {
        // x: [N, in_channels, h, w]
        auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);

        auto h = encoder->forward(ctx, x);  // [N, 2*z_channels, h/8, w/8]
        if (use_quant) {
            auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
            h               = quant_conv->forward(ctx, h);  // [N, 2*embed_dim, h/8, w/8]
        }
        return h;
    }
};

struct AutoEncoderKL : public GGMLRunner {
    bool decode_only = true;
    AutoencodingEngine ae;

    AutoEncoderKL(ggml_backend_t backend,
                  std::map<std::string, enum ggml_type>& tensor_types,
                  const std::string prefix,
                  bool decode_only       = false,
                  bool use_video_decoder = false,
                  SDVersion version      = VERSION_SD1)
        : decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend) {
        ae.init(params_ctx, tensor_types, prefix);
    }

    std::string get_desc() {
        return "vae";
    }

    void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
        ae.get_param_tensors(tensors, prefix);
    }

    struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
        struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);

        z = to_backend(z);

        struct ggml_tensor* out = decode_graph ? ae.decode(compute_ctx, z) : ae.encode(compute_ctx, z);

        ggml_build_forward_expand(gf, out);

        return gf;
    }

    void compute(const int n_threads,
                 struct ggml_tensor* z,
                 bool decode_graph,
                 struct ggml_tensor** output,
                 struct ggml_context* output_ctx = NULL) {
        auto get_graph = [&]() -> struct ggml_cgraph* {
            return build_graph(z, decode_graph);
        };
        // ggml_set_f32(z, 0.5f);
        // print_ggml_tensor(z);
        GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
    }

    void test() {
        struct ggml_init_params params;
        params.mem_size   = static_cast<size_t>(10 * 1024 * 1024);  // 10 MB
        params.mem_buffer = NULL;
        params.no_alloc   = false;

        struct ggml_context* work_ctx = ggml_init(params);
        GGML_ASSERT(work_ctx != NULL);

        {
            // CPU, x{1, 3, 64, 64}: Pass
            // CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan
            // CPU, x{2, 3, 64, 64}: Wrong result
            // CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result
            auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2);
            ggml_set_f32(x, 0.5f);
            print_ggml_tensor(x);
            struct ggml_tensor* out = NULL;

            int t0 = ggml_time_ms();
            compute(8, x, false, &out, work_ctx);
            int t1 = ggml_time_ms();

            print_ggml_tensor(out);
            LOG_DEBUG("encode test done in %dms", t1 - t0);
        }

        if (false) {
            // CPU, z{1, 4, 8, 8}: Pass
            // CUDA, z{1, 4, 8, 8}: Pass
            // CPU, z{3, 4, 8, 8}: Wrong result
            // CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result
            auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
            ggml_set_f32(z, 0.5f);
            print_ggml_tensor(z);
            struct ggml_tensor* out = NULL;

            int t0 = ggml_time_ms();
            compute(8, z, true, &out, work_ctx);
            int t1 = ggml_time_ms();

            print_ggml_tensor(out);
            LOG_DEBUG("decode test done in %dms", t1 - t0);
        }
    };
};

#endif