LexiMind / src /models /encoder.py
OliverPerrin
Update LexiMind: improved training, model architecture, and evaluation
1ec7405
"""Transformer Encoder implementation (Pre-LN).
This module implements the encoder component of the Transformer architecture:
- TransformerEncoderLayer: Single encoder block with self-attention + FFN
- TransformerEncoder: Full stack with embeddings and positional encoding
Design notes:
- Pre-LN with RMSNorm for training stability
- Masks are boolean: True = attend, False = mask
- Supports T5-style relative position bias
Author: Oliver Perrin
Date: 2025-10-23
"""
from typing import List, Literal, Optional, Tuple, Union, cast
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
# Encoder implementation
from .attention import MultiHeadAttention, T5RelativePositionBias
from .feedforward import FeedForward
from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
from .t5_layer_norm import T5LayerNorm
class TransformerEncoderLayer(nn.Module):
"""
Single Transformer encoder layer (Pre-LN).
Args:
d_model: model hidden size
num_heads: number of attention heads
d_ff: hidden dimension of the position-wise feed-forward network
dropout: dropout probability applied to sublayer outputs
quantization: optional quantization mode ("4bit", "8bit")
activation: activation function for FFN ("gelu", "relu", or "swiglu")
scale_attn_scores: Whether to scale attention scores by sqrt(d_k). T5 does NOT scale.
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1,
quantization: Optional[str] = None,
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
scale_attn_scores: bool = True, # T5 uses False
):
super().__init__()
self.self_attn = MultiHeadAttention(
d_model=d_model,
num_heads=num_heads,
dropout=0.0,
quantization=quantization,
scale_scores=scale_attn_scores,
)
# set MHA internal dropout to 0.0 and use dropout1/dropout2 in the layer
self.ffn = FeedForward(
d_model=d_model,
d_ff=d_ff,
dropout=dropout,
activation=activation,
quantization=quantization,
)
self.norm1 = T5LayerNorm(d_model)
self.norm2 = T5LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
collect_attn: bool = False,
position_bias: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""
Forward pass for the encoder layer.
Args:
x: (batch, seq_len, d_model) - input embeddings / representations
mask: optional attention mask, shape either (batch, seq_q, seq_k) or (batch, 1, seq_q, seq_k)
collect_attn: whether to return attention weights
position_bias: optional (1, num_heads, seq_q, seq_k) T5-style relative position bias
Returns:
x: (batch, seq_len, d_model)
If you want attention weights, set collect_attn externally (the encoder stack can collect them).
"""
# Self-attention sublayer (Pre-LN)
x_norm = self.norm1(x) # Pre-LN
# self_attn expects query, key, value; for encoder they are the same
attn_out, attn_weights = self.self_attn(
x_norm,
x_norm,
x_norm,
mask,
return_attn_weights=collect_attn,
position_bias=position_bias,
)
x = x + self.dropout1(attn_out)
# Clamp inf values for fp16/bf16 training stability (like HuggingFace T5)
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
clamp_value = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
# Feed-forward sublayer (Pre-LN)
x_norm = self.norm2(x)
ffn_out = self.ffn(x_norm)
x = x + self.dropout2(ffn_out)
# Clamp inf values for fp16/bf16 training stability
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
clamp_value = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
# Return output (and optionally attn_weights if caller wants to collect them)
return x, attn_weights
class TransformerEncoder(nn.Module):
"""
Full encoder: token embedding + positional encoding + N encoder layers.
Args:
vocab_size: vocabulary size (ignored if you always pass embeddings)
d_model: model hidden size
num_layers: number of encoder layers to stack
num_heads: number of attention heads
d_ff: hidden dimension in FFN
dropout: dropout probability (applied in positional encoding & residuals)
max_len: maximum sequence length for positional encoding
pad_token_id: optional token id for padding; if provided and input is token ids,
a padding mask will be constructed automatically
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_layers: int = 6,
num_heads: int = 8,
d_ff: int = 2048,
dropout: float = 0.1,
max_len: int = 512,
pad_token_id: Optional[int] = None,
quantization: Optional[str] = None,
use_learned_pos_enc: bool = False,
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
use_relative_position_bias: bool = False, # T5-style relative position bias
gradient_checkpointing: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.pad_token_id = pad_token_id
self.use_relative_position_bias = use_relative_position_bias
self.gradient_checkpointing = gradient_checkpointing
# Token embedding (only used if forward receives token ids)
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
# Positional encoding (disabled when using relative position bias for T5)
self.relative_position_bias: Optional[T5RelativePositionBias] = None
if use_relative_position_bias:
# T5 uses relative position bias instead of absolute positional embeddings
self.pos_encoder = None
self.relative_position_bias = T5RelativePositionBias(
num_heads=num_heads,
num_buckets=32,
max_distance=128,
is_decoder=False,
)
elif use_learned_pos_enc:
# T5 uses max_len=512 by default; we add buffer for special tokens
self.pos_encoder = LearnedPositionalEncoding(
d_model=d_model, max_len=max_len + 2, dropout=dropout
)
else:
self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)
# T5 does NOT scale attention scores by sqrt(d_k), others do
scale_attn_scores = not use_relative_position_bias
# Encoder layers stack
self.layers = nn.ModuleList(
[
TransformerEncoderLayer(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout=dropout,
quantization=quantization,
activation=activation,
scale_attn_scores=scale_attn_scores,
)
for _ in range(num_layers)
]
)
# Final T5LayerNorm for Pre-LN stacks
self.final_norm = T5LayerNorm(d_model)
# Dropout applied after embedding + positional encoding (paper uses this)
self.input_dropout = nn.Dropout(dropout)
def _build_padding_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Build a 3D attention mask (batch, seq, seq) from input_ids and pad_token_id.
True indicates valid positions; False indicates masked (pad).
"""
assert self.pad_token_id is not None, (
"pad_token_id must be set to build padding mask from ids."
)
# mask shape: (batch, seq) where True = token kept (non-pad)
pad_mask = input_ids != self.pad_token_id
# Convert to (batch, seq_q, seq_k) by outer product broadcasting
# We want positions that are valid as both query and key
attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2)
# attn_mask dtype should be bool
return attn_mask
def forward(
self,
inputs: torch.Tensor,
mask: Optional[torch.Tensor] = None,
collect_attn: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
Forward through the encoder.
Args:
inputs: either
- token ids: LongTensor of shape (batch, seq)
- embeddings: FloatTensor of shape (batch, seq, d_model)
mask: optional attention mask. If None and pad_token_id is set and inputs are token ids,
a padding mask will be created automatically with shape (batch, seq, seq).
The mask should be boolean where True indicates allowed attention.
collect_attn: if True, returns (output, [attn_weights_per_layer]) where each entry is (batch, num_heads, seq, seq)
Returns:
output: (batch, seq, d_model)
or (output, attn_list) if collect_attn True
"""
# If inputs are token ids, embed them; otherwise assume they are embeddings
if inputs.dim() == 2: # token ids
if self.embedding is None:
raise ValueError("Encoder was not constructed with an embedding layer.")
# T5/FLAN-T5 does NOT scale embeddings by sqrt(d_model)
x = self.embedding(inputs)
seq_len = inputs.size(1)
elif inputs.dim() == 3: # already embeddings
x = inputs
seq_len = inputs.size(1)
else:
raise ValueError(
"inputs must be (batch, seq) token ids or (batch, seq, d_model) embeddings"
)
# Positional encoding + dropout (only if not using relative position bias)
if self.pos_encoder is not None:
x = self.pos_encoder(x)
x = self.input_dropout(x)
# Build mask if needed
if mask is None and inputs.dim() == 2 and self.pad_token_id is not None:
mask = self._build_padding_mask(inputs)
# Ensure mask is boolean and on the same device
if mask is not None:
mask = mask.to(dtype=torch.bool, device=x.device)
# Compute relative position bias if using T5-style
position_bias = None
if self.relative_position_bias is not None:
position_bias = self.relative_position_bias(seq_len, seq_len, x.device)
attn_weights_per_layer: List[torch.Tensor] = []
# Pass through each encoder layer (optionally collect attn)
for layer in self.layers:
if self.gradient_checkpointing and self.training:
# Gradient checkpointing requires the inputs to require grad
# We use a lambda to pass keyword arguments
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
return custom_forward
x, attn = cast(
Tuple[torch.Tensor, Optional[torch.Tensor]],
checkpoint(
create_custom_forward(layer),
x,
use_reentrant=False,
),
)
else:
x, attn = layer(x, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
if collect_attn:
attn_weights_per_layer.append(attn)
# Final normalization (Pre-LN stack)
x = self.final_norm(x)
if collect_attn:
return x, attn_weights_per_layer
return x