Spaces:
Running
Running
| """Transformer Encoder implementation (Pre-LN). | |
| This module implements the encoder component of the Transformer architecture: | |
| - TransformerEncoderLayer: Single encoder block with self-attention + FFN | |
| - TransformerEncoder: Full stack with embeddings and positional encoding | |
| Design notes: | |
| - Pre-LN with RMSNorm for training stability | |
| - Masks are boolean: True = attend, False = mask | |
| - Supports T5-style relative position bias | |
| Author: Oliver Perrin | |
| Date: 2025-10-23 | |
| """ | |
| from typing import List, Literal, Optional, Tuple, Union, cast | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.checkpoint import checkpoint | |
| # Encoder implementation | |
| from .attention import MultiHeadAttention, T5RelativePositionBias | |
| from .feedforward import FeedForward | |
| from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding | |
| from .t5_layer_norm import T5LayerNorm | |
| class TransformerEncoderLayer(nn.Module): | |
| """ | |
| Single Transformer encoder layer (Pre-LN). | |
| Args: | |
| d_model: model hidden size | |
| num_heads: number of attention heads | |
| d_ff: hidden dimension of the position-wise feed-forward network | |
| dropout: dropout probability applied to sublayer outputs | |
| quantization: optional quantization mode ("4bit", "8bit") | |
| activation: activation function for FFN ("gelu", "relu", or "swiglu") | |
| scale_attn_scores: Whether to scale attention scores by sqrt(d_k). T5 does NOT scale. | |
| """ | |
| def __init__( | |
| self, | |
| d_model: int, | |
| num_heads: int, | |
| d_ff: int, | |
| dropout: float = 0.1, | |
| quantization: Optional[str] = None, | |
| activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu", | |
| scale_attn_scores: bool = True, # T5 uses False | |
| ): | |
| super().__init__() | |
| self.self_attn = MultiHeadAttention( | |
| d_model=d_model, | |
| num_heads=num_heads, | |
| dropout=0.0, | |
| quantization=quantization, | |
| scale_scores=scale_attn_scores, | |
| ) | |
| # set MHA internal dropout to 0.0 and use dropout1/dropout2 in the layer | |
| self.ffn = FeedForward( | |
| d_model=d_model, | |
| d_ff=d_ff, | |
| dropout=dropout, | |
| activation=activation, | |
| quantization=quantization, | |
| ) | |
| self.norm1 = T5LayerNorm(d_model) | |
| self.norm2 = T5LayerNorm(d_model) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| collect_attn: bool = False, | |
| position_bias: Optional[torch.Tensor] = None, | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: | |
| """ | |
| Forward pass for the encoder layer. | |
| Args: | |
| x: (batch, seq_len, d_model) - input embeddings / representations | |
| mask: optional attention mask, shape either (batch, seq_q, seq_k) or (batch, 1, seq_q, seq_k) | |
| collect_attn: whether to return attention weights | |
| position_bias: optional (1, num_heads, seq_q, seq_k) T5-style relative position bias | |
| Returns: | |
| x: (batch, seq_len, d_model) | |
| If you want attention weights, set collect_attn externally (the encoder stack can collect them). | |
| """ | |
| # Self-attention sublayer (Pre-LN) | |
| x_norm = self.norm1(x) # Pre-LN | |
| # self_attn expects query, key, value; for encoder they are the same | |
| attn_out, attn_weights = self.self_attn( | |
| x_norm, | |
| x_norm, | |
| x_norm, | |
| mask, | |
| return_attn_weights=collect_attn, | |
| position_bias=position_bias, | |
| ) | |
| x = x + self.dropout1(attn_out) | |
| # Clamp inf values for fp16/bf16 training stability (like HuggingFace T5) | |
| if x.dtype == torch.float16 or x.dtype == torch.bfloat16: | |
| clamp_value = torch.finfo(x.dtype).max - 1000 | |
| x = torch.clamp(x, min=-clamp_value, max=clamp_value) | |
| # Feed-forward sublayer (Pre-LN) | |
| x_norm = self.norm2(x) | |
| ffn_out = self.ffn(x_norm) | |
| x = x + self.dropout2(ffn_out) | |
| # Clamp inf values for fp16/bf16 training stability | |
| if x.dtype == torch.float16 or x.dtype == torch.bfloat16: | |
| clamp_value = torch.finfo(x.dtype).max - 1000 | |
| x = torch.clamp(x, min=-clamp_value, max=clamp_value) | |
| # Return output (and optionally attn_weights if caller wants to collect them) | |
| return x, attn_weights | |
| class TransformerEncoder(nn.Module): | |
| """ | |
| Full encoder: token embedding + positional encoding + N encoder layers. | |
| Args: | |
| vocab_size: vocabulary size (ignored if you always pass embeddings) | |
| d_model: model hidden size | |
| num_layers: number of encoder layers to stack | |
| num_heads: number of attention heads | |
| d_ff: hidden dimension in FFN | |
| dropout: dropout probability (applied in positional encoding & residuals) | |
| max_len: maximum sequence length for positional encoding | |
| pad_token_id: optional token id for padding; if provided and input is token ids, | |
| a padding mask will be constructed automatically | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| d_model: int = 512, | |
| num_layers: int = 6, | |
| num_heads: int = 8, | |
| d_ff: int = 2048, | |
| dropout: float = 0.1, | |
| max_len: int = 512, | |
| pad_token_id: Optional[int] = None, | |
| quantization: Optional[str] = None, | |
| use_learned_pos_enc: bool = False, | |
| activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu", | |
| use_relative_position_bias: bool = False, # T5-style relative position bias | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.pad_token_id = pad_token_id | |
| self.use_relative_position_bias = use_relative_position_bias | |
| self.gradient_checkpointing = gradient_checkpointing | |
| # Token embedding (only used if forward receives token ids) | |
| self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) | |
| # Positional encoding (disabled when using relative position bias for T5) | |
| self.relative_position_bias: Optional[T5RelativePositionBias] = None | |
| if use_relative_position_bias: | |
| # T5 uses relative position bias instead of absolute positional embeddings | |
| self.pos_encoder = None | |
| self.relative_position_bias = T5RelativePositionBias( | |
| num_heads=num_heads, | |
| num_buckets=32, | |
| max_distance=128, | |
| is_decoder=False, | |
| ) | |
| elif use_learned_pos_enc: | |
| # T5 uses max_len=512 by default; we add buffer for special tokens | |
| self.pos_encoder = LearnedPositionalEncoding( | |
| d_model=d_model, max_len=max_len + 2, dropout=dropout | |
| ) | |
| else: | |
| self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout) | |
| # T5 does NOT scale attention scores by sqrt(d_k), others do | |
| scale_attn_scores = not use_relative_position_bias | |
| # Encoder layers stack | |
| self.layers = nn.ModuleList( | |
| [ | |
| TransformerEncoderLayer( | |
| d_model=d_model, | |
| num_heads=num_heads, | |
| d_ff=d_ff, | |
| dropout=dropout, | |
| quantization=quantization, | |
| activation=activation, | |
| scale_attn_scores=scale_attn_scores, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| # Final T5LayerNorm for Pre-LN stacks | |
| self.final_norm = T5LayerNorm(d_model) | |
| # Dropout applied after embedding + positional encoding (paper uses this) | |
| self.input_dropout = nn.Dropout(dropout) | |
| def _build_padding_mask(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Build a 3D attention mask (batch, seq, seq) from input_ids and pad_token_id. | |
| True indicates valid positions; False indicates masked (pad). | |
| """ | |
| assert self.pad_token_id is not None, ( | |
| "pad_token_id must be set to build padding mask from ids." | |
| ) | |
| # mask shape: (batch, seq) where True = token kept (non-pad) | |
| pad_mask = input_ids != self.pad_token_id | |
| # Convert to (batch, seq_q, seq_k) by outer product broadcasting | |
| # We want positions that are valid as both query and key | |
| attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) | |
| # attn_mask dtype should be bool | |
| return attn_mask | |
| def forward( | |
| self, | |
| inputs: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| collect_attn: bool = False, | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: | |
| """ | |
| Forward through the encoder. | |
| Args: | |
| inputs: either | |
| - token ids: LongTensor of shape (batch, seq) | |
| - embeddings: FloatTensor of shape (batch, seq, d_model) | |
| mask: optional attention mask. If None and pad_token_id is set and inputs are token ids, | |
| a padding mask will be created automatically with shape (batch, seq, seq). | |
| The mask should be boolean where True indicates allowed attention. | |
| collect_attn: if True, returns (output, [attn_weights_per_layer]) where each entry is (batch, num_heads, seq, seq) | |
| Returns: | |
| output: (batch, seq, d_model) | |
| or (output, attn_list) if collect_attn True | |
| """ | |
| # If inputs are token ids, embed them; otherwise assume they are embeddings | |
| if inputs.dim() == 2: # token ids | |
| if self.embedding is None: | |
| raise ValueError("Encoder was not constructed with an embedding layer.") | |
| # T5/FLAN-T5 does NOT scale embeddings by sqrt(d_model) | |
| x = self.embedding(inputs) | |
| seq_len = inputs.size(1) | |
| elif inputs.dim() == 3: # already embeddings | |
| x = inputs | |
| seq_len = inputs.size(1) | |
| else: | |
| raise ValueError( | |
| "inputs must be (batch, seq) token ids or (batch, seq, d_model) embeddings" | |
| ) | |
| # Positional encoding + dropout (only if not using relative position bias) | |
| if self.pos_encoder is not None: | |
| x = self.pos_encoder(x) | |
| x = self.input_dropout(x) | |
| # Build mask if needed | |
| if mask is None and inputs.dim() == 2 and self.pad_token_id is not None: | |
| mask = self._build_padding_mask(inputs) | |
| # Ensure mask is boolean and on the same device | |
| if mask is not None: | |
| mask = mask.to(dtype=torch.bool, device=x.device) | |
| # Compute relative position bias if using T5-style | |
| position_bias = None | |
| if self.relative_position_bias is not None: | |
| position_bias = self.relative_position_bias(seq_len, seq_len, x.device) | |
| attn_weights_per_layer: List[torch.Tensor] = [] | |
| # Pass through each encoder layer (optionally collect attn) | |
| for layer in self.layers: | |
| if self.gradient_checkpointing and self.training: | |
| # Gradient checkpointing requires the inputs to require grad | |
| # We use a lambda to pass keyword arguments | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs, mask=mask, collect_attn=collect_attn, position_bias=position_bias) | |
| return custom_forward | |
| x, attn = cast( | |
| Tuple[torch.Tensor, Optional[torch.Tensor]], | |
| checkpoint( | |
| create_custom_forward(layer), | |
| x, | |
| use_reentrant=False, | |
| ), | |
| ) | |
| else: | |
| x, attn = layer(x, mask=mask, collect_attn=collect_attn, position_bias=position_bias) | |
| if collect_attn: | |
| attn_weights_per_layer.append(attn) | |
| # Final normalization (Pre-LN stack) | |
| x = self.final_norm(x) | |
| if collect_attn: | |
| return x, attn_weights_per_layer | |
| return x | |