LexiMind / src /models /heads.py
OliverPerrin
Full training run, code cleanup, mypy/ruff fixes
590a604
"""Prediction heads for Transformer models.
This module provides task-specific output heads:
- ClassificationHead: Sequence-level classification with pooling (mean/cls/max)
- TokenClassificationHead: Per-token classification (NER, POS tagging)
- LMHead: Language modeling logits with optional weight tying
- ProjectionHead: MLP for representation learning / contrastive tasks
Author: Oliver Perrin
Date: 2025-10-23
"""
from typing import Literal, Optional
import torch
import torch.nn as nn
class ClassificationHead(nn.Module):
"""
Sequence-level classification head.
Args:
d_model: hidden size from encoder/decoder
num_labels: number of output classes
pooler: one of 'mean', 'cls', 'max' - how to pool the sequence
dropout: dropout probability before final linear layer
"""
def __init__(
self,
d_model: int,
num_labels: int,
pooler: Literal["mean", "cls", "max"] = "mean",
dropout: float = 0.1,
):
super().__init__()
assert pooler in ("mean", "cls", "max"), "pooler must be 'mean'|'cls'|'max'"
self.pooler = pooler
self.dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(d_model, num_labels)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
x: (batch, seq_len, d_model)
mask: (batch, seq_len) - True for valid tokens, False for padding
returns: (batch, num_labels)
"""
if self.pooler == "mean":
if mask is not None:
# mask is (B, S)
# x is (B, S, D)
# Expand mask to (B, S, 1)
mask_expanded = mask.unsqueeze(-1).float()
# Zero out padding
x = x * mask_expanded
# Sum over sequence
sum_embeddings = x.sum(dim=1)
# Count valid tokens
sum_mask = mask_expanded.sum(dim=1)
# Avoid division by zero
sum_mask = torch.clamp(sum_mask, min=1e-9)
pooled = sum_embeddings / sum_mask
else:
pooled = x.mean(dim=1)
elif self.pooler == "cls":
pooled = x[:, 0, :]
else: # max
if mask is not None:
# Mask padding with -inf
mask_expanded = mask.unsqueeze(-1)
x = x.masked_fill(~mask_expanded, float("-inf"))
pooled, _ = x.max(dim=1)
pooled = self.dropout(pooled)
return self.out_proj(pooled)
class TokenClassificationHead(nn.Module):
"""
Per-token classification head. Useful for NER, POS, etc.
Args:
d_model: hidden size
num_labels: number of per-token classes
dropout: dropout probability applied before the linear layer
"""
def __init__(self, d_model: int, num_labels: int, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(d_model, num_labels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch, seq_len, d_model)
returns: (batch, seq_len, num_labels)
"""
x = self.dropout(x)
return self.out_proj(x)
class LMHead(nn.Module):
"""
Language modeling head: maps hidden states to logits over vocabulary.
Args:
d_model: hidden size
vocab_size: vocabulary size
tie_embedding: optional nn.Embedding instance to tie weights with
"""
def __init__(self, d_model: int, vocab_size: int, tie_embedding: Optional[nn.Embedding] = None):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.proj = nn.Linear(d_model, vocab_size, bias=True)
if tie_embedding is not None:
# Validate sizes
assert tie_embedding.num_embeddings == vocab_size, (
"vocab size mismatch for weight tying"
)
assert tie_embedding.embedding_dim == d_model, (
"embedding dim must match d_model for weight tying"
)
# Tie weights: point the projection weight to the embedding weight Tensor
# Remove the existing projection parameter in favor of the embedding weight
# This keeps the same Parameter object, so updates affect both modules.
self.proj.weight = tie_embedding.weight
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
hidden_states: (batch, seq_len, d_model)
returns logits: (batch, seq_len, vocab_size)
"""
return self.proj(hidden_states)
class ProjectionHead(nn.Module):
"""
Simple projection head for representation learning.
Args:
d_model: input dimension
proj_dim: output projection dimension
hidden_dim: intermediate dimension (optional)
dropout: dropout probability
"""
def __init__(
self,
d_model: int,
proj_dim: int = 128,
hidden_dim: Optional[int] = None,
dropout: float = 0.1,
):
super().__init__()
if hidden_dim is None:
hidden_dim = max(d_model, proj_dim)
self.net = nn.Sequential(
nn.Linear(d_model, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, proj_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch, d_model) or (batch, seq_len, d_model) - both supported.
Returns:
If input is 3D: (batch, seq_len, proj_dim)
If input is 2D: (batch, proj_dim)
"""
orig_dim = x.dim()
if orig_dim == 3:
B, T, D = x.shape
out = self.net(x.view(B * T, D))
return out.view(B, T, -1)
elif orig_dim == 2:
return self.net(x)
else:
raise ValueError("Input must be 2D or 3D tensor")