from typing import * import torch import torch.nn as nn import torch.nn.functional as F from .full_attn import scaled_dot_product_attention from .rope import RotaryPositionEmbedder class MultiHeadRMSNorm(nn.Module): def __init__(self, dim: int, heads: int): super().__init__() self.scale = dim ** 0.5 self.gamma = nn.Parameter(torch.ones(heads, dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) class MultiHeadAttention(nn.Module): def __init__( self, channels: int, num_heads: int, ctx_channels: Optional[int]=None, type: Literal["self", "cross"] = "self", attn_mode: Literal["full", "windowed"] = "full", window_size: Optional[int] = None, shift_window: Optional[Tuple[int, int, int]] = None, qkv_bias: bool = True, use_rope: bool = False, rope_freq: Tuple[float, float] = (1.0, 10000.0), qk_rms_norm: bool = False, ): super().__init__() assert channels % num_heads == 0 assert type in ["self", "cross"], f"Invalid attention type: {type}" assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" if attn_mode == "windowed": raise NotImplementedError("Windowed attention is not yet implemented") self.channels = channels self.head_dim = channels // num_heads self.ctx_channels = ctx_channels if ctx_channels is not None else channels self.num_heads = num_heads self._type = type self.attn_mode = attn_mode self.window_size = window_size self.shift_window = shift_window self.use_rope = use_rope self.qk_rms_norm = qk_rms_norm if self._type == "self": self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) else: self.to_q = nn.Linear(channels, channels, bias=qkv_bias) self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) if self.qk_rms_norm: self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) self.to_out = nn.Linear(channels, channels) def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: B, L, C = x.shape if self._type == "self": qkv = self.to_qkv(x) qkv = qkv.reshape(B, L, 3, self.num_heads, -1) if self.attn_mode == "full": if self.qk_rms_norm or self.use_rope: q, k, v = qkv.unbind(dim=2) if self.qk_rms_norm: q = self.q_rms_norm(q) k = self.k_rms_norm(k) if self.use_rope: assert phases is not None, "Phases must be provided for RoPE" q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases) k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases) h = scaled_dot_product_attention(q, k, v) else: h = scaled_dot_product_attention(qkv) elif self.attn_mode == "windowed": raise NotImplementedError("Windowed attention is not yet implemented") else: Lkv = context.shape[1] q = self.to_q(x) kv = self.to_kv(context) q = q.reshape(B, L, self.num_heads, -1) kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) if self.qk_rms_norm: q = self.q_rms_norm(q) k, v = kv.unbind(dim=2) k = self.k_rms_norm(k) h = scaled_dot_product_attention(q, k, v) else: h = scaled_dot_product_attention(q, kv) h = h.reshape(B, L, -1) h = self.to_out(h) return h