K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling

arXiv GitHub

Overview

K-Forcing distills an autoregressive (AR) language model into a push-forward language model (PFLM) that generates k tokens in one forward pass. It maps k independent uniform noise variables to k future tokens jointly via an inverse-CDF construction, enabling fixed-length multi-token decoding that is fully compatible with standard KV-cache batch serving.

Key results: ~2.4–3.5× batch-serving throughput speedup at modest quality degradation on LM1B and OpenWebText with ~100M-param Transformers.

Checkpoints

This repository contains four checkpoints:

File Model Dataset Parameters Description
ar_openwebtxt.ckpt AR OpenWebText ~100M Autoregressive teacher model (GPT-2 tokenizer, seq_len=1024)
ar_best_lm1b.ckpt AR LM1B ~100M Autoregressive teacher model (custom tokenizer, seq_len=128)
pflm_owt_k4.ckpt PFLM (k=4) OpenWebText ~100M Push-forward LM, decodes 4 tokens per forward pass
pflm_lm1b_k4.ckpt PFLM (k=4) LM1B ~100M Push-forward LM, decodes 4 tokens per forward pass

All models share a 12-layer causal Transformer backbone (768 hidden dim, 12 heads), following the architecture from MDLM (Sahoo et al., 2024).

Download

from huggingface_hub import hf_hub_download

# Download a specific checkpoint
ckpt_path = hf_hub_download(
    repo_id="zwave/K-Forcing",
    filename="pflm_owt_k4.ckpt",  # or: ar_openwebtxt.ckpt, ar_best_lm1b.ckpt, pflm_lm1b_k4.ckpt
)

Or download all checkpoints at once:

from huggingface_hub import snapshot_download

snapshot_download(repo_id="zwave/K-Forcing", local_dir="./checkpoints")

Or via CLI:

huggingface-cli download zwave/K-Forcing --local-dir ./checkpoints

Usage

Clone the K-Forcing repository and follow setup instructions there:

git clone https://github.com/alibaba-damo-academy/K-Forcing.git
cd K-Forcing

# Setup environment
mkdir -p wheels
wget -P wheels https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
uv sync

AR Inference

python batch_inference_with_prefix.py \
    --model ar --task owt \
    --ckpt_path ./checkpoints/ar_openwebtxt.ckpt \
    --prefix_file assets/prefix_owt_examples.jsonl \
    --batch_size 4 --n_per_prefix 1

PFLM Inference (K=2 tokens per forward pass)

python batch_inference_with_prefix.py \
    --model pflm --task owt \
    --ckpt_path ./checkpoints/pflm_owt_k4.ckpt \
    --prefix_file assets/prefix_owt_examples.jsonl \
    --batch_size 4 --n_per_prefix 1 --K 2 --freq_penalty 0.3

The PFLM checkpoint trained with k=4 supports inference with any K ≤ 4.

Architecture

  • Backbone: 12-layer causal Transformer (~100M params), 768 hidden dim, 12 heads
  • Noise encoder: sinusoidal + MLP, encodes each Uniform(0,1) noise variable into a token embedding
  • Fully causal design: noise tokens attend causally — each zⱼ sees context + z₁..zⱼ
  • Shared prediction head: same linear head as AR, applied at each noise-token position
  • Training: progressive self-forcing distillation (AR → k=1 → k=2 → k=4)

Citation

@misc{tang2026kforcingjointnextktokendecoding,
      title={K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling},
      author={Zhiwei Tang and Yuanyu He and Yizheng Han and Wangbo Zhao and Jiasheng Tang and Fan Wang and Bohan Zhuang},
      year={2026},
      eprint={2606.10820},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2606.10820},
}

License

This project is licensed under the MIT License.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train zwave/K-Forcing

Papers for zwave/K-Forcing