CASM β€” CLaSp Adaptive Skip Mask

CASM is a lightweight GRPO-trained skip policy for self-speculative decoding with Meta-Llama-3-8B. It learns which transformer layers to bypass during the draft stage of speculative decoding, replacing the CLaSp dynamic-programming optimizer with a small neural policy that runs in microseconds.

How it works

Self-speculative decoding runs the same frozen model in two modes:

  1. Draft β€” selected decoder layers are skipped, producing K candidate tokens cheaply.
  2. Verify β€” the full model validates the draft block and accepts the longest matching prefix.
  3. Policy update β€” CASM observes the verify hidden states and chooses a new skip mask for the next cycle.

CASM replaces step 3's DP solver with a 2-layer Transformer encoder (~200 K parameters) that maps per-layer hidden states to a skip-mask distribution. It is trained end-to-end with GRPO against a reward combining token-acceptance rate, decoding speed, and mismatch regularization.

Architecture

Component Description
HiddenStateProjector Projects per-layer hidden states [L, d_model] β†’ [L, 128]
ScalarFeatureEmbedder Embeds 5 scalar context features (acceptance rate, latency, position, mask age, temperature)
PolicyEncoder 2-layer Transformer encoder over layer positions
logit_head Per-layer skip logits β†’ top-M selection
AcceptanceRateHead Predicts E[Ο„/K] β†’ optimal draft length K*

Parameters: ~200 K
Base model: meta-llama/Meta-Llama-3-8B (32 layers, hidden_dim=4096)
Skip budget: 8 layers per draft cycle

Usage

import torch
from src.grpo.policy import SkipPolicy

# Load policy
ckpt = torch.load("policy_best.pt", map_location="cpu")
policy = SkipPolicy(
    hidden_dim=4096,
    n_layers=32,
    n_skip=8,
    policy_dim=128,
    context_tokens=1,
)
policy.load_state_dict(ckpt["policy_state_dict"])
policy.eval()

# During self-speculative decoding, call after each verify pass:
# hidden_states: tuple of (L+1) tensors from model output_hidden_states=True
mask, draft_len = policy.greedy_mask(
    hidden_states,
    last_tau=accepted_tokens,
    draft_len=current_draft_len,
    position=current_position,
    max_len=max_new_tokens,
)
# mask: list of 0/1 per layer (1 = skip this layer during draft)
# draft_len: recommended tokens to draft next cycle

See grpo-clasp for the full training and evaluation codebase.

Training

Trained with GRPO on SpecBench-style prompts using meta-llama/Meta-Llama-3-8B on a single A100 80 GB for 10 000 steps. Imitation warm-start from CLaSp DP masks was used for the first ~1000 steps.

Metric Value
Training steps 10 000
Eval reward 99.8
Test reward 100.4
GPU NVIDIA A100 80 GB

Citation

If you use CASM, please cite the CLaSp paper and this repository:

@misc{casm2026,
  author = {Dayne Guy},
  title  = {CASM: CLaSp Adaptive Skip Mask},
  year   = {2026},
  url    = {https://huggingface.co/dayngerous/CASM}
}
Downloads last month
12
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for dayngerous/CASM

Finetuned
(582)
this model

Space using dayngerous/CASM 1