Papers
arxiv:2603.15619

Mixture-of-Depths Attention

Published on Mar 16
· Submitted by
Lianghui Zhu
on Mar 17
Authors:
,
,
,
,
,
,
,
,

Abstract

Scaling depth is a key driver for large language models (LLMs). Yet, as LLMs become deeper, they often suffer from signal degradation: informative features formed in shallow layers are gradually diluted by repeated residual updates, making them harder to recover in deeper layers. We introduce mixture-of-depths attention (MoDA), a mechanism that allows each attention head to attend to sequence KV pairs at the current layer and depth KV pairs from preceding layers. We further describe a hardware-efficient algorithm for MoDA that resolves non-contiguous memory-access patterns, achieving 97.3% of FlashAttention-2's efficiency at a sequence length of 64K. Experiments on 1.5B-parameter models demonstrate that MoDA consistently outperforms strong baselines. Notably, it improves average perplexity by 0.2 across 10 validation benchmarks and increases average performance by 2.11% on 10 downstream tasks, with a negligible 3.7% FLOPs computational overhead. We also find that combining MoDA with post-norm yields better performance than using it with pre-norm. These results suggest that MoDA is a promising primitive for depth scaling. Code is released at https://github.com/hustvl/MoDA .

Community

Paper author Paper submitter

Scaling depth is a key driver for large language models (LLMs). Yet, as LLMs become deeper, they often suffer from signal degradation: informative features formed in shallow layers are gradually diluted by repeated residual updates, making them harder to recover in deeper layers. We introduce mixture-of-depths attention (MoDA), a mechanism that allows each attention head to attend to sequence KV pairs at the current layer and depth KV pairs from preceding layers. We further describe a hardware-efficient algorithm for MoDA that resolves non-contiguous memory-access patterns, achieving 97.3% of FlashAttention-2's efficiency at a sequence length of 64K. Experiments on 1.5B-parameter models demonstrate that MoDA consistently outperforms strong baselines. Notably, it improves average perplexity by 0.2 across 10 validation benchmarks and increases average performance by 2.11% on 10 downstream tasks, with a negligible 3.7% FLOPs computational overhead. We also find that combining MoDA with post-norm yields better performance than using it with pre-norm. These results suggest that MoDA is a promising primitive for depth scaling.

The paper introduces Mixture-of-Depths Attention (MoDA), a mechanism designed to address signal degradation in deep Large Language Models (LLMs)—the phenomenon where informative features formed in shallow layers get diluted by repeated residual updates, making them harder to recover in deeper layers.

Core Concept

MoDA extends standard causal attention by allowing each attention head to attend to both:

  1. Sequence KV pairs (standard keys/values from the current layer)
  2. Depth KV pairs (historical keys/values from preceding layers at the same token position)

This creates a unified attention operation where queries jointly attend to sequence and depth information under a single softmax normalization.

Figure 1: MoDA Architecture
Figure 1: Compared to vanilla causal attention, MoDA allows queries to attend to "depth memories" (historical KV pairs from preceding layers).

Main Performance Results

MoDA consistently outperforms strong baselines (OLMo2) across multiple model scales (700M and 1.5B parameters) trained on 400B tokens:

Figure 2: Performance Comparison
Figure 2: MoDA achieves lower C4 validation loss and better downstream performance (HellaSwag, WinoGrande, ARC-Challenge) compared to OLMo2 at 1.5B scale.

Key metrics at 1.5B scale:

  • Perplexity: Improves average perplexity by 0.2 across 10 validation benchmarks
  • Downstream tasks: Increases average performance by 2.11% across 10 tasks (including HellaSwag, WinoGrande, ARC-Challenge, MMLU)
  • Cost: Only 3.7% FLOPs overhead and negligible parameter increase

Detailed results:

Model Avg Downstream C4 Val PPL Wiki PPL
OLMo2 1.5B 62.28 16.16 10.41
MoDA 1.5B 64.39 (+2.11) 15.97 10.16

The gains are consistent across both scales:

  • 700M: +1.76% average downstream performance
  • 1.5B: +2.11% average downstream performance

Mechanism Design Space

Figure 3: Design Comparison
Figure 3: Conceptual comparison of depth-stream mechanisms. (a) Depth Residual (standard), (b) Depth Dense (DenseNet-style), (c) Depth Attention, (d) Mixture-of-Depths Attention (MoDA).

MoDA sits at an efficient point in the design space:

  • Depth Residual: Compresses history via addition (information dilution)
  • Depth Dense: Concatenates all layers (lossless but O(L²D²) parameters)
  • MoDA: Data-dependent retrieval via attention with O(LD²/G) parameters (most efficient)

Hardware Efficiency

MoDA achieves 97.3% of FlashAttention-2's efficiency at 64K sequence length through a hardware-aware fused kernel that resolves non-contiguous memory access patterns.

Figure 4: Hardware Layout
Figure 4: Chunk-aware depth KV layout reduces effective depth span from T×L to (C×L)/G per chunk, improving memory access efficiency.

Efficiency scaling:

  • At T=64K: Only 2.73% extra time vs FlashAttention-2
  • As sequence length increases, overhead decreases (from 25.86% at T=4K to 2.73% at T=64K)

Kernel optimization ablation (speedups over naive PyTorch):

  1. Flash-compatible layout: 162.5× faster
    • Chunk-aware layout: additional 52% reduction
    • Group-aware indexing: additional 4.31× speedup
  2. Total: ~1458× speedup over naive implementation

Ablation Studies

MoDA Variants (700M models):

  • Baseline OLMo2: 57.11% avg downstream
    • Depth KV only: 58.10% (+0.99%)
    • Depth KV + FFN KV projection: 58.87% (+1.76% total)
    • Extra Attn KV projection: 58.97% (marginal gain, not worth the parameter cost)

Layer Number Analysis:
MoDA remains effective across different depths (24 vs 48 layers). Notably, combining MoDA with post-norm yields better results than pre-norm, especially in deeper models (48 layers):

  • 48-layer post-norm: 3.3484 validation loss (vs 3.3653 pre-norm)

Attention Visualization

Figure 5: Attention Heatmaps
Figure 5: Attention heatmaps show substantial attention mass assigned to depth-KV blocks (right of red dashed line), indicating MoDA effectively retrieves cross-layer information rather than relying solely on sequence context.

The visualization reveals:

  • Persistent attention to depth KV pairs across middle and late layers
  • Reduced "attention sink" behavior (less probability mass collapsing onto fixed positions)
  • Better probability allocation to informative sequence and depth positions

Summary

MoDA provides a practical primitive for depth scaling in LLMs by enabling data-dependent retrieval of historical layer information with minimal overhead. It achieves strong performance gains across multiple scales and benchmarks while maintaining near-FlashAttention hardware efficiency.

Thanks for the code

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2603.15619 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2603.15619 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2603.15619 in a Space README.md to link it from this page.

Collections including this paper 2