JeffreyXiang's picture
Finalize
a1e3f5f
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from .full_attn import scaled_dot_product_attention
from .rope import RotaryPositionEmbedder
class MultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels: int,
num_heads: int,
ctx_channels: Optional[int]=None,
type: Literal["self", "cross"] = "self",
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
qkv_bias: bool = True,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False,
):
super().__init__()
assert channels % num_heads == 0
assert type in ["self", "cross"], f"Invalid attention type: {type}"
assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
if attn_mode == "windowed":
raise NotImplementedError("Windowed attention is not yet implemented")
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.attn_mode = attn_mode
self.window_size = window_size
self.shift_window = shift_window
self.use_rope = use_rope
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
else:
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
if self.qk_rms_norm:
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
self.to_out = nn.Linear(channels, channels)
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
B, L, C = x.shape
if self._type == "self":
qkv = self.to_qkv(x)
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
if self.attn_mode == "full":
if self.qk_rms_norm or self.use_rope:
q, k, v = qkv.unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
assert phases is not None, "Phases must be provided for RoPE"
q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
h = scaled_dot_product_attention(q, k, v)
else:
h = scaled_dot_product_attention(qkv)
elif self.attn_mode == "windowed":
raise NotImplementedError("Windowed attention is not yet implemented")
else:
Lkv = context.shape[1]
q = self.to_q(x)
kv = self.to_kv(context)
q = q.reshape(B, L, self.num_heads, -1)
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k, v = kv.unbind(dim=2)
k = self.k_rms_norm(k)
h = scaled_dot_product_attention(q, k, v)
else:
h = scaled_dot_product_attention(q, kv)
h = h.reshape(B, L, -1)
h = self.to_out(h)
return h