import pytest import torch from src.models.decoder import ( TransformerDecoder, TransformerDecoderLayer, create_causal_mask, ) def test_create_causal_mask_properties(): mask = create_causal_mask(5) assert mask.shape == (5, 5) # diagonal and below should be True for i in range(5): for j in range(5): if j <= i: assert mask[i, j].item() is True else: assert mask[i, j].item() is False def test_decoder_layer_shapes_and_grad(): torch.manual_seed(0) d_model, num_heads, d_ff = 32, 4, 64 batch_size, tgt_len, src_len = 2, 6, 7 layer = TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0) tgt = torch.randn(batch_size, tgt_len, d_model, requires_grad=True) memory = torch.randn(batch_size, src_len, d_model) # No masks out, attn = layer(tgt, memory, tgt_mask=None, memory_mask=None, collect_attn=True) assert out.shape == (batch_size, tgt_len, d_model) assert isinstance(attn, dict) assert "self" in attn and "cross" in attn assert attn["self"].shape == (batch_size, num_heads, tgt_len, tgt_len) assert attn["cross"].shape == (batch_size, num_heads, tgt_len, src_len) # Backprop works loss = out.sum() loss.backward() grads = [p.grad for p in layer.parameters() if p.requires_grad] assert any(g is not None for g in grads) def test_decoder_layer_causal_mask_blocks_future(): torch.manual_seed(1) d_model, num_heads, d_ff = 48, 6, 128 batch_size, tgt_len, src_len = 1, 5, 5 layer = TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0) # create trivial increasing tgt embeddings so attention patterns are deterministic-ish tgt = torch.randn(batch_size, tgt_len, d_model) memory = torch.randn(batch_size, src_len, d_model) causal = create_causal_mask(tgt_len, device=tgt.device) # (T, T) tgt_mask = causal.unsqueeze(0) # (1, T, T) -> layer will handle unsqueeze to heads out, attn = layer(tgt, memory, tgt_mask=tgt_mask, memory_mask=None, collect_attn=True) self_attn = attn["self"].detach() # Ensure upper triangle of attention weights is zero (no future attention) # For each head and query i, keys j>i should be zero B, H, Tq, Tk = self_attn.shape for i in range(Tq): for j in range(i + 1, Tk): assert torch.allclose(self_attn[:, :, i, j], torch.zeros(B, H)), ( f"Found nonzero attention to future position {j} from query {i}" ) def test_decoder_stack_and_greedy_decode_shapes(): torch.manual_seed(2) vocab_size = 30 d_model = 32 num_layers = 2 num_heads = 4 d_ff = 128 batch_size = 2 src_len = 7 max_tgt = 6 decoder = TransformerDecoder( vocab_size=vocab_size, d_model=d_model, num_layers=num_layers, num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=max_tgt, pad_token_id=0, ) # Random memory from encoder memory = torch.randn(batch_size, src_len, d_model) # Greedy decode: should produce (B, <= max_tgt) generated = decoder.greedy_decode(memory, max_len=max_tgt, start_token_id=1, end_token_id=None) assert generated.shape[0] == batch_size assert generated.shape[1] <= max_tgt assert (generated[:, 0] == 1).all() # starts with start token # Also test forward with embeddings and collect_attn embeddings = torch.randn(batch_size, max_tgt, d_model) logits, attn_list = decoder(embeddings, memory, collect_attn=True) assert logits.shape == (batch_size, max_tgt, vocab_size) assert isinstance(attn_list, list) assert len(attn_list) == num_layers for attn in attn_list: assert "self" in attn and "cross" in attn def test_decoder_train_eval_dropout_behavior(): torch.manual_seed(3) vocab_size = 40 d_model = 32 num_layers = 2 num_heads = 4 d_ff = 128 batch_size = 2 src_len = 6 tgt_len = 5 decoder = TransformerDecoder( vocab_size=vocab_size, d_model=d_model, num_layers=num_layers, num_heads=num_heads, d_ff=d_ff, dropout=0.4, max_len=tgt_len, pad_token_id=0, ) # token ids with padding possible input_ids = torch.randint(1, vocab_size, (batch_size, tgt_len), dtype=torch.long) input_ids[0, -1] = 0 memory = torch.randn(batch_size, src_len, d_model) decoder.train() out1 = decoder(input_ids, memory) out2 = decoder(input_ids, memory) # With dropout in train mode, outputs should usually differ assert not torch.allclose(out1, out2) decoder.eval() out3 = decoder(input_ids, memory) out4 = decoder(input_ids, memory) assert torch.allclose(out3, out4) if __name__ == "__main__": pytest.main([__file__, "-q"])