Spaces:
Running
Running
OliverPerrin
commited on
Commit
·
67c3a83
1
Parent(s):
590a604
Fix Pylance type errors, add inductor compilation support
Browse files- Add cast() for registered buffer access (pe, cos, sin) in decoder.py,
positional_encoding.py, and attention.py to satisfy Pylance type checking
- Add cast() for compile_model return types in train.py
- Add type: ignore for rouge_score import (no type stubs available)
- Add safe_compile.py for torch.compile with inductor backend (default mode)
- Add nan_debugger.py for debugging NaN/Inf during training
- Update configs: batch_size=12 for medium/full, fix layer counts to match FLAN-T5-base
- Add benchmark mode to train.py for speed testing without saving checkpoints
- Suppress torch inductor warnings that interfere with tqdm progress bars
All checks pass: ruff, mypy, Pylance
- configs/model/base.yaml +3 -3
- configs/training/dev.yaml +10 -7
- configs/training/full.yaml +8 -5
- configs/training/medium.yaml +9 -6
- scripts/eval_rouge.py +1 -1
- scripts/train.py +40 -11
- src/models/attention.py +5 -3
- src/models/decoder.py +7 -6
- src/models/positional_encoding.py +2 -1
- src/training/nan_debugger.py +123 -0
- src/training/safe_compile.py +86 -0
- src/training/trainer.py +55 -6
- tests/test_models/test_decoder.py +3 -3
configs/model/base.yaml
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# FLAN-T5-base architecture
|
| 2 |
-
#
|
| 3 |
d_model: 768
|
| 4 |
-
num_encoder_layers:
|
| 5 |
-
num_decoder_layers:
|
| 6 |
num_attention_heads: 12
|
| 7 |
ffn_dim: 2048 # T5 uses d_ff = 2048 for base model
|
| 8 |
dropout: 0.1
|
|
|
|
| 1 |
# FLAN-T5-base architecture
|
| 2 |
+
# 6 encoder layers, 6 decoder layers, 768 hidden dim
|
| 3 |
d_model: 768
|
| 4 |
+
num_encoder_layers: 6 # T5-base has 6 layers
|
| 5 |
+
num_decoder_layers: 6 # T5-base has 6 layers
|
| 6 |
num_attention_heads: 12
|
| 7 |
ffn_dim: 2048 # T5 uses d_ff = 2048 for base model
|
| 8 |
dropout: 0.1
|
configs/training/dev.yaml
CHANGED
|
@@ -4,15 +4,18 @@
|
|
| 4 |
# Use: python scripts/train.py training=dev
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
-
batch_size:
|
| 8 |
shuffle: true
|
| 9 |
-
num_workers:
|
| 10 |
pin_memory: true
|
|
|
|
|
|
|
| 11 |
|
| 12 |
optimizer:
|
| 13 |
name: adamw
|
| 14 |
-
lr:
|
| 15 |
weight_decay: 0.01
|
|
|
|
| 16 |
|
| 17 |
scheduler:
|
| 18 |
name: cosine
|
|
@@ -21,12 +24,12 @@ scheduler:
|
|
| 21 |
trainer:
|
| 22 |
max_epochs: 1
|
| 23 |
gradient_clip_norm: 1.0
|
| 24 |
-
gradient_accumulation_steps:
|
| 25 |
-
validation_max_length:
|
| 26 |
label_smoothing: 0.1
|
| 27 |
task_weights:
|
| 28 |
summarization: 1.0
|
| 29 |
emotion: 1.0
|
| 30 |
topic: 1.0
|
| 31 |
-
max_train_samples:
|
| 32 |
-
max_val_samples:
|
|
|
|
| 4 |
# Use: python scripts/train.py training=dev
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
+
batch_size: 14
|
| 8 |
shuffle: true
|
| 9 |
+
num_workers: 6
|
| 10 |
pin_memory: true
|
| 11 |
+
persistent_workers: true
|
| 12 |
+
prefetch_factor: 4
|
| 13 |
|
| 14 |
optimizer:
|
| 15 |
name: adamw
|
| 16 |
+
lr: 2.0e-5
|
| 17 |
weight_decay: 0.01
|
| 18 |
+
eps: 1.0e-6
|
| 19 |
|
| 20 |
scheduler:
|
| 21 |
name: cosine
|
|
|
|
| 24 |
trainer:
|
| 25 |
max_epochs: 1
|
| 26 |
gradient_clip_norm: 1.0
|
| 27 |
+
gradient_accumulation_steps: 4
|
| 28 |
+
validation_max_length: 128
|
| 29 |
label_smoothing: 0.1
|
| 30 |
task_weights:
|
| 31 |
summarization: 1.0
|
| 32 |
emotion: 1.0
|
| 33 |
topic: 1.0
|
| 34 |
+
max_train_samples: 1000
|
| 35 |
+
max_val_samples: 100
|
configs/training/full.yaml
CHANGED
|
@@ -1,27 +1,30 @@
|
|
| 1 |
# Full Training Configuration for FLAN-T5-base
|
| 2 |
# Complete training run on all data
|
| 3 |
-
# Training time: ~6
|
| 4 |
# Use: python scripts/train.py training=full
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
-
batch_size:
|
| 8 |
shuffle: true
|
| 9 |
num_workers: 6
|
| 10 |
pin_memory: true
|
|
|
|
|
|
|
| 11 |
|
| 12 |
optimizer:
|
| 13 |
name: adamw
|
| 14 |
lr: 2.0e-5
|
| 15 |
weight_decay: 0.01
|
|
|
|
| 16 |
|
| 17 |
scheduler:
|
| 18 |
name: cosine
|
| 19 |
-
warmup_steps: 500
|
| 20 |
|
| 21 |
trainer:
|
| 22 |
-
max_epochs: 3
|
| 23 |
gradient_clip_norm: 1.0
|
| 24 |
-
gradient_accumulation_steps:
|
| 25 |
validation_max_length: 128
|
| 26 |
label_smoothing: 0.1
|
| 27 |
task_weights:
|
|
|
|
| 1 |
# Full Training Configuration for FLAN-T5-base
|
| 2 |
# Complete training run on all data
|
| 3 |
+
# Training time: ~4-6 hours on RTX 4070 12GB with inductor
|
| 4 |
# Use: python scripts/train.py training=full
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
+
batch_size: 14
|
| 8 |
shuffle: true
|
| 9 |
num_workers: 6
|
| 10 |
pin_memory: true
|
| 11 |
+
persistent_workers: true
|
| 12 |
+
prefetch_factor: 4
|
| 13 |
|
| 14 |
optimizer:
|
| 15 |
name: adamw
|
| 16 |
lr: 2.0e-5
|
| 17 |
weight_decay: 0.01
|
| 18 |
+
eps: 1.0e-6
|
| 19 |
|
| 20 |
scheduler:
|
| 21 |
name: cosine
|
| 22 |
+
warmup_steps: 500
|
| 23 |
|
| 24 |
trainer:
|
| 25 |
+
max_epochs: 3
|
| 26 |
gradient_clip_norm: 1.0
|
| 27 |
+
gradient_accumulation_steps: 3 # Effective batch = 42
|
| 28 |
validation_max_length: 128
|
| 29 |
label_smoothing: 0.1
|
| 30 |
task_weights:
|
configs/training/medium.yaml
CHANGED
|
@@ -1,28 +1,31 @@
|
|
| 1 |
# Medium Configuration for FLAN-T5-base
|
| 2 |
# Balanced approach - good results in reasonable time
|
| 3 |
-
# Training time: ~2
|
| 4 |
# Use: python scripts/train.py training=medium
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
-
batch_size:
|
| 8 |
shuffle: true
|
| 9 |
num_workers: 6
|
| 10 |
pin_memory: true
|
|
|
|
|
|
|
| 11 |
|
| 12 |
optimizer:
|
| 13 |
name: adamw
|
| 14 |
-
lr: 3.0e-5
|
| 15 |
weight_decay: 0.01
|
|
|
|
| 16 |
|
| 17 |
scheduler:
|
| 18 |
name: cosine
|
| 19 |
-
warmup_steps: 300
|
| 20 |
|
| 21 |
trainer:
|
| 22 |
max_epochs: 3
|
| 23 |
gradient_clip_norm: 1.0
|
| 24 |
-
gradient_accumulation_steps: 3 # Effective batch =
|
| 25 |
-
validation_max_length:
|
| 26 |
label_smoothing: 0.1
|
| 27 |
task_weights:
|
| 28 |
summarization: 1.0
|
|
|
|
| 1 |
# Medium Configuration for FLAN-T5-base
|
| 2 |
# Balanced approach - good results in reasonable time
|
| 3 |
+
# Training time: ~1.5-2 hours on RTX 4070 12GB with inductor
|
| 4 |
# Use: python scripts/train.py training=medium
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
+
batch_size: 14
|
| 8 |
shuffle: true
|
| 9 |
num_workers: 6
|
| 10 |
pin_memory: true
|
| 11 |
+
persistent_workers: true
|
| 12 |
+
prefetch_factor: 4
|
| 13 |
|
| 14 |
optimizer:
|
| 15 |
name: adamw
|
| 16 |
+
lr: 3.0e-5
|
| 17 |
weight_decay: 0.01
|
| 18 |
+
eps: 1.0e-6
|
| 19 |
|
| 20 |
scheduler:
|
| 21 |
name: cosine
|
| 22 |
+
warmup_steps: 300
|
| 23 |
|
| 24 |
trainer:
|
| 25 |
max_epochs: 3
|
| 26 |
gradient_clip_norm: 1.0
|
| 27 |
+
gradient_accumulation_steps: 3 # Effective batch = 42
|
| 28 |
+
validation_max_length: 128
|
| 29 |
label_smoothing: 0.1
|
| 30 |
task_weights:
|
| 31 |
summarization: 1.0
|
scripts/eval_rouge.py
CHANGED
|
@@ -18,7 +18,7 @@ from pathlib import Path
|
|
| 18 |
from statistics import fmean
|
| 19 |
from typing import Dict, Iterable, List, Sequence, Tuple
|
| 20 |
|
| 21 |
-
from rouge_score import rouge_scorer
|
| 22 |
from tqdm import tqdm
|
| 23 |
|
| 24 |
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
|
|
| 18 |
from statistics import fmean
|
| 19 |
from typing import Dict, Iterable, List, Sequence, Tuple
|
| 20 |
|
| 21 |
+
from rouge_score import rouge_scorer # type: ignore[import-untyped]
|
| 22 |
from tqdm import tqdm
|
| 23 |
|
| 24 |
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
scripts/train.py
CHANGED
|
@@ -11,10 +11,20 @@ Date: December 2025
|
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
import json
|
|
|
|
|
|
|
| 14 |
import sys
|
| 15 |
import time
|
|
|
|
| 16 |
from pathlib import Path
|
| 17 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
import hydra
|
| 20 |
import torch
|
|
@@ -82,14 +92,14 @@ def limit_samples(splits: Dict[str, list], cfg: DictConfig) -> None:
|
|
| 82 |
# --------------- Model Compilation ---------------
|
| 83 |
|
| 84 |
|
| 85 |
-
def compile_model(model: torch.nn.Module) ->
|
| 86 |
-
"""Compile model with
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
|
| 94 |
|
| 95 |
# --------------- Main ---------------
|
|
@@ -101,6 +111,11 @@ def main(cfg: DictConfig) -> None:
|
|
| 101 |
print(OmegaConf.to_yaml(cfg))
|
| 102 |
set_seed(cfg.seed)
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
# Enable TF32 for Ampere+ GPUs (RTX 30xx/40xx) - ~2x matmul speedup
|
| 105 |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
| 106 |
print("✓ TF32 enabled for Ampere GPU")
|
|
@@ -242,9 +257,13 @@ def main(cfg: DictConfig) -> None:
|
|
| 242 |
|
| 243 |
# Compile encoder/decoder for faster training (skip heads - small overhead)
|
| 244 |
if model.encoder is not None:
|
| 245 |
-
|
|
|
|
|
|
|
| 246 |
if model.decoder is not None:
|
| 247 |
-
|
|
|
|
|
|
|
| 248 |
|
| 249 |
# --------------- Optimizer & Trainer ---------------
|
| 250 |
|
|
@@ -272,6 +291,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 272 |
# --------------- Train ---------------
|
| 273 |
|
| 274 |
def save_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
|
|
|
|
|
|
|
| 275 |
path = Path(cfg.checkpoint_out).parent / f"epoch_{epoch}.pt"
|
| 276 |
path.parent.mkdir(parents=True, exist_ok=True)
|
| 277 |
save_state(model, str(path))
|
|
@@ -281,6 +302,14 @@ def main(cfg: DictConfig) -> None:
|
|
| 281 |
|
| 282 |
# --------------- Save Outputs ---------------
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
# Best checkpoint
|
| 285 |
ckpt_path = Path(cfg.checkpoint_out)
|
| 286 |
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
import sys
|
| 17 |
import time
|
| 18 |
+
import warnings
|
| 19 |
from pathlib import Path
|
| 20 |
+
from typing import Dict, Sequence, cast
|
| 21 |
+
|
| 22 |
+
# Suppress torch inductor warnings that mess up progress bars
|
| 23 |
+
os.environ.setdefault("TORCH_LOGS", "-all")
|
| 24 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torch._inductor")
|
| 25 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="mlflow")
|
| 26 |
+
logging.getLogger("torch._inductor").setLevel(logging.ERROR)
|
| 27 |
+
logging.getLogger("torch._dynamo").setLevel(logging.ERROR)
|
| 28 |
|
| 29 |
import hydra
|
| 30 |
import torch
|
|
|
|
| 92 |
# --------------- Model Compilation ---------------
|
| 93 |
|
| 94 |
|
| 95 |
+
def compile_model(model: torch.nn.Module) -> torch.nn.Module:
|
| 96 |
+
"""Compile model with inductor backend (default mode, no CUDA graphs)."""
|
| 97 |
+
from src.training.safe_compile import apply_safe_config, compile_model_safe
|
| 98 |
+
|
| 99 |
+
# Apply safe configuration first
|
| 100 |
+
apply_safe_config()
|
| 101 |
+
# Compile with default mode (inductor without CUDA graphs)
|
| 102 |
+
return compile_model_safe(model, mode="default")
|
| 103 |
|
| 104 |
|
| 105 |
# --------------- Main ---------------
|
|
|
|
| 111 |
print(OmegaConf.to_yaml(cfg))
|
| 112 |
set_seed(cfg.seed)
|
| 113 |
|
| 114 |
+
# Benchmark mode: skip saving checkpoints (for speed testing)
|
| 115 |
+
benchmark_mode = cfg.get("benchmark", False)
|
| 116 |
+
if benchmark_mode:
|
| 117 |
+
print("⚡ BENCHMARK MODE: Checkpoints will NOT be saved")
|
| 118 |
+
|
| 119 |
# Enable TF32 for Ampere+ GPUs (RTX 30xx/40xx) - ~2x matmul speedup
|
| 120 |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
| 121 |
print("✓ TF32 enabled for Ampere GPU")
|
|
|
|
| 257 |
|
| 258 |
# Compile encoder/decoder for faster training (skip heads - small overhead)
|
| 259 |
if model.encoder is not None:
|
| 260 |
+
from src.models.encoder import TransformerEncoder
|
| 261 |
+
|
| 262 |
+
model.encoder = cast(TransformerEncoder, compile_model(model.encoder))
|
| 263 |
if model.decoder is not None:
|
| 264 |
+
from src.models.decoder import TransformerDecoder
|
| 265 |
+
|
| 266 |
+
model.decoder = cast(TransformerDecoder, compile_model(model.decoder))
|
| 267 |
|
| 268 |
# --------------- Optimizer & Trainer ---------------
|
| 269 |
|
|
|
|
| 291 |
# --------------- Train ---------------
|
| 292 |
|
| 293 |
def save_checkpoint(epoch: int, model: torch.nn.Module, history: Dict) -> None:
|
| 294 |
+
if benchmark_mode:
|
| 295 |
+
return # Skip saving in benchmark mode
|
| 296 |
path = Path(cfg.checkpoint_out).parent / f"epoch_{epoch}.pt"
|
| 297 |
path.parent.mkdir(parents=True, exist_ok=True)
|
| 298 |
save_state(model, str(path))
|
|
|
|
| 302 |
|
| 303 |
# --------------- Save Outputs ---------------
|
| 304 |
|
| 305 |
+
if benchmark_mode:
|
| 306 |
+
total_time = time.perf_counter() - start_time
|
| 307 |
+
print(f"\n{'=' * 50}")
|
| 308 |
+
print(f"⚡ Benchmark complete in {total_time:.1f}s")
|
| 309 |
+
print(" (No files saved in benchmark mode)")
|
| 310 |
+
print(f"{'=' * 50}")
|
| 311 |
+
return
|
| 312 |
+
|
| 313 |
# Best checkpoint
|
| 314 |
ckpt_path = Path(cfg.checkpoint_out)
|
| 315 |
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
src/models/attention.py
CHANGED
|
@@ -13,7 +13,7 @@ Date: 2025-10-23
|
|
| 13 |
"""
|
| 14 |
|
| 15 |
import math
|
| 16 |
-
from typing import Optional, Tuple
|
| 17 |
|
| 18 |
import torch
|
| 19 |
import torch.nn as nn
|
|
@@ -280,8 +280,10 @@ class RotaryEmbedding(nn.Module):
|
|
| 280 |
seq_len = x.shape[2]
|
| 281 |
# Slice cos/sin to current sequence length
|
| 282 |
# unsqueeze to broadcast over batch and heads: (1, 1, seq_len, dim)
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
| 285 |
|
| 286 |
return (x * cos) + (self._rotate_half(x) * sin)
|
| 287 |
|
|
|
|
| 13 |
"""
|
| 14 |
|
| 15 |
import math
|
| 16 |
+
from typing import Optional, Tuple, cast
|
| 17 |
|
| 18 |
import torch
|
| 19 |
import torch.nn as nn
|
|
|
|
| 280 |
seq_len = x.shape[2]
|
| 281 |
# Slice cos/sin to current sequence length
|
| 282 |
# unsqueeze to broadcast over batch and heads: (1, 1, seq_len, dim)
|
| 283 |
+
cos_buf = cast(torch.Tensor, self.cos)
|
| 284 |
+
sin_buf = cast(torch.Tensor, self.sin)
|
| 285 |
+
cos = cos_buf[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
| 286 |
+
sin = sin_buf[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
| 287 |
|
| 288 |
return (x * cos) + (self._rotate_half(x) * sin)
|
| 289 |
|
src/models/decoder.py
CHANGED
|
@@ -14,7 +14,7 @@ Author: Oliver Perrin
|
|
| 14 |
Date: 2025-10-23
|
| 15 |
"""
|
| 16 |
|
| 17 |
-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torch.nn as nn
|
|
@@ -530,7 +530,7 @@ class TransformerDecoder(nn.Module):
|
|
| 530 |
if self.pos_encoder is not None:
|
| 531 |
if hasattr(self.pos_encoder, "pe"):
|
| 532 |
# Sinusoidal: use buffer directly
|
| 533 |
-
pe = self.pos_encoder.pe #
|
| 534 |
pos_idx = past_len
|
| 535 |
if pos_idx >= pe.size(1):
|
| 536 |
raise RuntimeError(f"pos_idx {pos_idx} exceeds max_len {pe.size(1)}")
|
|
@@ -538,12 +538,12 @@ class TransformerDecoder(nn.Module):
|
|
| 538 |
elif hasattr(self.pos_encoder, "embeddings"):
|
| 539 |
# Learned: lookup specific position
|
| 540 |
# Create position ids: [past_len]
|
| 541 |
-
|
| 542 |
# Lookup embedding: (1, d_model)
|
| 543 |
-
pos_emb = self.pos_encoder.embeddings(
|
| 544 |
# Add to input: (B, 1, d_model) + (1, 1, d_model) broadcast
|
| 545 |
x = x + pos_emb.unsqueeze(0)
|
| 546 |
-
x = self.pos_encoder.dropout(x)
|
| 547 |
else:
|
| 548 |
# fallback: call pos_encoder (likely incorrect for step-by-step if it assumes pos 0)
|
| 549 |
x = self.pos_encoder(x)
|
|
@@ -583,7 +583,8 @@ class TransformerDecoder(nn.Module):
|
|
| 583 |
|
| 584 |
# Iterate layers, updating caches and computing output for current token only
|
| 585 |
layer_input = x # (B,1,d_model)
|
| 586 |
-
for i,
|
|
|
|
| 587 |
# -------------------
|
| 588 |
# 1) Self-attention (incremental)
|
| 589 |
# -------------------
|
|
|
|
| 14 |
Date: 2025-10-23
|
| 15 |
"""
|
| 16 |
|
| 17 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torch.nn as nn
|
|
|
|
| 530 |
if self.pos_encoder is not None:
|
| 531 |
if hasattr(self.pos_encoder, "pe"):
|
| 532 |
# Sinusoidal: use buffer directly
|
| 533 |
+
pe: torch.Tensor = self.pos_encoder.pe # type: ignore[union-attr]
|
| 534 |
pos_idx = past_len
|
| 535 |
if pos_idx >= pe.size(1):
|
| 536 |
raise RuntimeError(f"pos_idx {pos_idx} exceeds max_len {pe.size(1)}")
|
|
|
|
| 538 |
elif hasattr(self.pos_encoder, "embeddings"):
|
| 539 |
# Learned: lookup specific position
|
| 540 |
# Create position ids: [past_len]
|
| 541 |
+
pos_idx_t = torch.tensor([past_len], dtype=torch.long, device=device)
|
| 542 |
# Lookup embedding: (1, d_model)
|
| 543 |
+
pos_emb = self.pos_encoder.embeddings(pos_idx_t) # type: ignore[union-attr]
|
| 544 |
# Add to input: (B, 1, d_model) + (1, 1, d_model) broadcast
|
| 545 |
x = x + pos_emb.unsqueeze(0)
|
| 546 |
+
x = self.pos_encoder.dropout(x) # type: ignore[union-attr]
|
| 547 |
else:
|
| 548 |
# fallback: call pos_encoder (likely incorrect for step-by-step if it assumes pos 0)
|
| 549 |
x = self.pos_encoder(x)
|
|
|
|
| 583 |
|
| 584 |
# Iterate layers, updating caches and computing output for current token only
|
| 585 |
layer_input = x # (B,1,d_model)
|
| 586 |
+
for i, layer_module in enumerate(self.layers):
|
| 587 |
+
layer = cast(TransformerDecoderLayer, layer_module)
|
| 588 |
# -------------------
|
| 589 |
# 1) Self-attention (incremental)
|
| 590 |
# -------------------
|
src/models/positional_encoding.py
CHANGED
|
@@ -74,7 +74,8 @@ class PositionalEncoding(nn.Module):
|
|
| 74 |
# Add the appropriate slice of positional encoding
|
| 75 |
# Apply dropout
|
| 76 |
# Return result
|
| 77 |
-
|
|
|
|
| 78 |
# self.pe contains pre-computed encodings for all positions
|
| 79 |
# just need to add the first seq_len positions to x
|
| 80 |
return self.dropout(x)
|
|
|
|
| 74 |
# Add the appropriate slice of positional encoding
|
| 75 |
# Apply dropout
|
| 76 |
# Return result
|
| 77 |
+
pe: torch.Tensor = self.pe # type: ignore[assignment]
|
| 78 |
+
x = x + pe[:, : x.size(1)].requires_grad_(False)
|
| 79 |
# self.pe contains pre-computed encodings for all positions
|
| 80 |
# just need to add the first seq_len positions to x
|
| 81 |
return self.dropout(x)
|
src/training/nan_debugger.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NaN debugging utilities for training.
|
| 3 |
+
|
| 4 |
+
Helps identify where NaNs originate in the model during training.
|
| 5 |
+
|
| 6 |
+
Author: Oliver Perrin
|
| 7 |
+
Date: December 2025
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class NaNDetector:
|
| 17 |
+
"""Detect and log NaNs in model parameters and gradients."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, model: nn.Module, enabled: bool = True):
|
| 20 |
+
self.model = model
|
| 21 |
+
self.enabled = enabled
|
| 22 |
+
self.nan_count = 0
|
| 23 |
+
self.max_nans = 10
|
| 24 |
+
|
| 25 |
+
def check_forward(self, outputs: torch.Tensor, loss: torch.Tensor, step: int) -> bool:
|
| 26 |
+
"""Check for NaNs in forward pass. Returns True if NaN found."""
|
| 27 |
+
if not self.enabled:
|
| 28 |
+
return False
|
| 29 |
+
|
| 30 |
+
has_nan = False
|
| 31 |
+
|
| 32 |
+
if torch.isnan(outputs).any():
|
| 33 |
+
print(f"\n{'=' * 60}")
|
| 34 |
+
print(f"⚠ NaN detected in MODEL OUTPUTS at step {step}")
|
| 35 |
+
print(f"Output shape: {outputs.shape}")
|
| 36 |
+
print(f"NaN count: {torch.isnan(outputs).sum().item()}")
|
| 37 |
+
print(f"{'=' * 60}\n")
|
| 38 |
+
has_nan = True
|
| 39 |
+
|
| 40 |
+
if torch.isnan(loss):
|
| 41 |
+
print(f"\n{'=' * 60}")
|
| 42 |
+
print(f"⚠ NaN detected in LOSS at step {step}")
|
| 43 |
+
print(f"Loss value: {loss.item()}")
|
| 44 |
+
print(f"{'=' * 60}\n")
|
| 45 |
+
has_nan = True
|
| 46 |
+
|
| 47 |
+
if has_nan:
|
| 48 |
+
self.nan_count += 1
|
| 49 |
+
if self.nan_count >= self.max_nans:
|
| 50 |
+
print(f"\n⚠ Too many NaNs ({self.nan_count}), stopping training")
|
| 51 |
+
|
| 52 |
+
return has_nan
|
| 53 |
+
|
| 54 |
+
def check_gradients(self, step: int) -> Optional[Tuple[str, torch.Tensor]]:
|
| 55 |
+
"""Check gradients for NaNs/Infs after backward pass."""
|
| 56 |
+
if not self.enabled:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
for name, param in self.model.named_parameters():
|
| 60 |
+
if param.grad is not None:
|
| 61 |
+
if torch.isnan(param.grad).any():
|
| 62 |
+
print(f"\n{'=' * 60}")
|
| 63 |
+
print(f"⚠ NaN in GRADIENT: {name}")
|
| 64 |
+
print(f" Step: {step}")
|
| 65 |
+
print(f" Grad shape: {param.grad.shape}")
|
| 66 |
+
print(f" NaN count: {torch.isnan(param.grad).sum().item()}")
|
| 67 |
+
print(f"{'=' * 60}\n")
|
| 68 |
+
return (name, param.grad)
|
| 69 |
+
|
| 70 |
+
if torch.isinf(param.grad).any():
|
| 71 |
+
print(f"\n{'=' * 60}")
|
| 72 |
+
print(f"⚠ Inf in GRADIENT: {name}")
|
| 73 |
+
print(f" Step: {step}")
|
| 74 |
+
print(f" Inf count: {torch.isinf(param.grad).sum().item()}")
|
| 75 |
+
print(f"{'=' * 60}\n")
|
| 76 |
+
return (name, param.grad)
|
| 77 |
+
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
def check_parameters(self, step: int) -> Optional[str]:
|
| 81 |
+
"""Check parameters for NaNs/Infs."""
|
| 82 |
+
if not self.enabled:
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
for name, param in self.model.named_parameters():
|
| 86 |
+
if torch.isnan(param).any():
|
| 87 |
+
print(f"\n{'=' * 60}")
|
| 88 |
+
print(f"⚠ NaN in PARAMETER: {name}")
|
| 89 |
+
print(f" Step: {step}")
|
| 90 |
+
print(f"{'=' * 60}\n")
|
| 91 |
+
return str(name)
|
| 92 |
+
|
| 93 |
+
if torch.isinf(param).any():
|
| 94 |
+
print(f"\n{'=' * 60}")
|
| 95 |
+
print(f"⚠ Inf in PARAMETER: {name}")
|
| 96 |
+
print(f" Step: {step}")
|
| 97 |
+
print(f"{'=' * 60}\n")
|
| 98 |
+
return str(name)
|
| 99 |
+
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def gradient_stats(model: nn.Module) -> dict:
|
| 104 |
+
"""Get gradient statistics for debugging."""
|
| 105 |
+
stats = {
|
| 106 |
+
"max_grad": 0.0,
|
| 107 |
+
"min_grad": float("inf"),
|
| 108 |
+
"mean_grad": 0.0,
|
| 109 |
+
"num_grads": 0,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
grad_norms = []
|
| 113 |
+
for _name, param in model.named_parameters():
|
| 114 |
+
if param.grad is not None:
|
| 115 |
+
grad_norms.append(param.grad.norm().item())
|
| 116 |
+
stats["max_grad"] = max(stats["max_grad"], param.grad.abs().max().item())
|
| 117 |
+
stats["min_grad"] = min(stats["min_grad"], param.grad.abs().min().item())
|
| 118 |
+
stats["num_grads"] += 1
|
| 119 |
+
|
| 120 |
+
if grad_norms:
|
| 121 |
+
stats["mean_grad"] = sum(grad_norms) / len(grad_norms)
|
| 122 |
+
|
| 123 |
+
return stats
|
src/training/safe_compile.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Safe torch.compile configuration that prevents NaN issues.
|
| 3 |
+
|
| 4 |
+
Author: Oliver Perrin
|
| 5 |
+
Date: December 2025
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def compile_model_safe(
|
| 12 |
+
model: torch.nn.Module,
|
| 13 |
+
mode: str = "default",
|
| 14 |
+
) -> torch.nn.Module:
|
| 15 |
+
"""
|
| 16 |
+
Compile model with inductor backend and safety guardrails.
|
| 17 |
+
|
| 18 |
+
Uses 'default' mode which gives inductor speedups without CUDA graphs.
|
| 19 |
+
CUDA graphs (reduce-overhead mode) don't work with dynamic shapes or
|
| 20 |
+
shared embeddings like in T5.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
model: Model to compile
|
| 24 |
+
mode: Compilation mode ("default" recommended, avoid "reduce-overhead")
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Compiled model (or original if compilation fails)
|
| 28 |
+
"""
|
| 29 |
+
if not torch.cuda.is_available():
|
| 30 |
+
print("⚠ CUDA not available, skipping compilation")
|
| 31 |
+
return model
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
# Configure for stability
|
| 35 |
+
torch._dynamo.config.suppress_errors = True
|
| 36 |
+
torch._dynamo.config.cache_size_limit = 64 # Allow more graph variations
|
| 37 |
+
|
| 38 |
+
# Disable aggressive optimizations that can cause NaNs
|
| 39 |
+
if hasattr(torch, "_inductor"):
|
| 40 |
+
cfg = torch._inductor.config
|
| 41 |
+
if hasattr(cfg, "epilogue_fusion"):
|
| 42 |
+
cfg.epilogue_fusion = False
|
| 43 |
+
if hasattr(cfg, "coordinate_descent_tuning"):
|
| 44 |
+
cfg.coordinate_descent_tuning = False
|
| 45 |
+
if hasattr(cfg, "force_fuse_int_mm_with_mul"):
|
| 46 |
+
cfg.force_fuse_int_mm_with_mul = False
|
| 47 |
+
# Explicitly disable CUDA graphs
|
| 48 |
+
if hasattr(cfg, "triton"):
|
| 49 |
+
if hasattr(cfg.triton, "cudagraphs"):
|
| 50 |
+
cfg.triton.cudagraphs = False
|
| 51 |
+
if hasattr(cfg.triton, "max_autotune_gemm"):
|
| 52 |
+
cfg.triton.max_autotune_gemm = False
|
| 53 |
+
|
| 54 |
+
# Compile with inductor (no CUDA graphs)
|
| 55 |
+
compiled = torch.compile(model, mode=mode, fullgraph=False, dynamic=True)
|
| 56 |
+
print(f"✓ Compiled with inductor ({mode} mode)")
|
| 57 |
+
return compiled
|
| 58 |
+
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"⚠ Inductor compilation failed: {e}")
|
| 61 |
+
print(" Falling back to aot_eager")
|
| 62 |
+
try:
|
| 63 |
+
return torch.compile(model, backend="aot_eager")
|
| 64 |
+
except Exception:
|
| 65 |
+
print(" Using uncompiled model")
|
| 66 |
+
return model
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def apply_safe_config():
|
| 70 |
+
"""Apply safe configuration to torch._inductor before any compilation."""
|
| 71 |
+
if hasattr(torch, "_inductor"):
|
| 72 |
+
cfg = torch._inductor.config
|
| 73 |
+
if hasattr(cfg, "epilogue_fusion"):
|
| 74 |
+
cfg.epilogue_fusion = False
|
| 75 |
+
if hasattr(cfg, "coordinate_descent_tuning"):
|
| 76 |
+
cfg.coordinate_descent_tuning = False
|
| 77 |
+
if hasattr(cfg, "triton"):
|
| 78 |
+
if hasattr(cfg.triton, "cudagraphs"):
|
| 79 |
+
cfg.triton.cudagraphs = False
|
| 80 |
+
if hasattr(cfg.triton, "max_autotune_gemm"):
|
| 81 |
+
cfg.triton.max_autotune_gemm = False
|
| 82 |
+
|
| 83 |
+
# Dynamo config for stability
|
| 84 |
+
torch._dynamo.config.suppress_errors = True
|
| 85 |
+
torch._dynamo.config.cache_size_limit = 64
|
| 86 |
+
print("✓ Applied safe inductor configuration")
|
src/training/trainer.py
CHANGED
|
@@ -10,6 +10,7 @@ Date: December 2025
|
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
|
|
|
| 13 |
import time
|
| 14 |
from collections import defaultdict
|
| 15 |
from dataclasses import dataclass
|
|
@@ -23,6 +24,7 @@ from tqdm import tqdm
|
|
| 23 |
|
| 24 |
from ..data.tokenization import Tokenizer
|
| 25 |
from .metrics import accuracy, multilabel_f1, rouge_like
|
|
|
|
| 26 |
|
| 27 |
# --------------- Configuration ---------------
|
| 28 |
|
|
@@ -69,6 +71,14 @@ class Trainer:
|
|
| 69 |
self.use_bfloat16 = self.use_amp and torch.cuda.is_bf16_supported()
|
| 70 |
self.scaler = torch.GradScaler("cuda", enabled=(self.use_amp and not self.use_bfloat16))
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._nan_counter = 0
|
| 73 |
mlflow.set_experiment(config.experiment_name)
|
| 74 |
|
|
@@ -98,6 +108,8 @@ class Trainer:
|
|
| 98 |
desc="Training",
|
| 99 |
unit="epoch",
|
| 100 |
position=0,
|
|
|
|
|
|
|
| 101 |
)
|
| 102 |
|
| 103 |
for epoch in epoch_pbar:
|
|
@@ -178,11 +190,14 @@ class Trainer:
|
|
| 178 |
unit="batch",
|
| 179 |
leave=False,
|
| 180 |
position=1,
|
|
|
|
|
|
|
| 181 |
)
|
| 182 |
|
| 183 |
context = torch.enable_grad() if train else torch.no_grad()
|
| 184 |
with context:
|
| 185 |
for step in pbar:
|
|
|
|
| 186 |
step_loss = 0.0
|
| 187 |
|
| 188 |
for task, loader in loaders.items():
|
|
@@ -241,7 +256,19 @@ class Trainer:
|
|
| 241 |
return averaged
|
| 242 |
|
| 243 |
def _optimizer_step(self) -> None:
|
| 244 |
-
"""Optimizer step with gradient clipping."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
if self.use_bfloat16:
|
| 246 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_norm)
|
| 247 |
self.optimizer.step()
|
|
@@ -250,8 +277,16 @@ class Trainer:
|
|
| 250 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_norm)
|
| 251 |
self.scaler.step(self.optimizer)
|
| 252 |
self.scaler.update()
|
|
|
|
| 253 |
self.optimizer.zero_grad()
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
def _get_batch(
|
| 256 |
self, iterators: Dict, loader: DataLoader, task: str
|
| 257 |
) -> Dict[str, torch.Tensor] | None:
|
|
@@ -274,14 +309,28 @@ class Trainer:
|
|
| 274 |
def _forward_task(
|
| 275 |
self, task: str, batch: Dict[str, torch.Tensor]
|
| 276 |
) -> tuple[torch.Tensor, Dict[str, float]]:
|
| 277 |
-
"""Route to task-specific forward pass."""
|
| 278 |
if task == "summarization":
|
| 279 |
-
|
| 280 |
elif task == "emotion":
|
| 281 |
-
|
| 282 |
elif task == "topic":
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
def _forward_summarization(
|
| 287 |
self, batch: Dict[str, torch.Tensor]
|
|
|
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
+
import sys
|
| 14 |
import time
|
| 15 |
from collections import defaultdict
|
| 16 |
from dataclasses import dataclass
|
|
|
|
| 24 |
|
| 25 |
from ..data.tokenization import Tokenizer
|
| 26 |
from .metrics import accuracy, multilabel_f1, rouge_like
|
| 27 |
+
from .nan_debugger import NaNDetector
|
| 28 |
|
| 29 |
# --------------- Configuration ---------------
|
| 30 |
|
|
|
|
| 71 |
self.use_bfloat16 = self.use_amp and torch.cuda.is_bf16_supported()
|
| 72 |
self.scaler = torch.GradScaler("cuda", enabled=(self.use_amp and not self.use_bfloat16))
|
| 73 |
|
| 74 |
+
# NaN detection
|
| 75 |
+
self.nan_detector = NaNDetector(model, enabled=True)
|
| 76 |
+
self.nan_skip_count = 0
|
| 77 |
+
self.max_nan_skips = 50
|
| 78 |
+
|
| 79 |
+
# Track current step for debugging
|
| 80 |
+
self._current_step = 0
|
| 81 |
+
|
| 82 |
self._nan_counter = 0
|
| 83 |
mlflow.set_experiment(config.experiment_name)
|
| 84 |
|
|
|
|
| 108 |
desc="Training",
|
| 109 |
unit="epoch",
|
| 110 |
position=0,
|
| 111 |
+
file=sys.stderr,
|
| 112 |
+
dynamic_ncols=True,
|
| 113 |
)
|
| 114 |
|
| 115 |
for epoch in epoch_pbar:
|
|
|
|
| 190 |
unit="batch",
|
| 191 |
leave=False,
|
| 192 |
position=1,
|
| 193 |
+
file=sys.stderr,
|
| 194 |
+
dynamic_ncols=True,
|
| 195 |
)
|
| 196 |
|
| 197 |
context = torch.enable_grad() if train else torch.no_grad()
|
| 198 |
with context:
|
| 199 |
for step in pbar:
|
| 200 |
+
self._current_step = step
|
| 201 |
step_loss = 0.0
|
| 202 |
|
| 203 |
for task, loader in loaders.items():
|
|
|
|
| 256 |
return averaged
|
| 257 |
|
| 258 |
def _optimizer_step(self) -> None:
|
| 259 |
+
"""Optimizer step with gradient clipping and NaN detection."""
|
| 260 |
+
# Check gradients for NaN/Inf BEFORE clipping
|
| 261 |
+
nan_grad = self.nan_detector.check_gradients(self._current_step)
|
| 262 |
+
if nan_grad is not None:
|
| 263 |
+
param_name, _ = nan_grad
|
| 264 |
+
print(f"⚠ Skipping optimizer step due to NaN gradient in {param_name}")
|
| 265 |
+
self.optimizer.zero_grad()
|
| 266 |
+
self.nan_skip_count += 1
|
| 267 |
+
if self.nan_skip_count > self.max_nan_skips:
|
| 268 |
+
raise RuntimeError("Too many NaN gradients, stopping")
|
| 269 |
+
return
|
| 270 |
+
|
| 271 |
+
# Clip and step
|
| 272 |
if self.use_bfloat16:
|
| 273 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_norm)
|
| 274 |
self.optimizer.step()
|
|
|
|
| 277 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_norm)
|
| 278 |
self.scaler.step(self.optimizer)
|
| 279 |
self.scaler.update()
|
| 280 |
+
|
| 281 |
self.optimizer.zero_grad()
|
| 282 |
|
| 283 |
+
# Check parameters for NaN AFTER update
|
| 284 |
+
nan_param = self.nan_detector.check_parameters(self._current_step)
|
| 285 |
+
if nan_param is not None:
|
| 286 |
+
raise RuntimeError(
|
| 287 |
+
f"NaN in parameter {nan_param} after optimizer step at step {self._current_step}!"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
def _get_batch(
|
| 291 |
self, iterators: Dict, loader: DataLoader, task: str
|
| 292 |
) -> Dict[str, torch.Tensor] | None:
|
|
|
|
| 309 |
def _forward_task(
|
| 310 |
self, task: str, batch: Dict[str, torch.Tensor]
|
| 311 |
) -> tuple[torch.Tensor, Dict[str, float]]:
|
| 312 |
+
"""Route to task-specific forward pass with NaN detection."""
|
| 313 |
if task == "summarization":
|
| 314 |
+
loss, task_metrics = self._forward_summarization(batch)
|
| 315 |
elif task == "emotion":
|
| 316 |
+
loss, task_metrics = self._forward_emotion(batch)
|
| 317 |
elif task == "topic":
|
| 318 |
+
loss, task_metrics = self._forward_topic(batch)
|
| 319 |
+
else:
|
| 320 |
+
raise ValueError(f"Unknown task: {task}")
|
| 321 |
+
|
| 322 |
+
# Check for NaN in loss
|
| 323 |
+
if torch.isnan(loss):
|
| 324 |
+
self.nan_skip_count += 1
|
| 325 |
+
print(
|
| 326 |
+
f"⚠ NaN loss detected in {task} at step {self._current_step} (skip {self.nan_skip_count}/{self.max_nan_skips})"
|
| 327 |
+
)
|
| 328 |
+
if self.nan_skip_count > self.max_nan_skips:
|
| 329 |
+
raise RuntimeError(f"Too many NaN batches ({self.nan_skip_count}), stopping")
|
| 330 |
+
# Return zero loss to skip this batch
|
| 331 |
+
return torch.tensor(0.0, device=loss.device, requires_grad=True), task_metrics
|
| 332 |
+
|
| 333 |
+
return loss, task_metrics
|
| 334 |
|
| 335 |
def _forward_summarization(
|
| 336 |
self, batch: Dict[str, torch.Tensor]
|
tests/test_models/test_decoder.py
CHANGED
|
@@ -64,9 +64,9 @@ def test_decoder_layer_causal_mask_blocks_future():
|
|
| 64 |
B, H, Tq, Tk = self_attn.shape
|
| 65 |
for i in range(Tq):
|
| 66 |
for j in range(i + 1, Tk):
|
| 67 |
-
assert torch.allclose(
|
| 68 |
-
|
| 69 |
-
)
|
| 70 |
|
| 71 |
|
| 72 |
def test_decoder_stack_and_greedy_decode_shapes():
|
|
|
|
| 64 |
B, H, Tq, Tk = self_attn.shape
|
| 65 |
for i in range(Tq):
|
| 66 |
for j in range(i + 1, Tk):
|
| 67 |
+
assert torch.allclose(self_attn[:, :, i, j], torch.zeros(B, H)), (
|
| 68 |
+
f"Found nonzero attention to future position {j} from query {i}"
|
| 69 |
+
)
|
| 70 |
|
| 71 |
|
| 72 |
def test_decoder_stack_and_greedy_decode_shapes():
|