This tiny model is intended for debugging. It is randomly initialized using the configuration adapted from google/gemma-4-E4B-it.

File path Size
model.safetensors 9.5MB

Example usage:

import torch
from transformers import AutoModelForCausalLM, AutoProcessor

model_id = "yujiepan/gemma-4-e-tiny-random"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id, dtype=torch.bfloat16, device_map="auto"
)
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "audio",
                "audio": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/journal1.wav",
            },
            {"type": "text", "text": "Transcribe the following speech segment."},
        ],
    },
    {
        "role": "assistant",
        "content": [{"type": "text", "text": "Dummy response for audio"}],
    },
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "url": "https://raw.githubusercontent.com/google-gemma/cookbook/refs/heads/main/Demos/sample-data/GoldenGate.png",
            },
            {"type": "text", "text": "What is shown in this image?"},
        ],
    },
    {
        "role": "assistant",
        "content": [{"type": "text", "text": "Dummy response for image"}],
    },
    {
        "role": "user",
        "content": [
            {
                "type": "video",
                "video": "https://github.com/bebechien/gemma/raw/refs/heads/main/videos/ForBiggerBlazes.mp4",
            },
            {"type": "text", "text": "Describe this video."},
        ],
    },
]
inputs = processor.apply_chat_template(
    messages,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
    add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
print("input_len:", input_len)
outputs = model.generate(**inputs, max_new_tokens=32)
response = processor.decode(outputs[0], skip_special_tokens=False)
response = response.replace("<|audio|>", "A")
response = response.replace("<|image|>", "I")
response = response.replace("<|video|>", "V")
print(response)

Codes to create this repo:

Click to expand
import json
from pathlib import Path

import torch
from huggingface_hub import file_exists, hf_hub_download

from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoProcessor,
    AutoTokenizer,
    Gemma4ForConditionalGeneration,
    GenerationConfig,
    set_seed,
)

source_model_id = "google/gemma-4-E4B-it"
save_folder = "/tmp/yujiepan/gemma-4-e-tiny-random"

processor = AutoProcessor.from_pretrained(source_model_id)
processor.save_pretrained(save_folder)

with open(
    hf_hub_download(source_model_id, filename="config.json", repo_type="model"),
    "r",
    encoding="utf-8",
) as f:
    config_json = json.load(f)

config_json["audio_config"].update(
    {
        "num_attention_heads": 2,
        "num_hidden_layers": 2,
        "hidden_size": 64,
        "output_proj_dims": 32,
    }
)
config_json["text_config"].update(
    {
        "global_head_dim": 64,
        "head_dim": 32,
        "hidden_size": 8,
        "hidden_size_per_layer_input": 2,
        "intermediate_size": 64,
        "layer_types": [
            "sliding_attention",
            "full_attention",
            "sliding_attention",
            "full_attention",
        ],
        "num_attention_heads": 8,
        "num_hidden_layers": 4,
        "num_key_value_heads": 4,
        "num_kv_shared_layers": 2,
    }
)
config_json["vision_config"].update(
    {
        "num_hidden_layers": 2,
        "hidden_size": 8,
        "intermediate_size": 64,
        "head_dim": 32,
        "global_head_dim": 32,
        "num_attention_heads": 4,
        "num_key_value_heads": 4,
    }
)

with open(f"{save_folder}/config.json", "w", encoding="utf-8") as f:
    json.dump(config_json, f, indent=2)

config = AutoConfig.from_pretrained(
    save_folder,
    trust_remote_code=True,
)
print(config)

torch.set_default_dtype(torch.bfloat16)
model = Gemma4ForConditionalGeneration(config)
torch.set_default_dtype(torch.float32)
if file_exists(
    filename="generation_config.json", repo_id=source_model_id, repo_type="model"
):
    model.generation_config = GenerationConfig.from_pretrained(
        source_model_id,
        trust_remote_code=True,
    )
set_seed(42)
model = model.cpu()
all_numels = 0
for name, p in sorted(model.named_parameters()):
    all_numels += p.numel()
with torch.no_grad():
    for name, p in sorted(model.named_parameters()):
        torch.nn.init.normal_(p, 0, 0.2)
        print(name, p.shape, f"{p.numel() / all_numels * 100: .4f}%")
model.save_pretrained(save_folder)

Printing the model:

Click to expand
Gemma4ForConditionalGeneration(
  (model): Gemma4Model(
    (language_model): Gemma4TextModel(
      (embed_tokens): Gemma4TextScaledWordEmbedding(262144, 8, padding_idx=0)
      (layers): ModuleList(
        (0): Gemma4TextDecoderLayer(
          (self_attn): Gemma4TextAttention(
            (q_norm): Gemma4RMSNorm()
            (k_norm): Gemma4RMSNorm()
            (v_norm): Gemma4RMSNorm()
            (k_proj): Linear(in_features=8, out_features=128, bias=False)
            (q_proj): Linear(in_features=8, out_features=256, bias=False)
            (v_proj): Linear(in_features=8, out_features=128, bias=False)
            (o_proj): Linear(in_features=256, out_features=8, bias=False)
          )
          (mlp): Gemma4TextMLP(
            (gate_proj): Linear(in_features=8, out_features=64, bias=False)
            (up_proj): Linear(in_features=8, out_features=64, bias=False)
            (down_proj): Linear(in_features=64, out_features=8, bias=False)
            (act_fn): GELUTanh()
          )
          (input_layernorm): Gemma4RMSNorm()
          (post_attention_layernorm): Gemma4RMSNorm()
          (pre_feedforward_layernorm): Gemma4RMSNorm()
          (post_feedforward_layernorm): Gemma4RMSNorm()
          (act_fn): GELUTanh()
          (per_layer_input_gate): Linear(in_features=8, out_features=2, bias=False)
          (per_layer_projection): Linear(in_features=2, out_features=8, bias=False)
          (post_per_layer_input_norm): Gemma4RMSNorm()
        )
        (1): Gemma4TextDecoderLayer(
          (self_attn): Gemma4TextAttention(
            (q_norm): Gemma4RMSNorm()
            (k_norm): Gemma4RMSNorm()
            (v_norm): Gemma4RMSNorm()
            (k_proj): Linear(in_features=8, out_features=256, bias=False)
            (q_proj): Linear(in_features=8, out_features=512, bias=False)
            (v_proj): Linear(in_features=8, out_features=256, bias=False)
            (o_proj): Linear(in_features=512, out_features=8, bias=False)
          )
          (mlp): Gemma4TextMLP(
            (gate_proj): Linear(in_features=8, out_features=64, bias=False)
            (up_proj): Linear(in_features=8, out_features=64, bias=False)
            (down_proj): Linear(in_features=64, out_features=8, bias=False)
            (act_fn): GELUTanh()
          )
          (input_layernorm): Gemma4RMSNorm()
          (post_attention_layernorm): Gemma4RMSNorm()
          (pre_feedforward_layernorm): Gemma4RMSNorm()
          (post_feedforward_layernorm): Gemma4RMSNorm()
          (act_fn): GELUTanh()
          (per_layer_input_gate): Linear(in_features=8, out_features=2, bias=False)
          (per_layer_projection): Linear(in_features=2, out_features=8, bias=False)
          (post_per_layer_input_norm): Gemma4RMSNorm()
        )
        (2): Gemma4TextDecoderLayer(
          (self_attn): Gemma4TextAttention(
            (q_norm): Gemma4RMSNorm()
            (k_norm): Gemma4RMSNorm()
            (v_norm): Gemma4RMSNorm()
            (k_proj): Linear(in_features=8, out_features=128, bias=False)
            (q_proj): Linear(in_features=8, out_features=256, bias=False)
            (v_proj): Linear(in_features=8, out_features=128, bias=False)
            (o_proj): Linear(in_features=256, out_features=8, bias=False)
          )
          (mlp): Gemma4TextMLP(
            (gate_proj): Linear(in_features=8, out_features=64, bias=False)
            (up_proj): Linear(in_features=8, out_features=64, bias=False)
            (down_proj): Linear(in_features=64, out_features=8, bias=False)
            (act_fn): GELUTanh()
          )
          (input_layernorm): Gemma4RMSNorm()
          (post_attention_layernorm): Gemma4RMSNorm()
          (pre_feedforward_layernorm): Gemma4RMSNorm()
          (post_feedforward_layernorm): Gemma4RMSNorm()
          (act_fn): GELUTanh()
          (per_layer_input_gate): Linear(in_features=8, out_features=2, bias=False)
          (per_layer_projection): Linear(in_features=2, out_features=8, bias=False)
          (post_per_layer_input_norm): Gemma4RMSNorm()
        )
        (3): Gemma4TextDecoderLayer(
          (self_attn): Gemma4TextAttention(
            (q_norm): Gemma4RMSNorm()
            (k_norm): Gemma4RMSNorm()
            (v_norm): Gemma4RMSNorm()
            (k_proj): Linear(in_features=8, out_features=256, bias=False)
            (q_proj): Linear(in_features=8, out_features=512, bias=False)
            (v_proj): Linear(in_features=8, out_features=256, bias=False)
            (o_proj): Linear(in_features=512, out_features=8, bias=False)
          )
          (mlp): Gemma4TextMLP(
            (gate_proj): Linear(in_features=8, out_features=64, bias=False)
            (up_proj): Linear(in_features=8, out_features=64, bias=False)
            (down_proj): Linear(in_features=64, out_features=8, bias=False)
            (act_fn): GELUTanh()
          )
          (input_layernorm): Gemma4RMSNorm()
          (post_attention_layernorm): Gemma4RMSNorm()
          (pre_feedforward_layernorm): Gemma4RMSNorm()
          (post_feedforward_layernorm): Gemma4RMSNorm()
          (act_fn): GELUTanh()
          (per_layer_input_gate): Linear(in_features=8, out_features=2, bias=False)
          (per_layer_projection): Linear(in_features=2, out_features=8, bias=False)
          (post_per_layer_input_norm): Gemma4RMSNorm()
        )
      )
      (norm): Gemma4RMSNorm()
      (rotary_emb): Gemma4TextRotaryEmbedding()
      (embed_tokens_per_layer): Gemma4TextScaledWordEmbedding(262144, 8, padding_idx=0)
      (per_layer_model_projection): Linear(in_features=8, out_features=8, bias=False)
      (per_layer_projection_norm): Gemma4RMSNorm()
    )
    (vision_tower): Gemma4VisionModel(
      (patch_embedder): Gemma4VisionPatchEmbedder(
        (input_proj): Linear(in_features=768, out_features=8, bias=False)
      )
      (encoder): Gemma4VisionEncoder(
        (rotary_emb): Gemma4VisionRotaryEmbedding()
        (layers): ModuleList(
          (0-1): 2 x Gemma4VisionEncoderLayer(
            (self_attn): Gemma4VisionAttention(
              (q_proj): Gemma4ClippableLinear(
                (linear): Linear(in_features=8, out_features=128, bias=False)
              )
              (k_proj): Gemma4ClippableLinear(
                (linear): Linear(in_features=8, out_features=128, bias=False)
              )
              (v_proj): Gemma4ClippableLinear(
                (linear): Linear(in_features=8, out_features=128, bias=False)
              )
              (o_proj): Gemma4ClippableLinear(
                (linear): Linear(in_features=128, out_features=8, bias=False)
              )
              (q_norm): Gemma4RMSNorm()
              (k_norm): Gemma4RMSNorm()
              (v_norm): Gemma4RMSNorm()
            )
            (mlp): Gemma4VisionMLP(
              (gate_proj): Gemma4ClippableLinear(
                (linear): Linear(in_features=8, out_features=64, bias=False)
              )
              (up_proj): Gemma4ClippableLinear(
                (linear): Linear(in_features=8, out_features=64, bias=False)
              )
              (down_proj): Gemma4ClippableLinear(
                (linear): Linear(in_features=64, out_features=8, bias=False)
              )
              (act_fn): GELUTanh()
            )
            (input_layernorm): Gemma4RMSNorm()
            (post_attention_layernorm): Gemma4RMSNorm()
            (pre_feedforward_layernorm): Gemma4RMSNorm()
            (post_feedforward_layernorm): Gemma4RMSNorm()
          )
        )
      )
      (pooler): Gemma4VisionPooler()
    )
    (embed_vision): Gemma4MultimodalEmbedder(
      (embedding_projection): Linear(in_features=8, out_features=8, bias=False)
      (embedding_pre_projection_norm): Gemma4RMSNorm()
    )
    (audio_tower): Gemma4AudioModel(
      (subsample_conv_projection): Gemma4AudioSubSampleConvProjection(
        (layer0): Gemma4AudioSubSampleConvProjectionLayer(
          (conv): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (act): ReLU()
        )
        (layer1): Gemma4AudioSubSampleConvProjectionLayer(
          (conv): Conv2d(128, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (norm): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
          (act): ReLU()
        )
        (input_proj_linear): Linear(in_features=1024, out_features=64, bias=False)
      )
      (rel_pos_enc): Gemma4AudioRelPositionalEncoding()
      (layers): ModuleList(
        (0-1): 2 x Gemma4AudioLayer(
          (feed_forward1): Gemma4AudioFeedForward(
            (ffw_layer_1): Gemma4ClippableLinear(
              (linear): Linear(in_features=64, out_features=256, bias=False)
            )
            (ffw_layer_2): Gemma4ClippableLinear(
              (linear): Linear(in_features=256, out_features=64, bias=False)
            )
            (pre_layer_norm): Gemma4RMSNorm()
            (post_layer_norm): Gemma4RMSNorm()
            (act_fn): SiLUActivation()
          )
          (feed_forward2): Gemma4AudioFeedForward(
            (ffw_layer_1): Gemma4ClippableLinear(
              (linear): Linear(in_features=64, out_features=256, bias=False)
            )
            (ffw_layer_2): Gemma4ClippableLinear(
              (linear): Linear(in_features=256, out_features=64, bias=False)
            )
            (pre_layer_norm): Gemma4RMSNorm()
            (post_layer_norm): Gemma4RMSNorm()
            (act_fn): SiLUActivation()
          )
          (self_attn): Gemma4AudioAttention(
            (q_proj): Gemma4ClippableLinear(
              (linear): Linear(in_features=64, out_features=64, bias=False)
            )
            (k_proj): Gemma4ClippableLinear(
              (linear): Linear(in_features=64, out_features=64, bias=False)
            )
            (v_proj): Gemma4ClippableLinear(
              (linear): Linear(in_features=64, out_features=64, bias=False)
            )
            (post): Gemma4ClippableLinear(
              (linear): Linear(in_features=64, out_features=64, bias=False)
            )
            (relative_k_proj): Linear(in_features=64, out_features=64, bias=False)
          )
          (lconv1d): Gemma4AudioLightConv1d(
            (linear_start): Gemma4ClippableLinear(
              (linear): Linear(in_features=64, out_features=128, bias=False)
            )
            (linear_end): Gemma4ClippableLinear(
              (linear): Linear(in_features=64, out_features=64, bias=False)
            )
            (depthwise_conv1d): Gemma4AudioCausalConv1d(64, 64, kernel_size=(5,), stride=(1,), groups=64, bias=False)
            (pre_layer_norm): Gemma4RMSNorm()
            (conv_norm): Gemma4RMSNorm()
            (act_fn): SiLUActivation()
          )
          (norm_pre_attn): Gemma4RMSNorm()
          (norm_post_attn): Gemma4RMSNorm()
          (norm_out): Gemma4RMSNorm()
        )
      )
      (output_proj): Linear(in_features=64, out_features=32, bias=True)
    )
    (embed_audio): Gemma4MultimodalEmbedder(
      (embedding_projection): Linear(in_features=32, out_features=8, bias=False)
      (embedding_pre_projection_norm): Gemma4RMSNorm()
    )
  )
  (lm_head): Linear(in_features=8, out_features=262144, bias=False)
)
Downloads last month
50
Safetensors
Model size
4.72M params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for yujiepan/gemma-4-e-tiny-random

Finetuned
(15)
this model

Collection including yujiepan/gemma-4-e-tiny-random