#ifndef __DIFFUSION_MODEL_H__ #define __DIFFUSION_MODEL_H__ #include "flux.hpp" #include "mmdit.hpp" #include "unet.hpp" struct DiffusionModel { virtual void compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) = 0; virtual void alloc_params_buffer() = 0; virtual void free_params_buffer() = 0; virtual void free_compute_buffer() = 0; virtual void get_param_tensors(std::map& tensors) = 0; virtual size_t get_params_buffer_size() = 0; virtual int64_t get_adm_in_channels() = 0; }; struct UNetModel : public DiffusionModel { UNetModelRunner unet; UNetModel(ggml_backend_t backend, std::map& tensor_types, SDVersion version = VERSION_SD1, bool flash_attn = false) : unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) { } void alloc_params_buffer() { unet.alloc_params_buffer(); } void free_params_buffer() { unet.free_params_buffer(); } void free_compute_buffer() { unet.free_compute_buffer(); } void get_param_tensors(std::map& tensors) { unet.get_param_tensors(tensors, "model.diffusion_model"); } size_t get_params_buffer_size() { return unet.get_params_buffer_size(); } int64_t get_adm_in_channels() { return unet.unet.adm_in_channels; } void compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { (void)skip_layers; // SLG doesn't work with UNet models return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx); } }; struct MMDiTModel : public DiffusionModel { MMDiTRunner mmdit; MMDiTModel(ggml_backend_t backend, std::map& tensor_types) : mmdit(backend, tensor_types, "model.diffusion_model") { } void alloc_params_buffer() { mmdit.alloc_params_buffer(); } void free_params_buffer() { mmdit.free_params_buffer(); } void free_compute_buffer() { mmdit.free_compute_buffer(); } void get_param_tensors(std::map& tensors) { mmdit.get_param_tensors(tensors, "model.diffusion_model"); } size_t get_params_buffer_size() { return mmdit.get_params_buffer_size(); } int64_t get_adm_in_channels() { return 768 + 1280; } void compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers); } }; struct FluxModel : public DiffusionModel { Flux::FluxRunner flux; FluxModel(ggml_backend_t backend, std::map& tensor_types, bool flash_attn = false) : flux(backend, tensor_types, "model.diffusion_model", flash_attn) { } void alloc_params_buffer() { flux.alloc_params_buffer(); } void free_params_buffer() { flux.free_params_buffer(); } void free_compute_buffer() { flux.free_compute_buffer(); } void get_param_tensors(std::map& tensors) { flux.get_param_tensors(tensors, "model.diffusion_model"); } size_t get_params_buffer_size() { return flux.get_params_buffer_size(); } int64_t get_adm_in_channels() { return 768; } void compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers); } }; #endif