File size: 31,210 Bytes
590a604
 
 
 
 
 
 
 
 
 
 
 
 
 
1bdd1c1
ee1a8a3
67c3a83
d18b34d
5a20c96
 
1ec7405
5a20c96
b43ba56
5a20c96
b43ba56
1ec7405
5a20c96
 
 
 
1bdd1c1
 
5a20c96
1bdd1c1
 
 
5a20c96
 
 
 
1bdd1c1
 
 
 
 
5a20c96
 
d18b34d
 
 
 
 
 
 
b43ba56
 
d18b34d
5a20c96
1bdd1c1
d18b34d
b43ba56
 
 
 
 
d18b34d
 
b43ba56
 
 
 
 
d18b34d
 
b43ba56
 
 
 
 
d18b34d
5a20c96
1ec7405
 
 
5a20c96
 
 
 
 
 
 
 
 
 
 
d18b34d
b43ba56
 
d18b34d
5a20c96
 
1bdd1c1
 
 
 
d18b34d
b43ba56
 
5a20c96
 
1bdd1c1
5a20c96
1bdd1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d18b34d
b43ba56
 
 
 
 
 
d18b34d
1bdd1c1
5a20c96
b43ba56
 
 
 
 
1bdd1c1
 
d18b34d
b43ba56
 
 
 
 
 
d18b34d
1bdd1c1
5a20c96
b43ba56
 
 
 
 
1bdd1c1
 
 
 
5a20c96
b43ba56
 
 
 
 
1bdd1c1
5a20c96
 
 
 
1bdd1c1
5a20c96
1bdd1c1
5a20c96
 
 
 
 
 
 
 
 
 
 
 
d18b34d
b43ba56
 
 
1ec7405
5a20c96
 
 
 
 
b43ba56
 
1ec7405
b43ba56
 
1ec7405
b43ba56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a20c96
b43ba56
 
5a20c96
 
d18b34d
 
 
 
 
 
 
b43ba56
 
d18b34d
 
 
5a20c96
 
1ec7405
 
5a20c96
 
 
 
1bdd1c1
b43ba56
 
 
 
5a20c96
 
d18b34d
1ec7405
 
 
 
 
 
 
 
 
 
 
 
1bdd1c1
5a20c96
 
 
 
 
 
 
 
 
b43ba56
1ec7405
5a20c96
 
1bdd1c1
 
 
 
b43ba56
5a20c96
1bdd1c1
5a20c96
b43ba56
 
5a20c96
 
 
 
 
b43ba56
 
 
 
5a20c96
 
1bdd1c1
 
 
5a20c96
1bdd1c1
b43ba56
 
1bdd1c1
 
 
5a20c96
b43ba56
 
1bdd1c1
 
 
 
1ec7405
 
 
 
 
 
5a20c96
1bdd1c1
5a20c96
 
1bdd1c1
 
 
 
5a20c96
1ec7405
5a20c96
b43ba56
 
 
 
 
 
 
 
 
 
 
1bdd1c1
5a20c96
1ec7405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a20c96
 
 
1bdd1c1
1ec7405
5a20c96
1bdd1c1
5a20c96
 
 
 
b43ba56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a20c96
 
 
 
 
 
 
a504116
 
ea3248a
 
b43ba56
f0493d8
1bdd1c1
5a20c96
b43ba56
5a20c96
1bdd1c1
 
 
b43ba56
 
1bdd1c1
 
b43ba56
 
 
 
 
a504116
 
b43ba56
 
ea3248a
b43ba56
 
 
d18b34d
b43ba56
 
 
ea3248a
b43ba56
ea3248a
b43ba56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea3248a
a504116
d18b34d
ea3248a
 
 
b43ba56
ea3248a
 
b43ba56
 
ea3248a
 
 
d18b34d
 
ea3248a
d18b34d
ea3248a
 
 
 
 
d18b34d
ea3248a
 
 
b43ba56
a504116
b43ba56
 
1bdd1c1
 
b43ba56
1bdd1c1
b43ba56
 
 
 
1bdd1c1
 
 
 
 
 
5a20c96
 
b43ba56
5a20c96
 
 
 
1bdd1c1
5a20c96
 
1bdd1c1
 
 
5a20c96
 
1bdd1c1
 
5a20c96
1bdd1c1
 
 
 
 
 
 
 
b43ba56
 
 
 
 
 
 
 
67c3a83
b43ba56
 
 
 
 
 
 
67c3a83
b43ba56
67c3a83
b43ba56
 
67c3a83
b43ba56
 
 
 
 
1bdd1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b43ba56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bdd1c1
 
67c3a83
 
1bdd1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d18b34d
 
 
 
b43ba56
d18b34d
1bdd1c1
 
 
 
d18b34d
1bdd1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d18b34d
 
 
 
 
 
1bdd1c1
 
 
 
 
 
 
 
 
d18b34d
 
 
 
 
b43ba56
d18b34d
 
 
 
 
 
1bdd1c1
d18b34d
1bdd1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d18b34d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
"""Transformer Decoder implementation (Pre-LN).

This module implements the decoder component of the Transformer architecture:
- create_causal_mask: Generate causal attention masks
- TransformerDecoderLayer: Single decoder block with self-attn + cross-attn + FFN
- TransformerDecoder: Full stack with embeddings, positional encoding, and generation

Design notes:
- Pre-LN with RMSNorm for training stability
- Masks are boolean: True = attend, False = mask
- Supports T5-style relative position bias

Author: Oliver Perrin
Date: 2025-10-23
"""

from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from .attention import MultiHeadAttention, T5RelativePositionBias
from .feedforward import FeedForward
from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
from .t5_layer_norm import T5LayerNorm


def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
    """
    Create a (seq_len, seq_len) causal mask where entry (i, j) is True iff
    j <= i (query at i may attend to keys up to i).
    """
    # torch.triu(..., diagonal=1) is True above the diagonal. Invert to get allowed positions.
    mask = ~torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1)
    return mask  # shape: (T, T)


class TransformerDecoderLayer(nn.Module):
    """
    Single decoder layer (Pre-LN):
      1) Masked self-attention
      2) Cross-attention (encoder -> decoder)
      3) Feed-forward
    Returns the updated tgt and a dict of attention maps.
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        quantization: Optional[str] = None,
        activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
        scale_attn_scores: bool = True,  # T5 uses False
    ):
        super().__init__()
        # use internal MHA dropout = 0.0; the layer handles dropout after sublayers
        self.self_attn = MultiHeadAttention(
            d_model=d_model,
            num_heads=num_heads,
            dropout=0.0,
            quantization=quantization,
            scale_scores=scale_attn_scores,
        )
        self.cross_attn = MultiHeadAttention(
            d_model=d_model,
            num_heads=num_heads,
            dropout=0.0,
            quantization=quantization,
            scale_scores=scale_attn_scores,
        )
        self.ffn = FeedForward(
            d_model=d_model,
            d_ff=d_ff,
            dropout=dropout,
            activation=activation,
            quantization=quantization,
        )

        self.norm1 = T5LayerNorm(d_model)
        self.norm2 = T5LayerNorm(d_model)
        self.norm3 = T5LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
        collect_attn: bool = False,
        self_attn_position_bias: Optional[torch.Tensor] = None,
        cross_attn_position_bias: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
        """
        Args:
            tgt: (B, T, d_model)
            memory: (B, S, d_model)
            tgt_mask: optional mask for self-attn - shape (B, T, T) or (B, 1, T, T)
            memory_mask: optional mask for cross-attn - shape (B, S) or (B, 1, S) or (B, 1, T, S)
            collect_attn: whether to return attention weights
            self_attn_position_bias: optional T5 relative position bias for self-attention
            cross_attn_position_bias: optional T5 relative position bias for cross-attention

        Returns:
            (tgt_out, {"self": self_attn_weights, "cross": cross_attn_weights})
        """
        # Ensure masks are on same device and boolean
        if tgt_mask is not None:
            tgt_mask = tgt_mask.to(dtype=torch.bool, device=tgt.device)
        if memory_mask is not None:
            memory_mask = memory_mask.to(dtype=torch.bool, device=tgt.device)
            # If memory_mask is provided as (B, S) (per-key padding), expand to (B, 1, 1, S)
            if memory_mask.dim() == 2:
                memory_mask = memory_mask.unsqueeze(1).unsqueeze(1)  # (B,1,1,S)
            # If it's (B, S, S) or (B, 1, S, S) leave as-is; if (B, T, S) convert to (B,1,T,S)
            elif memory_mask.dim() == 3 and memory_mask.shape[1] != 1:
                # assume (B, T, S) -> make (B, 1, T, S)
                memory_mask = memory_mask.unsqueeze(1)

        # --- Masked self-attention (Pre-LN) ---
        x_norm = self.norm1(tgt)
        self_out, self_attn = self.self_attn(
            x_norm,
            x_norm,
            x_norm,
            tgt_mask,
            return_attn_weights=collect_attn,
            position_bias=self_attn_position_bias,
        )
        tgt = tgt + self.dropout1(self_out)

        # Clamp inf values for fp16/bf16 training stability (like HuggingFace T5)
        if tgt.dtype == torch.float16 or tgt.dtype == torch.bfloat16:
            clamp_value = torch.finfo(tgt.dtype).max - 1000
            tgt = torch.clamp(tgt, min=-clamp_value, max=clamp_value)

        # --- Cross-attention (Pre-LN) ---
        x_norm = self.norm2(tgt)
        cross_out, cross_attn = self.cross_attn(
            x_norm,
            memory,
            memory,
            memory_mask,
            return_attn_weights=collect_attn,
            position_bias=cross_attn_position_bias,
        )
        tgt = tgt + self.dropout2(cross_out)

        # Clamp inf values for fp16/bf16 training stability
        if tgt.dtype == torch.float16 or tgt.dtype == torch.bfloat16:
            clamp_value = torch.finfo(tgt.dtype).max - 1000
            tgt = torch.clamp(tgt, min=-clamp_value, max=clamp_value)

        # --- Feed-forward (Pre-LN) ---
        x_norm = self.norm3(tgt)
        ffn_out = self.ffn(x_norm)
        tgt = tgt + self.dropout3(ffn_out)

        # Clamp inf values for fp16/bf16 training stability
        if tgt.dtype == torch.float16 or tgt.dtype == torch.bfloat16:
            clamp_value = torch.finfo(tgt.dtype).max - 1000
            tgt = torch.clamp(tgt, min=-clamp_value, max=clamp_value)

        return tgt, {"self": self_attn, "cross": cross_attn}


class TransformerDecoder(nn.Module):
    """
    Decoder stack with token embeddings and positional encoding.

    Forward returns logits (B, T, vocab_size) by default; if collect_attn=True returns (logits, attn_list).
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_layers: int = 6,
        num_heads: int = 8,
        d_ff: int = 2048,
        dropout: float = 0.1,
        max_len: int = 512,
        pad_token_id: Optional[int] = None,
        quantization: Optional[str] = None,
        use_learned_pos_enc: bool = False,
        activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
        use_relative_position_bias: bool = False,  # T5-style relative position bias
        gradient_checkpointing: bool = False,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.pad_token_id = pad_token_id
        self.num_heads = num_heads
        self.use_relative_position_bias = use_relative_position_bias
        self.gradient_checkpointing = gradient_checkpointing

        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
        # Note: T5 does NOT scale logits (scaling factor removed)

        # Positional encoding (disabled when using relative position bias for T5)
        self.self_relative_position_bias: Optional[T5RelativePositionBias] = None
        self.cross_relative_position_bias: Optional[T5RelativePositionBias] = None
        if use_relative_position_bias:
            # T5 uses relative position bias instead of absolute positional embeddings
            self.pos_encoder = None
            # Self-attention position bias (decoder is causal, so is_decoder=True)
            self.self_relative_position_bias = T5RelativePositionBias(
                num_heads=num_heads,
                num_buckets=32,
                max_distance=128,
                is_decoder=True,
            )
            # T5 cross-attention does NOT use position bias
        elif use_learned_pos_enc:
            self.pos_encoder = LearnedPositionalEncoding(
                d_model=d_model, max_len=max_len + 2, dropout=dropout
            )
        else:
            self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)

        # T5 does NOT scale attention scores by sqrt(d_k), others do
        scale_attn_scores = not use_relative_position_bias

        self.layers = nn.ModuleList(
            [
                TransformerDecoderLayer(
                    d_model=d_model,
                    num_heads=num_heads,
                    d_ff=d_ff,
                    dropout=dropout,
                    quantization=quantization,
                    activation=activation,
                    scale_attn_scores=scale_attn_scores,
                )
                for _ in range(num_layers)
            ]
        )

        self.final_norm = T5LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size, bias=False)  # T5 has no bias
        self.input_dropout = nn.Dropout(dropout)

    def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Convert input ids to (B, T, T) boolean mask where True = allowed.

        Note: For T5, pad_token_id=0 is also used as decoder_start_token_id.
        During generation, we should NOT mask the start token. The caller should
        provide an explicit mask or set tgt_mask to avoid this issue.
        """
        assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
        pad_mask = input_ids != self.pad_token_id  # (B, T)

        # Always allow attending to the first token (BOS), even if it is pad_token_id
        # Avoid in-place mutation for better torch.compile compatibility
        if pad_mask.size(1) > 0:
            # Create a mask for the first column (B, 1)
            first_col_mask = torch.zeros_like(pad_mask[:, :1], dtype=torch.bool)
            first_col_mask[:] = True
            # Combine: pad_mask OR (column==0)
            # We can do this by creating a column index tensor
            col_indices = torch.arange(pad_mask.size(1), device=pad_mask.device).unsqueeze(0)
            pad_mask = pad_mask | (col_indices == 0)

        attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2)  # (B, T, T)
        return attn_mask

    def forward(
        self,
        inputs: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
        collect_attn: bool = False,
        skip_padding_mask: bool = False,  # Set True during generation to avoid masking start token
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, Optional[torch.Tensor]]]]]:
        """
        Args:
            inputs: (B, T) token ids or (B, T, d_model) embeddings
            memory: (B, S, d_model)
            tgt_mask: optional; if None, will create (causal [+ padding if ids available])
            memory_mask: optional; if provided as (B, S) will be expanded to (B, 1, 1, S)
            skip_padding_mask: if True, only use causal mask (for generation where start_token=pad_token)
        """
        # Prepare embeddings
        if inputs.dim() == 2:  # token ids
            # T5/FLAN-T5 does NOT scale embeddings by sqrt(d_model)
            x = self.embedding(inputs)
        elif inputs.dim() == 3:
            x = inputs
        else:
            raise ValueError("inputs must be (B, T) token ids or (B, T, d_model) embeddings")

        # Apply positional encoding if not using relative position bias
        # (T5 uses relative position bias in attention instead of absolute positional embeddings)
        if self.pos_encoder is not None:
            x = self.pos_encoder(x)
        x = self.input_dropout(x)

        B, T, _ = x.shape

        # Build target mask if not provided: combine causal + padding (if available)
        if tgt_mask is None:
            causal = create_causal_mask(T, device=x.device)  # (T, T)
            if inputs.dim() == 2 and self.pad_token_id is not None and not skip_padding_mask:
                # During training: combine causal mask with padding mask
                pad_pairwise = self._build_padding_mask_from_ids(inputs)  # (B, T, T)
                combined = pad_pairwise & causal.unsqueeze(0)  # (B, T, T)
                tgt_mask = combined.unsqueeze(1)  # (B, 1, T, T) -> broadcast to heads
            else:
                # During generation (skip_padding_mask=True) or no padding info:
                # Use only causal mask - don't mask based on token values
                tgt_mask = causal.unsqueeze(0).unsqueeze(1)  # (1, 1, T, T)
        else:
            # Ensure boolean and device alignment; accept (B, T, T) or (B,1,T,T) or (1,1,T,T)
            tgt_mask = tgt_mask.to(dtype=torch.bool, device=x.device)
            # If tgt_mask is just causal (T, T), expand it
            if tgt_mask.dim() == 2:
                tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)
            elif tgt_mask.dim() == 3:
                tgt_mask = tgt_mask.unsqueeze(1)


        # Normalize memory_mask dtype/device and expand simple shapes
        if memory_mask is not None:
            memory_mask = memory_mask.to(dtype=torch.bool, device=x.device)
            if memory_mask.dim() == 2:  # (B, S) -> (B, 1, 1, S)
                memory_mask = memory_mask.unsqueeze(1).unsqueeze(1)
            elif memory_mask.dim() == 3:  # (B, T, S) -> (B, 1, T, S)
                memory_mask = memory_mask.unsqueeze(1)

        attn_list: List[Dict[str, Optional[torch.Tensor]]] = []

        # Compute relative position biases (T5-style)
        # Note: T5 uses relative position bias for self-attention but NOT for cross-attention
        if self.use_relative_position_bias and self.self_relative_position_bias is not None:
            self_position_bias = self.self_relative_position_bias(
                T, T, x.device
            )  # (1, num_heads, T, T)
        else:
            self_position_bias = None
        # Cross-attention position bias is None for T5 (see T5 paper/implementation)
        cross_position_bias = None

        # Pass through decoder layers
        for layer in self.layers:
            if self.gradient_checkpointing and self.training:
                # Gradient checkpointing requires the inputs to require grad
                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, tgt_mask=tgt_mask, memory_mask=memory_mask, collect_attn=collect_attn, self_attn_position_bias=self_position_bias, cross_attn_position_bias=cross_position_bias)
                    return custom_forward

                x, attn = cast(
                    Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]],
                    checkpoint(
                        create_custom_forward(layer),
                        x,
                        memory,
                        use_reentrant=False,
                    ),
                )
            else:
                x, attn = layer(
                    x,
                    memory,
                    tgt_mask=tgt_mask,
                    memory_mask=memory_mask,
                    collect_attn=collect_attn,
                    self_attn_position_bias=self_position_bias,
                    cross_attn_position_bias=cross_position_bias,
                )
            if collect_attn:
                attn_list.append(attn)

        x = self.final_norm(x)
        # T5 does NOT scale logits - direct projection to vocabulary
        logits = self.output_projection(x)  # (B, T, vocab)

        if collect_attn:
            return logits, attn_list
        return logits

    def greedy_decode_naive(
        self,
        memory: torch.Tensor,
        max_len: int,
        start_token_id: int,
        end_token_id: Optional[int] = None,
        device: Optional[torch.device] = None,
        memory_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Naive greedy decoding using full forward pass (O(N^2) but simpler).
        Used for debugging to verify step() correctness.
        """
        if device is None:
            device = memory.device
        B = memory.size(0)

        # Initialize with start token
        generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)

        for _ in range(max_len - 1):
            # Full forward pass on entire generated sequence
            # skip_padding_mask=True because start_token=pad_token for T5
            logits = self.forward(
                generated, memory, memory_mask=memory_mask, skip_padding_mask=True
            )
            if isinstance(logits, tuple):
                logits = logits[0]
            # logits: (B, T, vocab)

            # Get logits for last position
            next_logits = logits[:, -1, :]  # (B, vocab)

            # Greedy: pick highest probability token
            next_token = next_logits.argmax(dim=-1, keepdim=True)  # (B, 1)

            # Append to generated
            generated = torch.cat([generated, next_token], dim=1)

            # Check for EOS
            if end_token_id is not None and (next_token == end_token_id).all():
                break

        return generated

    def greedy_decode(
        self,
        memory: torch.Tensor,
        max_len: int,
        start_token_id: int,
        end_token_id: Optional[int] = None,
        device: Optional[torch.device] = None,
        *,
        min_len: Optional[int] = None,
        ban_token_ids: Optional[List[int]] = None,
        no_repeat_ngram_size: int = 0,
        repetition_penalty: float = 1.0,
        memory_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Greedy decoding with KV caching for O(N) complexity.
        """
        if device is None:
            device = memory.device
        B = memory.size(0)

        # Initialize generated sequence with start token
        generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)

        # Initialize cache
        cache: Dict[str, Any] = {"past_length": 0}
        if memory_mask is not None:
            cache["memory_mask"] = memory_mask

        min_len = 0 if min_len is None else max(0, min_len)

        # Keep track of finished sequences
        finished = torch.zeros(B, dtype=torch.bool, device=device)

        for _ in range(max_len - 1):
            # Use the last generated token for the next step
            last_token = generated[:, -1:]  # (B, 1)

            # Run one step of the decoder
            logits, cache = self.step(last_token, memory, cache)
            # logits: (B, vocab_size)

            next_step_logits = logits.clone()

            # Apply repetition penalty
            if repetition_penalty != 1.0:
                for b in range(B):
                    if finished[b]:
                        continue
                    gen_seq = generated[b]
                    unique_tokens = torch.unique(gen_seq)
                    current_logits = next_step_logits[b, unique_tokens]
                    next_step_logits[b, unique_tokens] = torch.where(
                        current_logits < 0,
                        current_logits * repetition_penalty,
                        current_logits / repetition_penalty,
                    )

            # Apply constraints
            if end_token_id is not None and generated.size(1) < max(1, min_len):
                next_step_logits[:, end_token_id] = float("-inf")

            if ban_token_ids:
                next_step_logits[:, ban_token_ids] = float("-inf")

            # N-gram repetition blocking
            if no_repeat_ngram_size > 0:
                for b in range(B):
                    if finished[b]:
                        continue
                    gen_seq = generated[b].tolist()
                    if len(gen_seq) < no_repeat_ngram_size - 1:
                        continue

                    prefix = tuple(gen_seq[-(no_repeat_ngram_size - 1) :])
                    banned_for_this_batch = set()

                    for i in range(len(gen_seq) - no_repeat_ngram_size + 1):
                        window = tuple(gen_seq[i : i + no_repeat_ngram_size - 1])
                        if window == prefix:
                            if i + no_repeat_ngram_size - 1 < len(gen_seq):
                                banned_for_this_batch.add(gen_seq[i + no_repeat_ngram_size - 1])

                    if banned_for_this_batch:
                        next_step_logits[b, list(banned_for_this_batch)] = float("-inf")

            # Greedy selection
            next_token = next_step_logits.argmax(dim=-1, keepdim=True)  # (B, 1)

            # Update generated sequence
            generated = torch.cat([generated, next_token], dim=1)

            # Check for completion
            if end_token_id is not None:
                is_end = next_token.squeeze(-1) == end_token_id
                finished = finished | is_end
                if finished.all() and generated.size(1) >= max(1, min_len):
                    break

        return generated

    # -----------------------------
    # Incremental single-step API
    # -----------------------------
    def step(
        self,
        last_token_ids: torch.Tensor,
        memory: torch.Tensor,
        cache: Optional[Dict] = None,
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Run one autoregressive step.

        Args:
            last_token_ids: (B, 1) last generated token ids
            memory: encoder outputs (B, S, d_model)
            cache: optional dict with previous cached keys/values and 'past_length'.

        Returns:
            logits: (B, vocab_size) logits for the next-token prediction
            new_cache: updated cache dictionary
        """
        device = memory.device
        B = last_token_ids.size(0)

        if cache is None:
            cache = {}
        past_len = int(cache.get("past_length", 0))

        # 1) Embed last token and add positional encoding for position `past_len`
        # T5/FLAN-T5 does NOT scale embeddings by sqrt(d_model)
        x = self.embedding(last_token_ids)  # (B,1,d)

        # Handle positional encoding for single step
        # Note: When using relative position bias (T5-style), pos_encoder is None
        if self.pos_encoder is not None:
            if hasattr(self.pos_encoder, "pe"):
                # Sinusoidal: use buffer directly
                pe: torch.Tensor = self.pos_encoder.pe  # type: ignore[union-attr]
                pos_idx = past_len
                if pos_idx >= pe.size(1):
                    raise RuntimeError(f"pos_idx {pos_idx} exceeds max_len {pe.size(1)}")
                x = x + pe[:, pos_idx : pos_idx + 1, :].to(device)
            elif hasattr(self.pos_encoder, "embeddings"):
                # Learned: lookup specific position
                # Create position ids: [past_len]
                pos_idx_t = torch.tensor([past_len], dtype=torch.long, device=device)
                # Lookup embedding: (1, d_model)
                pos_emb = self.pos_encoder.embeddings(pos_idx_t)  # type: ignore[union-attr]
                # Add to input: (B, 1, d_model) + (1, 1, d_model) broadcast
                x = x + pos_emb.unsqueeze(0)
                x = self.pos_encoder.dropout(x)  # type: ignore[union-attr]
            else:
                # fallback: call pos_encoder (likely incorrect for step-by-step if it assumes pos 0)
                x = self.pos_encoder(x)
        # When pos_encoder is None (relative position bias mode), we skip positional encoding
        # The position information is provided via relative_position_bias in attention

        # We will update new_cache incrementally
        new_cache = dict(cache)  # shallow copy
        new_cache["past_length"] = past_len + 1

        # Optional: memory_mask could be supplied in cache under 'memory_mask'
        memory_mask = new_cache.get("memory_mask", None)
        if memory_mask is not None:
            memory_mask = memory_mask.to(dtype=torch.bool, device=device)
            # expand (B, S) -> (B,1,1,S) if necessary
            if memory_mask.dim() == 2:
                memory_mask = memory_mask.unsqueeze(1).unsqueeze(1)
            elif memory_mask.dim() == 3:
                memory_mask = memory_mask.unsqueeze(1)

        # Compute position biases for incremental step (T5-style)
        # For step mode: query_length=1, but actual position is past_len
        # Self-attention: query at position past_len attends to keys at positions 0..past_len
        # Note: T5 uses relative position bias for self-attention but NOT for cross-attention
        if self.use_relative_position_bias and self.self_relative_position_bias is not None:
            # Self-attention bias: query_length=1, key_length=past_len+1, offset=past_len
            self_position_bias = self.self_relative_position_bias(
                query_length=1,
                key_length=past_len + 1,
                device=device,
                query_position_offset=past_len,
            )  # (1, num_heads, 1, past_len+1)
        else:
            self_position_bias = None
        # Cross-attention position bias is None for T5 (see T5 paper/implementation)
        cross_position_bias = None

        # Iterate layers, updating caches and computing output for current token only
        layer_input = x  # (B,1,d_model)
        for i, layer_module in enumerate(self.layers):
            layer = cast(TransformerDecoderLayer, layer_module)
            # -------------------
            # 1) Self-attention (incremental)
            # -------------------
            # Normalize input for pre-LN
            x_norm = layer.norm1(layer_input)  # (B,1,d)

            # Project Q,K,V for the new token
            Q_new = layer.self_attn.W_Q(x_norm)  # (B,1,d_model)
            K_new = layer.self_attn.W_K(x_norm)
            V_new = layer.self_attn.W_V(x_norm)

            # Reshape into heads: (B, num_heads, 1, d_k)
            B_, Lq, _ = Q_new.shape
            num_heads = layer.self_attn.num_heads
            d_k = layer.self_attn.d_k
            Qh = Q_new.view(B_, Lq, num_heads, d_k).transpose(1, 2)  # (B, num_heads, 1, d_k)
            Kh = K_new.view(B_, Lq, num_heads, d_k).transpose(1, 2)
            Vh = V_new.view(B_, Lq, num_heads, d_k).transpose(1, 2)

            # Retrieve cached keys/values for self-attn (if exist)
            cache_k = cache.get(f"self_k_{i}", None)
            cache_v = cache.get(f"self_v_{i}", None)
            if cache_k is None or cache_v is None:
                K_all = Kh  # (B, H, 1, d_k)
                V_all = Vh
            else:
                # concat along sequence dim (dim=2)
                K_all = torch.cat([cache_k.to(device), Kh], dim=2)
                V_all = torch.cat([cache_v.to(device), Vh], dim=2)

            # Store updated caches
            new_cache[f"self_k_{i}"] = K_all
            new_cache[f"self_v_{i}"] = V_all

            # Compute attention for the new token: Query length = 1, Key length = K_all.size(2)
            # Explicitly create mask for consistency with forward pass (though None should work)
            # mask=True means attend.
            step_mask = torch.ones(B_, 1, 1, K_all.size(2), dtype=torch.bool, device=device)
            attn_out_heads, self_attn_w = layer.self_attn.attention(
                Qh, K_all, V_all, mask=step_mask, position_bias=self_position_bias
            )
            # attn_out_heads: (B, H, 1, d_k)
            # concat heads, project out
            attn_out = attn_out_heads.transpose(1, 2).contiguous().view(B_, 1, num_heads * d_k)
            attn_out = layer.self_attn.W_O(attn_out)  # (B,1,d_model)
            attn_out = layer.self_attn.dropout(attn_out)
            layer_output = layer_input + layer.dropout1(attn_out)

            # -------------------
            # 2) Cross-attention (use cached memory projections if available)
            # -------------------
            x_norm2 = layer.norm2(layer_output)  # (B,1,d)
            # Ensure memory K/V are cached per layer
            mem_k = cache.get(f"mem_k_{i}", None)
            mem_v = cache.get(f"mem_v_{i}", None)
            if mem_k is None or mem_v is None:
                # project memory once for this layer and cache it
                # memory: (B, S, d_model)
                MK = layer.cross_attn.W_K(memory)  # (B, S, d_model)
                MV = layer.cross_attn.W_V(memory)
                Bm, S, _ = MK.shape
                MKh = MK.view(Bm, S, layer.cross_attn.num_heads, layer.cross_attn.d_k).transpose(
                    1, 2
                )  # (B,H,S,d_k)
                MVh = MV.view(Bm, S, layer.cross_attn.num_heads, layer.cross_attn.d_k).transpose(
                    1, 2
                )
                mem_k = MKh
                mem_v = MVh
                new_cache[f"mem_k_{i}"] = mem_k
                new_cache[f"mem_v_{i}"] = mem_v
            else:
                mem_k = mem_k.to(device)
                mem_v = mem_v.to(device)

            Qc = layer.cross_attn.W_Q(x_norm2)  # (B,1,d_model)
            Qch = Qc.view(B, 1, layer.cross_attn.num_heads, layer.cross_attn.d_k).transpose(
                1, 2
            )  # (B,H,1,d_k)

            cross_out_heads, cross_attn_w = layer.cross_attn.attention(
                Qch, mem_k, mem_v, mask=memory_mask, position_bias=cross_position_bias
            )
            cross_out = (
                cross_out_heads.transpose(1, 2)
                .contiguous()
                .view(B, 1, layer.cross_attn.num_heads * layer.cross_attn.d_k)
            )
            cross_out = layer.cross_attn.W_O(cross_out)  # (B,1,d_model)
            cross_out = layer.cross_attn.dropout(cross_out)
            layer_output = layer_output + layer.dropout2(cross_out)

            # -------------------
            # 3) Feed-forward (incremental)
            # -------------------
            x_norm3 = layer.norm3(layer_output)
            ffn_out = layer.ffn(x_norm3)  # (B,1,d_model)
            layer_output = layer_output + layer.dropout3(ffn_out)

            # Prepare for next layer
            layer_input = layer_output

        # Final norm + output projection (for this single time step)
        out_norm = self.final_norm(layer_input)  # (B,1,d_model)
        logits = self.output_projection(out_norm)  # (B,1,vocab)
        logits = logits.squeeze(1)  # (B, vocab)

        return logits, new_cache