Tiny dummy models
Collection
Randomly initialized tiny models for debugging/testing purpose
•
93 items
•
Updated
•
5
This tiny model is for debugging. It is randomly initialized with the config adapted from google/gemma-3n-E4B-it.
Model ID | Notes |
---|---|
yujiepan/gemma-3n-tiny-random | hidden size is 32 |
yujiepan/gemma-3n-tiny-random-dim4 | hidden size is 4; potentially not supported in paged attention kernels |
import torch
from transformers import pipeline
model_id = "yujiepan/gemma-3n-tiny-random-dim4"
pipe = pipeline(
task="image-text-to-text",
model=model_id,
device=0,
torch_dtype=torch.bfloat16
)
# temporary patch for audio tower
from accelerate.hooks import ModelHook, add_hook_to_module
class EnsureDtype(ModelHook):
def pre_forward(self, module, *args, **kwargs):
args = list(args)
args[0] = args[0].to(module.dtype)
return super().pre_forward(module, *args, **kwargs)
add_hook_to_module(pipe.model.audio_tower, EnsureDtype())
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user",
"content": [
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
# audio is buggy for now: bf16 x fp32
{"type": "audio", "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/glass-breaking-151256.mp3"},
{"type": "text", "text": "Which image is cuter?"},
]
},
]
result = pipe(messages, min_new_tokens=512, max_new_tokens=512, do_sample=True)
print(result)
import json
from pathlib import Path
import torch
import accelerate
from huggingface_hub import file_exists, hf_hub_download
from timm.models.mobilenetv5 import decode_arch_def
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
Gemma3nForConditionalGeneration,
GenerationConfig,
set_seed,
)
source_model_id = "google/gemma-3n-E4B-it"
save_folder = "/tmp/yujiepan/gemma-3n-tiny-random-dim4"
processor = AutoProcessor.from_pretrained(source_model_id)
processor.save_pretrained(save_folder)
with open(hf_hub_download(source_model_id, filename='config.json', repo_type='model'), 'r', encoding='utf-8') as f:
config_json = json.load(f)
config_json['audio_config'].update({
"conf_num_attention_heads": 2,
"conf_num_hidden_layers": 2,
"hidden_size": 4,
})
config_json['text_config'].update({
"activation_sparsity_pattern": [0.95, 0.95, 0.0, 0.0],
"head_dim": 2,
"hidden_size": 4,
"hidden_size_per_layer_input": 1,
"intermediate_size": 8,
"laurel_rank": 1,
"layer_types": ['sliding_attention', 'full_attention', 'sliding_attention', 'full_attention'],
"num_attention_heads": 2,
"num_hidden_layers": 4,
"num_key_value_heads": 1,
"num_kv_shared_layers": 2,
"sliding_window": 512,
})
block_args = decode_arch_def(
[
# Stage 0: 128x128 in
[
'er_r1_k3_s2_e4_c4',
'er_r1_k3_s1_e4_c4',
],
# Stage 1: 256x256 in
[
'uir_r1_a3_k5_s2_e6_c4',
'uir_r1_a5_k0_s1_e4_c4',
'uir_r1_a3_k0_s1_e4_c4',
],
# Stage 2: 640x640 in
[
"uir_r1_a5_k5_s2_e6_c4",
"uir_r1_a0_k0_s1_e1_c4",
"mqa_r1_k3_h2_v2_s1_d8_c4",
"uir_r1_a0_k0_s1_e2_c4",
],
# Stage 3: 1280x1280 in
[
"uir_r1_a5_k5_s2_e6_c4",
"mqa_r1_k3_h2_s1_d8_c4",
"uir_r1_a0_k0_s1_e2_c4",
],
]
)
config_json['vision_config'].update({
"hidden_size": 2048, # hard-coded in timm
"model_args": {
"block_args": block_args,
}
})
config_json['tie_word_embeddings'] = True
with open(f"{save_folder}/config.json", "w", encoding='utf-8') as f:
json.dump(config_json, f, indent=2)
config = AutoConfig.from_pretrained(
save_folder,
trust_remote_code=True,
)
print(config)
torch.set_default_dtype(torch.bfloat16)
model = Gemma3nForConditionalGeneration(config)
torch.set_default_dtype(torch.float32)
if file_exists(filename="generation_config.json", repo_id=source_model_id, repo_type='model'):
model.generation_config = GenerationConfig.from_pretrained(
source_model_id, trust_remote_code=True,
)
set_seed(42)
model = model.cpu()
all_numels = 0
for name, p in sorted(model.named_parameters()):
all_numels += p.numel()
with torch.no_grad():
for name, p in sorted(model.named_parameters()):
torch.nn.init.normal_(p, 0, 0.2)
print(name, p.shape, f'{p.numel() / all_numels * 100: .4f}%')
model.save_pretrained(save_folder)
Gemma3nForConditionalGeneration(
(model): Gemma3nModel(
(vision_tower): TimmWrapperModel(
(timm_model): MobileNetV5Encoder(
(conv_stem): ConvNormAct(
(conv): Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(blocks): Sequential(
(0): Sequential(
(0): EdgeResidual(
(conv_exp): Conv2dSame(64, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn1): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
(aa): Identity()
(se): Identity()
(conv_pwl): Conv2d(256, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
(1): EdgeResidual(
(conv_exp): Conv2d(8, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
(aa): Identity()
(se): Identity()
(conv_pwl): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
(drop_path): Identity()
)
)
(1): Sequential(
(0): UniversalInvertedResidual(
(dw_start): ConvNormAct(
(conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(pw_exp): ConvNormAct(
(conv): Conv2d(8, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(dw_mid): ConvNormAct(
(conv): Conv2dSame(48, 48, kernel_size=(5, 5), stride=(2, 2), groups=48, bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(se): Identity()
(pw_proj): ConvNormAct(
(conv): Conv2d(48, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(dw_end): Identity()
(layer_scale): LayerScale2d()
(drop_path): Identity()
)
(1): UniversalInvertedResidual(
(dw_start): ConvNormAct(
(conv): Conv2d(8, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=8, bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(pw_exp): ConvNormAct(
(conv): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(dw_mid): Identity()
(se): Identity()
(pw_proj): ConvNormAct(
(conv): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(dw_end): Identity()
(layer_scale): LayerScale2d()
(drop_path): Identity()
)
(2): UniversalInvertedResidual(
(dw_start): ConvNormAct(
(conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(pw_exp): ConvNormAct(
(conv): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(dw_mid): Identity()
(se): Identity()
(pw_proj): ConvNormAct(
(conv): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(dw_end): Identity()
(layer_scale): LayerScale2d()
(drop_path): Identity()
)
)
(2): Sequential(
(0): UniversalInvertedResidual(
(dw_start): ConvNormAct(
(conv): Conv2d(8, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=8, bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(pw_exp): ConvNormAct(
(conv): Conv2d(8, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(dw_mid): ConvNormAct(
(conv): Conv2dSame(48, 48, kernel_size=(5, 5), stride=(2, 2), groups=48, bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(se): Identity()
(pw_proj): ConvNormAct(
(conv): Conv2d(48, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(dw_end): Identity()
(layer_scale): LayerScale2d()
(drop_path): Identity()
)
(1): UniversalInvertedResidual(
(dw_start): Identity()
(pw_exp): ConvNormAct(
(conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(dw_mid): Identity()
(se): Identity()
(pw_proj): ConvNormAct(
(conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(dw_end): Identity()
(layer_scale): LayerScale2d()
(drop_path): Identity()
)
(2): MobileAttention(
(norm): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
(attn): MultiQueryAttention2d(
(query): Sequential(
(proj): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(key): Sequential(
(down_conv): Conv2dSame(8, 8, kernel_size=(3, 3), stride=(2, 2), groups=8, bias=False)
(norm): RmsNorm2d()
(proj): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(value): Sequential(
(down_conv): Conv2dSame(8, 8, kernel_size=(3, 3), stride=(2, 2), groups=8, bias=False)
(norm): RmsNorm2d()
(proj): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(attn_drop): Dropout(p=0.0, inplace=False)
(output): Sequential(
(proj): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(drop): Dropout(p=0.0, inplace=False)
)
)
(layer_scale): LayerScale2d()
(drop_path): Identity()
)
(3): UniversalInvertedResidual(
(dw_start): Identity()
(pw_exp): ConvNormAct(
(conv): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(dw_mid): Identity()
(se): Identity()
(pw_proj): ConvNormAct(
(conv): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(dw_end): Identity()
(layer_scale): LayerScale2d()
(drop_path): Identity()
)
)
(3): Sequential(
(0): UniversalInvertedResidual(
(dw_start): ConvNormAct(
(conv): Conv2d(8, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=8, bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(pw_exp): ConvNormAct(
(conv): Conv2d(8, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(dw_mid): ConvNormAct(
(conv): Conv2dSame(48, 48, kernel_size=(5, 5), stride=(2, 2), groups=48, bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(se): Identity()
(pw_proj): ConvNormAct(
(conv): Conv2d(48, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(dw_end): Identity()
(layer_scale): LayerScale2d()
(drop_path): Identity()
)
(1): MobileAttention(
(norm): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
(attn): MultiQueryAttention2d(
(query): Sequential(
(proj): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(key): Sequential(
(proj): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(value): Sequential(
(proj): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(attn_drop): Dropout(p=0.0, inplace=False)
(output): Sequential(
(proj): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(drop): Dropout(p=0.0, inplace=False)
)
)
(layer_scale): LayerScale2d()
(drop_path): Identity()
)
(2): UniversalInvertedResidual(
(dw_start): Identity()
(pw_exp): ConvNormAct(
(conv): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(dw_mid): Identity()
(se): Identity()
(pw_proj): ConvNormAct(
(conv): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(dw_end): Identity()
(layer_scale): LayerScale2d()
(drop_path): Identity()
)
)
)
(msfa): MobileNetV5MultiScaleFusionAdapter(
(ffn): UniversalInvertedResidual(
(dw_start): Identity()
(pw_exp): ConvNormAct(
(conv): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): GELU(approximate='none')
)
)
(dw_mid): Identity()
(se): Identity()
(pw_proj): ConvNormAct(
(conv): Conv2d(32, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): RmsNormAct2d(
(drop): Identity()
(act): Identity()
)
)
(dw_end): Identity()
(layer_scale): Identity()
(drop_path): Identity()
)
(norm): RmsNorm2d()
)
)
)
(language_model): Gemma3nTextModel(
(embed_tokens): Gemma3nTextScaledWordEmbedding(262400, 4, padding_idx=0)
(layers): ModuleList(
(0-3): 4 x Gemma3nTextDecoderLayer(
(self_attn): Gemma3nTextAttention(
(q_proj): Linear(in_features=4, out_features=4, bias=False)
(k_proj): Linear(in_features=4, out_features=2, bias=False)
(v_proj): Linear(in_features=4, out_features=2, bias=False)
(o_proj): Linear(in_features=4, out_features=4, bias=False)
(q_norm): Gemma3nRMSNorm((2,), eps=1e-06)
(k_norm): Gemma3nRMSNorm((2,), eps=1e-06)
(v_norm): Gemma3nRMSNorm((), eps=1e-06)
)
(mlp): Gemma3nTextMLP(
(gate_proj): Linear(in_features=4, out_features=8, bias=False)
(up_proj): Linear(in_features=4, out_features=8, bias=False)
(down_proj): Linear(in_features=8, out_features=4, bias=False)
(act_fn): PytorchGELUTanh()
)
(input_layernorm): Gemma3nRMSNorm((4,), eps=1e-06)
(post_attention_layernorm): Gemma3nRMSNorm((4,), eps=1e-06)
(pre_feedforward_layernorm): Gemma3nRMSNorm((4,), eps=1e-06)
(post_feedforward_layernorm): Gemma3nRMSNorm((4,), eps=1e-06)
(act_fn): PytorchGELUTanh()
(altup): Gemma3nTextAltUp(
(correction_coefs): Linear(in_features=4, out_features=4, bias=False)
(prediction_coefs): Linear(in_features=4, out_features=16, bias=False)
(modality_router): Linear(in_features=4, out_features=4, bias=False)
(router_norm): Gemma3nRMSNorm((4,), eps=1e-06)
)
(laurel): Gemma3nTextLaurelBlock(
(linear_left): Linear(in_features=4, out_features=1, bias=False)
(linear_right): Linear(in_features=1, out_features=4, bias=False)
(post_laurel_norm): Gemma3nRMSNorm((4,), eps=1e-06)
)
(per_layer_input_gate): Linear(in_features=4, out_features=1, bias=False)
(per_layer_projection): Linear(in_features=1, out_features=4, bias=False)
(post_per_layer_input_norm): Gemma3nRMSNorm((4,), eps=1e-06)
)
)
(norm): Gemma3nRMSNorm((4,), eps=1e-06)
(rotary_emb): Gemma3nTextRotaryEmbedding()
(rotary_emb_local): Gemma3nTextRotaryEmbedding()
(embed_tokens_per_layer): Gemma3nTextScaledWordEmbedding(262144, 4, padding_idx=0)
(per_layer_model_projection): Linear(in_features=4, out_features=4, bias=False)
(per_layer_projection_norm): Gemma3nRMSNorm((1,), eps=1e-06)
(altup_projections): ModuleList(
(0-2): 3 x Linear(in_features=4, out_features=4, bias=False)
)
(altup_unembed_projections): ModuleList(
(0-2): 3 x Linear(in_features=4, out_features=4, bias=False)
)
)
(audio_tower): Gemma3nAudioEncoder(
(subsample_conv_projection): Gemma3nAudioSubSampleConvProjection(
(conv_0): Gemma3nAudioSSCPConvBlock(
(conv): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)
(norm): Gemma3nAudioCumulativeGroupNorm()
(activation): ReLU()
)
(conv_1): Gemma3nAudioSSCPConvBlock(
(conv): Conv2d(128, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
(norm): Gemma3nAudioCumulativeGroupNorm()
(activation): ReLU()
)
(input_proj_linear): Linear(in_features=1024, out_features=4, bias=False)
)
(conformer): ModuleList(
(0-1): 2 x Gemma3nAudioConformerBlock(
(ffw_layer_start): Gemma3nAudioConformerFeedForward(
(pre_layer_norm): Gemma3nRMSNorm((4,), eps=1e-06)
(ffw_layer_1): Linear(in_features=4, out_features=16, bias=False)
(ffw_layer_2): Linear(in_features=16, out_features=4, bias=False)
(post_layer_norm): Gemma3nRMSNorm((4,), eps=1e-06)
)
(attention): Gemma3nAudioConformerAttention(
(pre_attn_norm): Gemma3nRMSNorm((4,), eps=1e-06)
(attn): Gemma3nAudioAttention(
(relative_position_embedding): Gemma3nAudioRelativePositionEmbedding(
(pos_proj): Linear(in_features=4, out_features=4, bias=False)
)
(q_proj): Linear(in_features=4, out_features=4, bias=False)
(k_proj): Linear(in_features=4, out_features=4, bias=False)
(v_proj): Linear(in_features=4, out_features=4, bias=False)
)
(post): Linear(in_features=4, out_features=4, bias=False)
(post_norm): Gemma3nRMSNorm((4,), eps=1e-06)
)
(lconv1d): Gemma3nAudioConformerLightConv1d(
(pre_layer_norm): Gemma3nRMSNorm((4,), eps=1e-06)
(linear_start): Linear(in_features=4, out_features=8, bias=False)
(depthwise_conv1d): Conv1d(4, 4, kernel_size=(5,), stride=(1,), groups=4, bias=False)
(conv_norm): Gemma3nRMSNorm((4,), eps=1e-06)
(linear_end): Linear(in_features=4, out_features=4, bias=False)
)
(ffw_layer_end): Gemma3nAudioConformerFeedForward(
(pre_layer_norm): Gemma3nRMSNorm((4,), eps=1e-06)
(ffw_layer_1): Linear(in_features=4, out_features=16, bias=False)
(ffw_layer_2): Linear(in_features=16, out_features=4, bias=False)
(post_layer_norm): Gemma3nRMSNorm((4,), eps=1e-06)
)
(norm): Gemma3nRMSNorm((4,), eps=1e-06)
)
)
)
(embed_vision): Gemma3nMultimodalEmbedder(
(embedding): Embedding(128, 2048)
(hard_embedding_norm): Gemma3nRMSNorm((2048,), eps=1e-06)
(soft_embedding_norm): Gemma3nRMSNorm((2048,), eps=1e-06)
(embedding_projection): Linear(in_features=2048, out_features=4, bias=False)
(embedding_post_projection_norm): Gemma3nRMSNorm((), eps=1e-06)
)
(embed_audio): Gemma3nMultimodalEmbedder(
(embedding): Embedding(128, 4)
(hard_embedding_norm): Gemma3nRMSNorm((4,), eps=1e-06)
(soft_embedding_norm): Gemma3nRMSNorm((4,), eps=1e-06)
(embedding_projection): Linear(in_features=4, out_features=4, bias=False)
(embedding_post_projection_norm): Gemma3nRMSNorm((), eps=1e-06)
)
)
(lm_head): Linear(in_features=4, out_features=262400, bias=False)
)