TiledAttention
Code:
Paper at AI on HPC (ISC 2026):
TiledAttention is a reference scaled dot-product attention (SDPA) forward kernel for NVIDIA GPUs, implemented in cuTile Python (TileIR) and exposed for PyTorch-oriented workflows. This SDPA implementation is intended as a cuTile-in-Python reference to show that custom hardware-level kernels can be written and iterated directly in Python, including modality-specific attention variants beyond text-only LLM settings. The design follows FlashAttention-style online softmax with tiled (K,V) streaming, while emphasizing schedule-level modifiability (tile shapes, staging, shared-memory layout) for reproducible kernel research.
In the accompanying study, TiledAttention is evaluated against PyTorch SDPA auto-dispatch and explicit baselines across sequence length, head dimension, causal/non-causal masking, and FP16/BF16 precision.
This Hub kernel is packaged as a Python-only CUDA kernel. At runtime it also requires cupy-cuda13x and cuda-tile in the consumer environment.
Benchmark on NVIDIA DGX (GB10)
Full study results dataset using NVIDIA Nsight Compute profiling:
The figures below summarize the benchmark study results.
Fused vs unfused vs TiledAttention:

Individual SDPA backends vs TiledAttention:

Available Functions
sdpa(q, k, v, causal=False, scale=None)
sdpa
Computes forward scaled dot-product attention using the TiledAttention kernel.
Input requirements:
q,k,v: CUDA tensors of shape[B, H, S, D]- matching shape, dtype, and device
- dtype:
torch.float16ortorch.bfloat16
Parameters:
causal(bool, defaultFalse): enables causal maskingscale(float | None, defaultNone): attention scale; ifNone, uses1/sqrt(D)
Returns:
- output tensor of shape
[B, H, S, D]on CUDA
How to Use
# pip install -U kernels
from kernels import get_kernel
import torch
k = get_kernel("thisistaimur/tiledattention", version=1, trust_remote_code=True)
q = torch.randn(1, 8, 1024, 64, device="cuda", dtype=torch.float16)
k_t = torch.randn_like(q)
v = torch.randn_like(q)
out = k.sdpa(q, k_t, v, causal=False)
print(out.shape, out.dtype)
With explicit scale:
out = k.sdpa(q, k_t, v, causal=True, scale=1.0 / (q.shape[-1] ** 0.5))
- Downloads last month
- -