Spaces:
Running
Running
File size: 12,434 Bytes
590a604 5a20c96 1ec7405 d18b34d 5a20c96 1ec7405 5a20c96 6e61c34 b43ba56 5a20c96 b43ba56 1ec7405 5a20c96 d18b34d b43ba56 5a20c96 d18b34d b43ba56 d18b34d 5a20c96 d18b34d b43ba56 d18b34d 5a20c96 d18b34d b43ba56 d18b34d 1ec7405 d18b34d 5a20c96 d18b34d b43ba56 d18b34d 5a20c96 d18b34d b43ba56 5a20c96 d18b34d b43ba56 d18b34d 5a20c96 b43ba56 5a20c96 b43ba56 5a20c96 d18b34d b43ba56 1ec7405 5a20c96 b43ba56 1ec7405 5a20c96 b43ba56 5a20c96 b43ba56 5a20c96 d18b34d b43ba56 d18b34d 5a20c96 1ec7405 5a20c96 590a604 5a20c96 d18b34d 5a20c96 b43ba56 5a20c96 b43ba56 5a20c96 d18b34d 5a20c96 b43ba56 5a20c96 b43ba56 5a20c96 1ec7405 5a20c96 d18b34d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 |
"""Transformer Encoder implementation (Pre-LN).
This module implements the encoder component of the Transformer architecture:
- TransformerEncoderLayer: Single encoder block with self-attention + FFN
- TransformerEncoder: Full stack with embeddings and positional encoding
Design notes:
- Pre-LN with RMSNorm for training stability
- Masks are boolean: True = attend, False = mask
- Supports T5-style relative position bias
Author: Oliver Perrin
Date: 2025-10-23
"""
from typing import List, Literal, Optional, Tuple, Union, cast
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
# Encoder implementation
from .attention import MultiHeadAttention, T5RelativePositionBias
from .feedforward import FeedForward
from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
from .t5_layer_norm import T5LayerNorm
class TransformerEncoderLayer(nn.Module):
"""
Single Transformer encoder layer (Pre-LN).
Args:
d_model: model hidden size
num_heads: number of attention heads
d_ff: hidden dimension of the position-wise feed-forward network
dropout: dropout probability applied to sublayer outputs
quantization: optional quantization mode ("4bit", "8bit")
activation: activation function for FFN ("gelu", "relu", or "swiglu")
scale_attn_scores: Whether to scale attention scores by sqrt(d_k). T5 does NOT scale.
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1,
quantization: Optional[str] = None,
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
scale_attn_scores: bool = True, # T5 uses False
):
super().__init__()
self.self_attn = MultiHeadAttention(
d_model=d_model,
num_heads=num_heads,
dropout=0.0,
quantization=quantization,
scale_scores=scale_attn_scores,
)
# set MHA internal dropout to 0.0 and use dropout1/dropout2 in the layer
self.ffn = FeedForward(
d_model=d_model,
d_ff=d_ff,
dropout=dropout,
activation=activation,
quantization=quantization,
)
self.norm1 = T5LayerNorm(d_model)
self.norm2 = T5LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
collect_attn: bool = False,
position_bias: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""
Forward pass for the encoder layer.
Args:
x: (batch, seq_len, d_model) - input embeddings / representations
mask: optional attention mask, shape either (batch, seq_q, seq_k) or (batch, 1, seq_q, seq_k)
collect_attn: whether to return attention weights
position_bias: optional (1, num_heads, seq_q, seq_k) T5-style relative position bias
Returns:
x: (batch, seq_len, d_model)
If you want attention weights, set collect_attn externally (the encoder stack can collect them).
"""
# Self-attention sublayer (Pre-LN)
x_norm = self.norm1(x) # Pre-LN
# self_attn expects query, key, value; for encoder they are the same
attn_out, attn_weights = self.self_attn(
x_norm,
x_norm,
x_norm,
mask,
return_attn_weights=collect_attn,
position_bias=position_bias,
)
x = x + self.dropout1(attn_out)
# Clamp inf values for fp16/bf16 training stability (like HuggingFace T5)
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
clamp_value = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
# Feed-forward sublayer (Pre-LN)
x_norm = self.norm2(x)
ffn_out = self.ffn(x_norm)
x = x + self.dropout2(ffn_out)
# Clamp inf values for fp16/bf16 training stability
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
clamp_value = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
# Return output (and optionally attn_weights if caller wants to collect them)
return x, attn_weights
class TransformerEncoder(nn.Module):
"""
Full encoder: token embedding + positional encoding + N encoder layers.
Args:
vocab_size: vocabulary size (ignored if you always pass embeddings)
d_model: model hidden size
num_layers: number of encoder layers to stack
num_heads: number of attention heads
d_ff: hidden dimension in FFN
dropout: dropout probability (applied in positional encoding & residuals)
max_len: maximum sequence length for positional encoding
pad_token_id: optional token id for padding; if provided and input is token ids,
a padding mask will be constructed automatically
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_layers: int = 6,
num_heads: int = 8,
d_ff: int = 2048,
dropout: float = 0.1,
max_len: int = 512,
pad_token_id: Optional[int] = None,
quantization: Optional[str] = None,
use_learned_pos_enc: bool = False,
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
use_relative_position_bias: bool = False, # T5-style relative position bias
gradient_checkpointing: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.pad_token_id = pad_token_id
self.use_relative_position_bias = use_relative_position_bias
self.gradient_checkpointing = gradient_checkpointing
# Token embedding (only used if forward receives token ids)
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
# Positional encoding (disabled when using relative position bias for T5)
self.relative_position_bias: Optional[T5RelativePositionBias] = None
if use_relative_position_bias:
# T5 uses relative position bias instead of absolute positional embeddings
self.pos_encoder = None
self.relative_position_bias = T5RelativePositionBias(
num_heads=num_heads,
num_buckets=32,
max_distance=128,
is_decoder=False,
)
elif use_learned_pos_enc:
# T5 uses max_len=512 by default; we add buffer for special tokens
self.pos_encoder = LearnedPositionalEncoding(
d_model=d_model, max_len=max_len + 2, dropout=dropout
)
else:
self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)
# T5 does NOT scale attention scores by sqrt(d_k), others do
scale_attn_scores = not use_relative_position_bias
# Encoder layers stack
self.layers = nn.ModuleList(
[
TransformerEncoderLayer(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout=dropout,
quantization=quantization,
activation=activation,
scale_attn_scores=scale_attn_scores,
)
for _ in range(num_layers)
]
)
# Final T5LayerNorm for Pre-LN stacks
self.final_norm = T5LayerNorm(d_model)
# Dropout applied after embedding + positional encoding (paper uses this)
self.input_dropout = nn.Dropout(dropout)
def _build_padding_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Build a 3D attention mask (batch, seq, seq) from input_ids and pad_token_id.
True indicates valid positions; False indicates masked (pad).
"""
assert self.pad_token_id is not None, (
"pad_token_id must be set to build padding mask from ids."
)
# mask shape: (batch, seq) where True = token kept (non-pad)
pad_mask = input_ids != self.pad_token_id
# Convert to (batch, seq_q, seq_k) by outer product broadcasting
# We want positions that are valid as both query and key
attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2)
# attn_mask dtype should be bool
return attn_mask
def forward(
self,
inputs: torch.Tensor,
mask: Optional[torch.Tensor] = None,
collect_attn: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
Forward through the encoder.
Args:
inputs: either
- token ids: LongTensor of shape (batch, seq)
- embeddings: FloatTensor of shape (batch, seq, d_model)
mask: optional attention mask. If None and pad_token_id is set and inputs are token ids,
a padding mask will be created automatically with shape (batch, seq, seq).
The mask should be boolean where True indicates allowed attention.
collect_attn: if True, returns (output, [attn_weights_per_layer]) where each entry is (batch, num_heads, seq, seq)
Returns:
output: (batch, seq, d_model)
or (output, attn_list) if collect_attn True
"""
# If inputs are token ids, embed them; otherwise assume they are embeddings
if inputs.dim() == 2: # token ids
if self.embedding is None:
raise ValueError("Encoder was not constructed with an embedding layer.")
# T5/FLAN-T5 does NOT scale embeddings by sqrt(d_model)
x = self.embedding(inputs)
seq_len = inputs.size(1)
elif inputs.dim() == 3: # already embeddings
x = inputs
seq_len = inputs.size(1)
else:
raise ValueError(
"inputs must be (batch, seq) token ids or (batch, seq, d_model) embeddings"
)
# Positional encoding + dropout (only if not using relative position bias)
if self.pos_encoder is not None:
x = self.pos_encoder(x)
x = self.input_dropout(x)
# Build mask if needed
if mask is None and inputs.dim() == 2 and self.pad_token_id is not None:
mask = self._build_padding_mask(inputs)
# Ensure mask is boolean and on the same device
if mask is not None:
mask = mask.to(dtype=torch.bool, device=x.device)
# Compute relative position bias if using T5-style
position_bias = None
if self.relative_position_bias is not None:
position_bias = self.relative_position_bias(seq_len, seq_len, x.device)
attn_weights_per_layer: List[torch.Tensor] = []
# Pass through each encoder layer (optionally collect attn)
for layer in self.layers:
if self.gradient_checkpointing and self.training:
# Gradient checkpointing requires the inputs to require grad
# We use a lambda to pass keyword arguments
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
return custom_forward
x, attn = cast(
Tuple[torch.Tensor, Optional[torch.Tensor]],
checkpoint(
create_custom_forward(layer),
x,
use_reentrant=False,
),
)
else:
x, attn = layer(x, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
if collect_attn:
attn_weights_per_layer.append(attn)
# Final normalization (Pre-LN stack)
x = self.final_norm(x)
if collect_attn:
return x, attn_weights_per_layer
return x
|