| from dataclasses import dataclass, field |
| from einops import rearrange, repeat |
| import math |
| import torch |
| from torch.amp.autocast_mode import autocast |
| import torch.nn as nn |
| from transformers.activations import ACT2FN |
| from typing import cast |
|
|
| |
| try: |
| from flash_attn.bert_padding import pad_input, unpad_input |
| from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding |
| from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention |
| from flash_attn.ops.fused_dense import FusedDense |
| except ImportError: |
| print("flash_attn not found, using default implementations") |
| pad_input = unpad_input = FlashRotaryEmbedding = FlashCrossAttentio = FlashSelfAttention = FusedDense = None |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """Rotary positional embedding (RoPE). See https://www.youtube.com/watch?v=C6rV8BsrrCc""" |
|
|
| def __init__( |
| self, |
| d_rotary: int, |
| rotary_base: float = 10000.0, |
| initial_cos_sin_cache_len: int = 2048, |
| device: torch.device | None = None, |
| ) -> None: |
| super().__init__() |
| self.d_rotary = d_rotary |
| self.rotary_base = rotary_base |
| self.device = device |
| self.dtype = torch.float32 |
| self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len) |
|
|
| def _update_cos_sin_cache( |
| self, |
| seqlen: int, |
| device: str | None = None, |
| dtype: torch.dtype | None = None, |
| ) -> None: |
| |
| self._max_seqlen = seqlen |
|
|
| |
| m = torch.arange( |
| seqlen, |
| device=device, |
| dtype=torch.float32, |
| ) |
| theta_i = 1.0 / ( |
| self.rotary_base ** ( |
| torch.arange( |
| start=0, |
| end=self.d_rotary, |
| step=2, |
| device=device, |
| dtype=torch.float32, |
| ) / self.d_rotary |
| ) |
| ) |
| |
| |
| m_theta_i = torch.outer(m, theta_i) |
| self._cos_cached = torch.cos(m_theta_i).to(dtype) |
| self._sin_cached = torch.sin(m_theta_i).to(dtype) |
|
|
| |
| """ |
| if scale_base is not None: |
| scale = ( |
| torch.arange( |
| start=0, |
| end=self.d_rotary, |
| step=2, |
| device=self.device, |
| dtype=torch.float32, |
| ) + 0.4 * self.d_rotary |
| ) / (1.4 * self.d_rotary) |
| power = ( |
| torch.arange(seqlen, dtype=scale.dtype, device=scale.device) - seqlen // 2 |
| ) / scale_base |
| scale = scale.to(device=power.device) ** rearrange(power, "s -> s 1") |
| self._cos_cached = (torch.cos(m_theta_i) * scale).to(dtype) |
| self._sin_cached = (torch.sin(m_theta_i) * scale).to(dtype) |
| """ |
|
|
| def _apply_rotary_emb_qkv( |
| self, |
| x: torch.FloatTensor, |
| cos: torch.FloatTensor, |
| sin: torch.FloatTensor, |
| ) -> torch.FloatTensor: |
| seqlen = x.shape[1] |
| x_to_rotate = x[..., :self.d_rotary] |
| x_to_keep_unrotated = x[..., self.d_rotary:] |
| x1, x2 = x_to_rotate.chunk(2, dim=-1) |
| broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d" |
| c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange) |
| x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] |
| x_rotated = cast( |
| torch.FloatTensor, |
| torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype) |
| ) |
| return torch.cat([x_rotated, x_to_keep_unrotated], axis=-1) |
|
|
| def forward( |
| self, |
| x: torch.FloatTensor, |
| seqlen_offset: int = 0, |
| ) -> torch.FloatTensor: |
| if ( |
| not self._max_seqlen |
| or self._max_seqlen < x.shape[1] + seqlen_offset |
| or self._cos_cached.device != x.device |
| or self._cos_cached.dtype != x.dtype |
| or (self.training and self._cos_cached.is_inference()) |
| ): |
| self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset, device=x.device, dtype=x.dtype) |
| return self._apply_rotary_emb_qkv( |
| x, |
| cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]), |
| cast(torch.FloatTensor, self._sin_cached[seqlen_offset:]), |
| ) |
|
|
|
|
| class SelfAttention(nn.Module): |
| def __init__( |
| self, |
| qk_scale: float | None = None, |
| attention_dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| self.qk_scale = qk_scale |
| self.dropout = nn.Dropout(attention_dropout) |
|
|
| |
| @autocast("cpu", enabled=False) |
| @autocast("cuda", enabled=False) |
| def forward( |
| self, |
| qkv: torch.FloatTensor, |
| causal: bool = True, |
| key_padding_mask: torch.BoolTensor | None = None, |
| ) -> torch.FloatTensor: |
| batch_size, seqlen = qkv.shape[0], qkv.shape[1] |
| q, k, v = qkv.unbind(dim=2) |
| q = q.to(torch.float32) |
| k = k.to(torch.float32) |
| qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1]) |
|
|
| scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale) |
|
|
| if key_padding_mask: |
| padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) |
| padding_mask.masked_fill_(key_padding_mask, 0.0) |
| scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") |
|
|
| if causal: |
| causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) |
| scores = scores + causal_mask.to(dtype=scores.dtype) |
|
|
| attention = torch.softmax(scores, dim=-1).to(v.dtype) |
| attention = self.dropout(attention) |
|
|
| output = torch.einsum("bhts,bshd->bthd", attention, v) |
| return cast(torch.FloatTensor, output) |
|
|
|
|
| class CrossAttention(nn.Module): |
| def __init__( |
| self, |
| qk_scale: float | None = None, |
| attention_dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| self.qk_scale = qk_scale |
| self.dropout = nn.Dropout(attention_dropout) |
|
|
| |
| @autocast("cpu", enabled=False) |
| @autocast("cuda", enabled=False) |
| def forward( |
| self, |
| q: torch.FloatTensor, |
| kv: torch.FloatTensor, |
| causal: bool = True, |
| key_padding_mask: torch.BoolTensor | None = None, |
| ) -> torch.FloatTensor: |
| batch_size, seqlen_q = q.shape[0], q.shape[1] |
| seqlen_k = kv.shape[1] |
| if kv.shape[3] != q.shape[2]: |
| kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) |
| k, v = kv.unbind(dim=2) |
| q = cast(torch.FloatTensor, q.to(torch.float32)) |
| k = k.to(torch.float32) |
| qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1]) |
|
|
| scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale) |
|
|
| if key_padding_mask: |
| padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device) |
| padding_mask.masked_fill_(key_padding_mask, 0.0) |
| scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") |
|
|
| if causal: |
| rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") |
| cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long) |
| causal_mask = cols > rows + seqlen_k - seqlen_q |
| scores = scores.masked_fill(causal_mask, -10000.0) |
|
|
| attention = torch.softmax(scores, dim=-1).to(v.dtype) |
| attention = self.dropout(attention) |
|
|
| output = torch.einsum("bhts,bshd->bthd", attention, v) |
| return cast(torch.FloatTensor, output) |
|
|
|
|
| class MLP(nn.Module): |
| def __init__( |
| self, |
| d_embedding: int, |
| act_fn: str = "gelu_new", |
| ) -> None: |
| super().__init__() |
| n_inner = 4 * d_embedding |
| self.fc1 = nn.Linear(d_embedding, n_inner) |
| self.act = ACT2FN[act_fn] |
| self.fc2 = nn.Linear(n_inner, d_embedding) |
|
|
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| @dataclass |
| class KVCache: |
| """Options for model to calculate and store context during inference.""" |
| max_seqlen: int |
| max_batch_size: int |
| seqlen_offset: int |
| batch_size_offset: int |
| kv_block_map: dict[int, torch.Tensor] = field(default_factory=dict) |
| lengths_per_sample: torch.Tensor | None = None |
|
|
|
|
| class MHA(nn.Module): |
| """Multi-head attention block.""" |
|
|
| def __init__( |
| self, |
| d_embedding: int, |
| n_attn_heads: int, |
| block_n: int, |
| initial_cos_sin_cache_len: int, |
| attn_pdrop: float, |
| use_flash_rotary: bool, |
| use_flash_attn: bool, |
| use_fused_dense: bool, |
| checkpointing: bool, |
| ) -> None: |
| super().__init__() |
|
|
| |
| rotary_cls = ( |
| FlashRotaryEmbedding |
| if use_flash_rotary and FlashRotaryEmbedding is not None |
| else RotaryEmbedding |
| ) |
| self.rotary_emb = rotary_cls( |
| |
| d_rotary=32, |
| initial_cos_sin_cache_len=initial_cos_sin_cache_len, |
| ) |
|
|
| |
| self_attn_cls = ( |
| FlashSelfAttention |
| if use_flash_attn and FlashSelfAttention is not None |
| else SelfAttention |
| ) |
| self.inner_self_attn = self_attn_cls(attention_dropout=attn_pdrop) |
|
|
| |
| cross_attn_cls = ( |
| FlashCrossAttention |
| if use_flash_attn and FlashCrossAttention is not None |
| else CrossAttention |
| ) |
| self.inner_cross_attn = cross_attn_cls(attention_dropout=attn_pdrop) |
|
|
| |
| self.n_attn_heads = n_attn_heads |
| self.d_head = d_embedding // n_attn_heads |
| linear_cls = ( |
| FusedDense |
| if use_fused_dense and FusedDense is not None |
| else nn.Linear |
| ) |
| self.Wqkv = linear_cls( |
| d_embedding, |
| self.d_head * (3 * self.n_attn_heads), |
| ) |
| self.fc_out = linear_cls(d_embedding, d_embedding) |
|
|
| |
| self.using_flash_attn = self_attn_cls is FlashSelfAttention |
| self.block_n = block_n |
| self.checkpointing = checkpointing |
|
|
| def _forward_self_attn( |
| self, |
| qkv: torch.FloatTensor, |
| key_padding_mask: torch.BoolTensor | None, |
| ) -> torch.FloatTensor: |
| qkv = cast( |
| torch.FloatTensor, |
| torch.cat( |
| [ |
| self.rotary_emb(qkv[:, :, :2, :, :]), |
| qkv[:, :, 2, :, :], |
| ], |
| dim=2, |
| ) |
| ) |
|
|
| if self.using_flash_attn and unpad_input and pad_input: |
| batch_size, seqlen = qkv.shape[0], qkv.shape[1] |
| cu_seqlens, max_seqlen, indices = None, None, None |
|
|
| |
| if key_padding_mask: |
| qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask) |
|
|
| if self.checkpointing: |
| attn_output = torch.utils.checkpoint.checkpoint( |
| self.inner_self_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen |
| ) |
| else: |
| attn_output = self.inner_self_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device) |
|
|
| |
| if key_padding_mask: |
| return pad_input(attn_output, indices, batch_size, seqlen) |
| else: |
| return attn_output |
|
|
| if self.checkpointing: |
| return torch.utils.checkpoint.checkpoint(self.inner_self_attn, qkv, key_padding_mask=key_padding_mask) |
| else: |
| return self.inner_self_attn(qkv, key_padding_mask=key_padding_mask) |
|
|
| def _update_kv_cache( |
| self, |
| kv: torch.FloatTensor, |
| kv_cache: KVCache, |
| block_n: int, |
| ) -> None: |
| if block_n not in kv_cache.kv_block_map: |
| kv_cache.kv_block_map[block_n] = torch.empty( |
| kv_cache.max_batch_size, |
| kv_cache.max_seqlen, |
| 2, |
| kv.shape[-2], |
| kv.shape[-1], |
| dtype=kv.dtype, |
| device=kv.device, |
| ) |
|
|
| batch_start = kv_cache.batch_size_offset |
| batch_end = batch_start + kv.shape[0] |
| sequence_start = kv_cache.seqlen_offset |
| sequence_end = sequence_start + kv.shape[1] |
|
|
| |
| if sequence_end >= kv_cache.max_seqlen: |
| kv_cache.kv_block_map[block_n] = torch.concatenate( |
| (kv_cache.kv_block_map[block_n], kv), |
| dim=1, |
| ) |
| kv_cache.kv_block_map[block_n][ |
| batch_start:batch_end, |
| sequence_start:sequence_end, |
| ... |
| ] = kv |
| kv = kv_cache.kv_block_map[block_n][ |
| batch_start:batch_end, |
| :sequence_end, |
| ... |
| ] |
| return kv |
|
|
| def _forward_cross_attn( |
| self, |
| qkv: torch.FloatTensor, |
| kv_cache: KVCache, |
| key_padding_mask: torch.BoolTensor | None, |
| ) -> torch.FloatTensor: |
| qk = qkv[:, :, :2, :, :] |
| qk = self.rotary_emb( |
| qk, |
| seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset, |
| ) |
| v = cast(torch.FloatTensor, qkv[:, :, 2, :, :]) |
| q = qk[:, :, 0, :, :] |
| kv = torch.cat( |
| [ |
| qk[:, :, 1, :, :].unsqueeze(2), |
| v.unsqueeze(2), |
| ], |
| dim=2, |
| ) |
| kv = self._update_kv_cache(kv, kv_cache, self.block_n) |
|
|
| causal = (kv_cache.seqlen_offset == 0) |
|
|
| if self.using_flash_attn and unpad_input and pad_input: |
| batch_size, seqlen_q = q.shape[0], q.shape[1] |
| seqlen_k = kv.shape[1] |
| cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, indices_q = ( |
| None, |
| None, |
| None, |
| None, |
| None, |
| ) |
|
|
| |
| if key_padding_mask: |
| kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask) |
|
|
| if seqlen_q == 1: |
| key_padding_mask = cast(torch.BoolTensor, torch.ones(batch_size, 1, device=q.device)) |
| elif seqlen_q != seqlen_k: |
| key_padding_mask = cast(torch.BoolTensor, key_padding_mask[:, -seqlen_q:]) |
|
|
| q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask) |
|
|
| if self.checkpointing: |
| attn_output = torch.utils.checkpoint.checkpoint( |
| self.inner_cross_attn, |
| q, |
| kv, |
| causal=causal, |
| cu_seqlens=cu_seqlens_q, |
| max_seqlen=max_seqlen_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_k=max_seqlen_k, |
| ) |
| else: |
| attn_output = self.inner_cross_attn( |
| q, |
| kv, |
| causal=causal, |
| cu_seqlens=cu_seqlens_q, |
| max_seqlen=max_seqlen_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_k=max_seqlen_k, |
| ) |
|
|
| if key_padding_mask: |
| return pad_input(attn_output, indices_q, batch_size, max_seqlen_q) |
| else: |
| return attn_output |
|
|
| if self.checkpointing: |
| return torch.utils.checkpoint.checkpoint( |
| self.inner_cross_attn, |
| q, |
| kv, |
| key_padding_mask=key_padding_mask, |
| causal=causal, |
| ) |
| else: |
| return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal) |
|
|
| def forward( |
| self, |
| x: torch.FloatTensor, |
| kv_cache: KVCache | None = None, |
| key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None, |
| ) -> tuple[torch.FloatTensor, torch.FloatTensor]: |
| if key_padding_mask is not None: |
| key_padding_mask = cast(torch.BoolTensor, key_padding_mask.bool()) |
|
|
| qkv = self.Wqkv(x) |
| qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.d_head) |
| if kv_cache is None: |
| attn_output = self._forward_self_attn(qkv, key_padding_mask) |
| else: |
| attn_output = self._forward_cross_attn(qkv, kv_cache, key_padding_mask) |
|
|
| output = rearrange(attn_output, "... h d -> ... (h d)") |
| output = self.fc_out(output) |
| return output |
|
|
|
|
| class ParallelAttentionBlock(nn.Module): |
| """Calculates attention and MLP in parallel.""" |
|
|
| def __init__( |
| self, |
| resid_pdrop: float, |
| layer_norm_epsilon: float, |
| d_embedding: int, |
| n_attn_heads: int, |
| block_n: int, |
| initial_cos_sin_cache_len: int, |
| attn_pdrop: float, |
| use_flash_rotary: bool = True, |
| use_flash_attn: bool = True, |
| use_fused_dense: bool = True, |
| checkpointing: bool = False, |
| ) -> None: |
| super().__init__() |
| self.layer_norm = nn.LayerNorm(d_embedding, eps=layer_norm_epsilon) |
| self.block_n = block_n |
| self.multi_head_attention = MHA( |
| d_embedding=d_embedding, |
| n_attn_heads=n_attn_heads, |
| block_n=block_n, |
| initial_cos_sin_cache_len=initial_cos_sin_cache_len, |
| attn_pdrop=attn_pdrop, |
| use_flash_rotary=use_flash_rotary, |
| use_flash_attn=use_flash_attn, |
| use_fused_dense=use_fused_dense, |
| checkpointing=checkpointing, |
| ) |
| self.mlp = MLP(d_embedding) |
| self.dropout = nn.Dropout(resid_pdrop) |
|
|
| def forward( |
| self, |
| x: torch.FloatTensor, |
| kv_cache: KVCache | None = None, |
| key_padding_mask: torch.BoolTensor | None = None, |
| ) -> torch.FloatTensor: |
| residual = x |
| x = self.layer_norm(x) |
| attn_outputs = self.multi_head_attention( |
| x, |
| kv_cache=kv_cache, |
| key_padding_mask=key_padding_mask, |
| ) |
| mlp_outputs = self.mlp(x) |
| x = self.dropout(attn_outputs + mlp_outputs) + residual |
| return x |
|
|