tiledattention header

TiledAttention

Code:

GitHub

Paper at AI on HPC (ISC 2026):

arXiv

TiledAttention is a reference scaled dot-product attention (SDPA) forward kernel for NVIDIA GPUs, implemented in cuTile Python (TileIR) and exposed for PyTorch-oriented workflows. This SDPA implementation is intended as a cuTile-in-Python reference to show that custom hardware-level kernels can be written and iterated directly in Python, including modality-specific attention variants beyond text-only LLM settings. The design follows FlashAttention-style online softmax with tiled (K,V) streaming, while emphasizing schedule-level modifiability (tile shapes, staging, shared-memory layout) for reproducible kernel research.

In the accompanying study, TiledAttention is evaluated against PyTorch SDPA auto-dispatch and explicit baselines across sequence length, head dimension, causal/non-causal masking, and FP16/BF16 precision.

This Hub kernel is packaged as a Python-only CUDA kernel. At runtime it also requires cupy-cuda13x and cuda-tile in the consumer environment.

Benchmark on NVIDIA DGX (GB10)

Full study results dataset using NVIDIA Nsight Compute profiling:

DOI

The figures below summarize the benchmark study results.

Fused vs unfused vs TiledAttention: figure6

Individual SDPA backends vs TiledAttention: figure7

Available Functions

  • sdpa(q, k, v, causal=False, scale=None)

sdpa

Computes forward scaled dot-product attention using the TiledAttention kernel.

Input requirements:

  • q, k, v: CUDA tensors of shape [B, H, S, D]
  • matching shape, dtype, and device
  • dtype: torch.float16 or torch.bfloat16

Parameters:

  • causal (bool, default False): enables causal masking
  • scale (float | None, default None): attention scale; if None, uses 1/sqrt(D)

Returns:

  • output tensor of shape [B, H, S, D] on CUDA

How to Use

# pip install -U kernels
from kernels import get_kernel
import torch

k = get_kernel("thisistaimur/tiledattention", version=1, trust_remote_code=True)

q = torch.randn(1, 8, 1024, 64, device="cuda", dtype=torch.float16)
k_t = torch.randn_like(q)
v = torch.randn_like(q)

out = k.sdpa(q, k_t, v, causal=False)
print(out.shape, out.dtype)

With explicit scale:

out = k.sdpa(q, k_t, v, causal=True, scale=1.0 / (q.shape[-1] ** 0.5))
Downloads last month
-
Apache-2.0