Spaces:
Sleeping
Sleeping
OliverPerrin
commited on
Commit
·
1ec7405
1
Parent(s):
701dfc6
Update LexiMind: improved training, model architecture, and evaluation
Browse files- README.md +72 -49
- configs/config.yaml +1 -0
- configs/model/base.yaml +7 -3
- configs/training/dev.yaml +23 -13
- configs/training/full.yaml +22 -11
- configs/training/medium.yaml +23 -13
- docs/api.md +0 -79
- docs/architecture.md +10 -0
- docs/training.md +0 -80
- outputs/training_history.json +51 -51
- scripts/demo_gradio.py +222 -98
- scripts/download_data.py +53 -0
- scripts/process_books.py +231 -0
- scripts/train.py +56 -8
- scripts/visualize_training.py +341 -0
- src/data/dataloader.py +7 -4
- src/inference/pipeline.py +3 -0
- src/inference/postprocessing.py +0 -14
- src/models/decoder.py +57 -16
- src/models/encoder.py +28 -6
- src/models/factory.py +57 -29
- src/models/t5_layer_norm.py +41 -0
- src/training/early_stopping.py +60 -0
- src/training/gradient_monitor.py +102 -0
- src/training/safe_compile.py +38 -72
- src/training/trainer.py +158 -3
- tests/test_inference/test_pipeline.py +26 -4
- tests/test_models/test_visualizations.py +1 -1
README.md
CHANGED
|
@@ -8,7 +8,7 @@ app_file: scripts/demo_gradio.py
|
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
LexiMind is a state-of-the-art Natural Language Processing model designed for complex document understanding. It features a **custom-built Transformer architecture** initialized with weights from Google's **FLAN-T5**, combining the flexibility of from-scratch implementation with the power of modern pre-trained models.
|
| 14 |
|
|
@@ -18,32 +18,37 @@ This project is built with industry-standard MLOps practices, including configur
|
|
| 18 |
|
| 19 |
## Core Features
|
| 20 |
|
| 21 |
-
*
|
| 22 |
-
*
|
| 23 |
-
*
|
| 24 |
|
| 25 |
## Model Architecture
|
| 26 |
|
| 27 |
LexiMind implements a **from-scratch Transformer** with modern architectural choices:
|
| 28 |
|
| 29 |
### Custom Transformer Features
|
| 30 |
-
|
| 31 |
-
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
-
|
|
|
|
| 35 |
|
| 36 |
### Pre-trained Weight Initialization
|
|
|
|
| 37 |
The model loads weights from **Google's FLAN-T5-base**, which provides:
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
|
| 42 |
### Multi-Task Learning
|
|
|
|
| 43 |
A shared encoder-decoder backbone with task-specific heads:
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
| 47 |
|
| 48 |
## Technical Specifications
|
| 49 |
|
|
@@ -64,29 +69,32 @@ A shared encoder-decoder backbone with task-specific heads:
|
|
| 64 |
|
| 65 |
### Prerequisites
|
| 66 |
|
| 67 |
-
*
|
| 68 |
-
*
|
| 69 |
-
*
|
| 70 |
-
*
|
| 71 |
|
| 72 |
### Installation
|
| 73 |
|
| 74 |
-
1.
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
```
|
| 84 |
|
| 85 |
-
3.
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
| 90 |
|
| 91 |
## Usage
|
| 92 |
|
|
@@ -95,12 +103,13 @@ A shared encoder-decoder backbone with task-specific heads:
|
|
| 95 |
All training and model parameters are managed via Hydra. Configurations are located in the `configs/` directory.
|
| 96 |
|
| 97 |
Available configurations:
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
### Training
|
| 106 |
|
|
@@ -116,6 +125,9 @@ poetry run python scripts/train.py training=medium
|
|
| 116 |
|
| 117 |
# Override parameters
|
| 118 |
poetry run python scripts/train.py training.optimizer.lr=5e-5
|
|
|
|
|
|
|
|
|
|
| 119 |
```
|
| 120 |
|
| 121 |
Experiments are automatically tracked with MLflow. View results with `mlflow ui`.
|
|
@@ -148,7 +160,7 @@ docker run -p 7860:7860 leximind
|
|
| 148 |
|
| 149 |
## Project Structure
|
| 150 |
|
| 151 |
-
```
|
| 152 |
├── configs/ # Hydra configuration files
|
| 153 |
│ ├── model/ # Model architectures (base, small, large)
|
| 154 |
│ ├── training/ # Training configs (dev, medium, full)
|
|
@@ -169,22 +181,33 @@ docker run -p 7860:7860 leximind
|
|
| 169 |
|
| 170 |
## Code Quality
|
| 171 |
|
| 172 |
-
*
|
| 173 |
-
*
|
| 174 |
-
*
|
|
|
|
| 175 |
|
| 176 |
```bash
|
|
|
|
| 177 |
poetry run pre-commit install
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
```
|
| 179 |
|
| 180 |
## Performance Optimizations
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
|
| 188 |
## License
|
| 189 |
|
| 190 |
-
MIT License - see [LICENSE](LICENSE) for details.
|
|
|
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
| 11 |
+
## LexiMind: A Multi-Task NLP Model
|
| 12 |
|
| 13 |
LexiMind is a state-of-the-art Natural Language Processing model designed for complex document understanding. It features a **custom-built Transformer architecture** initialized with weights from Google's **FLAN-T5**, combining the flexibility of from-scratch implementation with the power of modern pre-trained models.
|
| 14 |
|
|
|
|
| 18 |
|
| 19 |
## Core Features
|
| 20 |
|
| 21 |
+
* **Abstractive Summarization:** Generates concise, coherent summaries of long-form text using encoder-decoder attention.
|
| 22 |
+
* **Emotion Classification:** Identifies emotions (Joy, Sadness, Anger, Fear, Love, Surprise) conveyed in a document.
|
| 23 |
+
* **Topic Clustering:** Classifies documents into thematic categories (World, Sports, Business, Sci/Tech).
|
| 24 |
|
| 25 |
## Model Architecture
|
| 26 |
|
| 27 |
LexiMind implements a **from-scratch Transformer** with modern architectural choices:
|
| 28 |
|
| 29 |
### Custom Transformer Features
|
| 30 |
+
|
| 31 |
+
* **Pre-Layer Normalization (Pre-LN):** RMSNorm applied before each sublayer for stable training
|
| 32 |
+
* **FlashAttention:** Via PyTorch 2.0's `scaled_dot_product_attention` for efficient computation
|
| 33 |
+
* **Learned Positional Embeddings:** Trainable position representations
|
| 34 |
+
* **Multi-Head Attention:** 12 heads with 768-dimensional representations
|
| 35 |
+
* **RMSNorm:** Modern normalization without bias (more efficient than LayerNorm)
|
| 36 |
|
| 37 |
### Pre-trained Weight Initialization
|
| 38 |
+
|
| 39 |
The model loads weights from **Google's FLAN-T5-base**, which provides:
|
| 40 |
+
|
| 41 |
+
* Strong language understanding from instruction-tuning
|
| 42 |
+
* Excellent performance on summarization and classification tasks
|
| 43 |
+
* Encoder-decoder architecture matching our custom implementation
|
| 44 |
|
| 45 |
### Multi-Task Learning
|
| 46 |
+
|
| 47 |
A shared encoder-decoder backbone with task-specific heads:
|
| 48 |
+
|
| 49 |
+
* **Summarization Head:** Language modeling head with weight tying
|
| 50 |
+
* **Emotion Head:** Mean-pooled classification with dropout
|
| 51 |
+
* **Topic Head:** Mean-pooled classification with dropout
|
| 52 |
|
| 53 |
## Technical Specifications
|
| 54 |
|
|
|
|
| 69 |
|
| 70 |
### Prerequisites
|
| 71 |
|
| 72 |
+
* Python 3.10+
|
| 73 |
+
* Poetry for dependency management
|
| 74 |
+
* Docker (for containerized deployment)
|
| 75 |
+
* An NVIDIA GPU with CUDA support (for training and accelerated inference)
|
| 76 |
|
| 77 |
### Installation
|
| 78 |
|
| 79 |
+
1. **Clone the repository:**
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
git clone https://github.com/OliverPerrin/LexiMind.git
|
| 83 |
+
cd LexiMind
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
2. **Install dependencies:**
|
| 87 |
|
| 88 |
+
```bash
|
| 89 |
+
poetry install
|
| 90 |
+
```
|
|
|
|
| 91 |
|
| 92 |
+
3. **Download and preprocess data:**
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
poetry run python scripts/download_data.py
|
| 96 |
+
poetry run python scripts/preprocess_data.py
|
| 97 |
+
```
|
| 98 |
|
| 99 |
## Usage
|
| 100 |
|
|
|
|
| 103 |
All training and model parameters are managed via Hydra. Configurations are located in the `configs/` directory.
|
| 104 |
|
| 105 |
Available configurations:
|
| 106 |
+
|
| 107 |
+
* `model=base` - FLAN-T5-base (default, 12 layers)
|
| 108 |
+
* `model=small` - Smaller model for testing (no pretrained weights)
|
| 109 |
+
* `model=large` - FLAN-T5-large (24 layers, requires more VRAM)
|
| 110 |
+
* `training=dev` - Quick development run
|
| 111 |
+
* `training=medium` - Balanced training (~2-3 hours on RTX 4070)
|
| 112 |
+
* `training=full` - Full training run
|
| 113 |
|
| 114 |
### Training
|
| 115 |
|
|
|
|
| 125 |
|
| 126 |
# Override parameters
|
| 127 |
poetry run python scripts/train.py training.optimizer.lr=5e-5
|
| 128 |
+
|
| 129 |
+
# Resume from a checkpoint
|
| 130 |
+
poetry run python scripts/train.py training=full resume_from=checkpoints/epoch_5.pt
|
| 131 |
```
|
| 132 |
|
| 133 |
Experiments are automatically tracked with MLflow. View results with `mlflow ui`.
|
|
|
|
| 160 |
|
| 161 |
## Project Structure
|
| 162 |
|
| 163 |
+
```text
|
| 164 |
├── configs/ # Hydra configuration files
|
| 165 |
│ ├── model/ # Model architectures (base, small, large)
|
| 166 |
│ ├── training/ # Training configs (dev, medium, full)
|
|
|
|
| 181 |
|
| 182 |
## Code Quality
|
| 183 |
|
| 184 |
+
* **Ruff:** Fast linting and formatting
|
| 185 |
+
* **MyPy:** Static type checking
|
| 186 |
+
* **Pytest:** Full test suite covering data, models, and training
|
| 187 |
+
* **Pre-commit hooks:** Automated quality checks
|
| 188 |
|
| 189 |
```bash
|
| 190 |
+
# Install hooks
|
| 191 |
poetry run pre-commit install
|
| 192 |
+
|
| 193 |
+
# Lint
|
| 194 |
+
poetry run ruff check .
|
| 195 |
+
|
| 196 |
+
# Type check
|
| 197 |
+
poetry run mypy .
|
| 198 |
+
|
| 199 |
+
# Tests
|
| 200 |
+
poetry run pytest
|
| 201 |
```
|
| 202 |
|
| 203 |
## Performance Optimizations
|
| 204 |
|
| 205 |
+
* **torch.compile:** JIT compilation with Inductor backend
|
| 206 |
+
* **Mixed Precision:** bfloat16 training on Ampere/Ada GPUs
|
| 207 |
+
* **TF32:** Enabled for RTX 30xx/40xx series
|
| 208 |
+
* **KV-Cache:** Efficient autoregressive decoding
|
| 209 |
+
* **FlashAttention:** Memory-efficient attention via SDPA
|
| 210 |
|
| 211 |
## License
|
| 212 |
|
| 213 |
+
MIT License - see [LICENSE](LICENSE) for details.
|
configs/config.yaml
CHANGED
|
@@ -14,5 +14,6 @@ hydra:
|
|
| 14 |
checkpoint_out: "checkpoints/best.pt"
|
| 15 |
labels_out: "artifacts/labels.json"
|
| 16 |
history_out: "outputs/training_history.json"
|
|
|
|
| 17 |
device: "cuda"
|
| 18 |
seed: 17
|
|
|
|
| 14 |
checkpoint_out: "checkpoints/best.pt"
|
| 15 |
labels_out: "artifacts/labels.json"
|
| 16 |
history_out: "outputs/training_history.json"
|
| 17 |
+
resume_from: null
|
| 18 |
device: "cuda"
|
| 19 |
seed: 17
|
configs/model/base.yaml
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
# FLAN-T5-base architecture
|
| 2 |
-
#
|
| 3 |
d_model: 768
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
num_attention_heads: 12
|
| 7 |
ffn_dim: 2048 # T5 uses d_ff = 2048 for base model
|
| 8 |
dropout: 0.1
|
|
@@ -10,3 +12,5 @@ activation: gated-gelu # T5/FLAN-T5 uses gated-gelu (GELU activation with gatin
|
|
| 10 |
use_pretrained: true
|
| 11 |
pretrained_model_name: google/flan-t5-base
|
| 12 |
use_relative_position_bias: true # T5 uses relative position bias instead of absolute embeddings
|
|
|
|
|
|
|
|
|
| 1 |
# FLAN-T5-base architecture
|
| 2 |
+
# 12 encoder layers, 12 decoder layers, 768 hidden dim
|
| 3 |
d_model: 768
|
| 4 |
+
# Align vocab with FLAN-T5 padded size to avoid weight truncation
|
| 5 |
+
vocab_size: 32128
|
| 6 |
+
num_encoder_layers: 12 # T5-base has 12 layers
|
| 7 |
+
num_decoder_layers: 12 # T5-base has 12 layers
|
| 8 |
num_attention_heads: 12
|
| 9 |
ffn_dim: 2048 # T5 uses d_ff = 2048 for base model
|
| 10 |
dropout: 0.1
|
|
|
|
| 12 |
use_pretrained: true
|
| 13 |
pretrained_model_name: google/flan-t5-base
|
| 14 |
use_relative_position_bias: true # T5 uses relative position bias instead of absolute embeddings
|
| 15 |
+
gradient_checkpointing: false
|
| 16 |
+
|
configs/training/dev.yaml
CHANGED
|
@@ -1,35 +1,45 @@
|
|
| 1 |
# Development/Testing Configuration for FLAN-T5-base
|
| 2 |
# Fast iteration for debugging and testing changes
|
| 3 |
-
#
|
|
|
|
| 4 |
# Use: python scripts/train.py training=dev
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
-
batch_size:
|
| 8 |
shuffle: true
|
| 9 |
-
num_workers:
|
| 10 |
pin_memory: true
|
| 11 |
persistent_workers: true
|
| 12 |
-
prefetch_factor:
|
| 13 |
|
| 14 |
optimizer:
|
| 15 |
name: adamw
|
| 16 |
-
lr:
|
| 17 |
weight_decay: 0.01
|
| 18 |
-
eps: 1.0e-
|
|
|
|
| 19 |
|
| 20 |
scheduler:
|
| 21 |
name: cosine
|
| 22 |
-
warmup_steps:
|
| 23 |
|
| 24 |
trainer:
|
| 25 |
-
max_epochs:
|
| 26 |
gradient_clip_norm: 1.0
|
| 27 |
-
gradient_accumulation_steps:
|
| 28 |
validation_max_length: 128
|
| 29 |
label_smoothing: 0.1
|
| 30 |
task_weights:
|
| 31 |
summarization: 1.0
|
| 32 |
-
emotion:
|
| 33 |
-
topic:
|
| 34 |
-
max_train_samples:
|
| 35 |
-
max_val_samples:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Development/Testing Configuration for FLAN-T5-base
|
| 2 |
# Fast iteration for debugging and testing changes
|
| 3 |
+
# VRAM Usage: ~8-9GB peak (12GB available)
|
| 4 |
+
# Training time: ~10-15 minutes on RTX 4070 12GB
|
| 5 |
# Use: python scripts/train.py training=dev
|
| 6 |
|
| 7 |
dataloader:
|
| 8 |
+
batch_size: 5 # Conservative for 12GB VRAM
|
| 9 |
shuffle: true
|
| 10 |
+
num_workers: 4
|
| 11 |
pin_memory: true
|
| 12 |
persistent_workers: true
|
| 13 |
+
prefetch_factor: 2
|
| 14 |
|
| 15 |
optimizer:
|
| 16 |
name: adamw
|
| 17 |
+
lr: 5.0e-5 # Higher LR for faster convergence in dev
|
| 18 |
weight_decay: 0.01
|
| 19 |
+
eps: 1.0e-8
|
| 20 |
+
betas: [0.9, 0.999]
|
| 21 |
|
| 22 |
scheduler:
|
| 23 |
name: cosine
|
| 24 |
+
warmup_steps: 100 # ~2% of training steps for smoother start
|
| 25 |
|
| 26 |
trainer:
|
| 27 |
+
max_epochs: 3
|
| 28 |
gradient_clip_norm: 1.0
|
| 29 |
+
gradient_accumulation_steps: 12 # Effective batch: 60 (5*12)
|
| 30 |
validation_max_length: 128
|
| 31 |
label_smoothing: 0.1
|
| 32 |
task_weights:
|
| 33 |
summarization: 1.0
|
| 34 |
+
emotion: 0.5
|
| 35 |
+
topic: 0.5
|
| 36 |
+
max_train_samples: 3000 # 3k samples for better validation
|
| 37 |
+
max_val_samples: 300
|
| 38 |
+
early_stopping_patience: 5 # Stop if no improvement
|
| 39 |
+
log_grad_norm_frequency: 100
|
| 40 |
+
|
| 41 |
+
# Disable compile for faster startup in dev
|
| 42 |
+
compile_encoder: false
|
| 43 |
+
compile_decoder: false
|
| 44 |
+
|
| 45 |
+
tokenizer_max_length: 512
|
configs/training/full.yaml
CHANGED
|
@@ -1,33 +1,44 @@
|
|
| 1 |
# Full Training Configuration for FLAN-T5-base
|
| 2 |
-
# Complete training run on all data
|
| 3 |
-
#
|
|
|
|
| 4 |
# Use: python scripts/train.py training=full
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
-
batch_size:
|
| 8 |
shuffle: true
|
| 9 |
-
num_workers:
|
| 10 |
pin_memory: true
|
| 11 |
persistent_workers: true
|
| 12 |
-
prefetch_factor:
|
| 13 |
|
| 14 |
optimizer:
|
| 15 |
name: adamw
|
| 16 |
-
lr:
|
| 17 |
weight_decay: 0.01
|
| 18 |
eps: 1.0e-6
|
|
|
|
| 19 |
|
| 20 |
scheduler:
|
| 21 |
name: cosine
|
| 22 |
-
warmup_steps: 1000
|
| 23 |
|
| 24 |
trainer:
|
| 25 |
-
max_epochs:
|
| 26 |
gradient_clip_norm: 1.0
|
| 27 |
-
gradient_accumulation_steps:
|
| 28 |
validation_max_length: 128
|
| 29 |
label_smoothing: 0.1
|
| 30 |
task_weights:
|
| 31 |
-
summarization: 1.
|
| 32 |
emotion: 1.0
|
| 33 |
-
topic:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Full Training Configuration for FLAN-T5-base
|
| 2 |
+
# Complete training run on all available data
|
| 3 |
+
# VRAM Usage: ~10-11GB peak (12GB available)
|
| 4 |
+
# Training time: ~3-4 hours on RTX 4070 12GB with torch.compile
|
| 5 |
# Use: python scripts/train.py training=full
|
| 6 |
|
| 7 |
dataloader:
|
| 8 |
+
batch_size: 6 # Conservative for 12GB VRAM with torch.compile overhead
|
| 9 |
shuffle: true
|
| 10 |
+
num_workers: 4
|
| 11 |
pin_memory: true
|
| 12 |
persistent_workers: true
|
| 13 |
+
prefetch_factor: 2
|
| 14 |
|
| 15 |
optimizer:
|
| 16 |
name: adamw
|
| 17 |
+
lr: 3.0e-5 # Higher LR with larger effective batch
|
| 18 |
weight_decay: 0.01
|
| 19 |
eps: 1.0e-6
|
| 20 |
+
betas: [0.9, 0.999]
|
| 21 |
|
| 22 |
scheduler:
|
| 23 |
name: cosine
|
| 24 |
+
warmup_steps: 1000 # ~1% warmup for stability
|
| 25 |
|
| 26 |
trainer:
|
| 27 |
+
max_epochs: 8 # More epochs for full dataset
|
| 28 |
gradient_clip_norm: 1.0
|
| 29 |
+
gradient_accumulation_steps: 16 # Effective batch: 96 (6*16)
|
| 30 |
validation_max_length: 128
|
| 31 |
label_smoothing: 0.1
|
| 32 |
task_weights:
|
| 33 |
+
summarization: 1.5 # Prioritize summarization quality
|
| 34 |
emotion: 1.0
|
| 35 |
+
topic: 0.8
|
| 36 |
+
# No max_samples - use full dataset
|
| 37 |
+
early_stopping_patience: 3 # Stop if plateaus
|
| 38 |
+
log_grad_norm_frequency: 100
|
| 39 |
+
|
| 40 |
+
# Enable torch.compile for maximum speed
|
| 41 |
+
compile_encoder: true
|
| 42 |
+
compile_decoder: true
|
| 43 |
+
|
| 44 |
+
tokenizer_max_length: 512
|
configs/training/medium.yaml
CHANGED
|
@@ -1,35 +1,45 @@
|
|
| 1 |
# Medium Configuration for FLAN-T5-base
|
| 2 |
# Balanced approach - good results in reasonable time
|
| 3 |
-
#
|
|
|
|
| 4 |
# Use: python scripts/train.py training=medium
|
| 5 |
|
| 6 |
dataloader:
|
| 7 |
-
batch_size:
|
| 8 |
shuffle: true
|
| 9 |
-
num_workers:
|
| 10 |
pin_memory: true
|
| 11 |
persistent_workers: true
|
| 12 |
-
prefetch_factor:
|
| 13 |
|
| 14 |
optimizer:
|
| 15 |
name: adamw
|
| 16 |
-
lr: 3.0e-5
|
| 17 |
weight_decay: 0.01
|
| 18 |
eps: 1.0e-6
|
|
|
|
| 19 |
|
| 20 |
scheduler:
|
| 21 |
name: cosine
|
| 22 |
-
warmup_steps:
|
| 23 |
|
| 24 |
trainer:
|
| 25 |
-
max_epochs:
|
| 26 |
gradient_clip_norm: 1.0
|
| 27 |
-
gradient_accumulation_steps:
|
| 28 |
validation_max_length: 128
|
| 29 |
label_smoothing: 0.1
|
| 30 |
task_weights:
|
| 31 |
-
summarization: 1.
|
| 32 |
-
emotion:
|
| 33 |
-
topic:
|
| 34 |
-
max_train_samples:
|
| 35 |
-
max_val_samples:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Medium Configuration for FLAN-T5-base
|
| 2 |
# Balanced approach - good results in reasonable time
|
| 3 |
+
# VRAM Usage: ~9-10GB peak (12GB available)
|
| 4 |
+
# Training time: ~45-60 minutes on RTX 4070 12GB with torch.compile
|
| 5 |
# Use: python scripts/train.py training=medium
|
| 6 |
|
| 7 |
dataloader:
|
| 8 |
+
batch_size: 6 # Conservative for 12GB VRAM with torch.compile
|
| 9 |
shuffle: true
|
| 10 |
+
num_workers: 4
|
| 11 |
pin_memory: true
|
| 12 |
persistent_workers: true
|
| 13 |
+
prefetch_factor: 2
|
| 14 |
|
| 15 |
optimizer:
|
| 16 |
name: adamw
|
| 17 |
+
lr: 3.0e-5 # Balanced LR for quality
|
| 18 |
weight_decay: 0.01
|
| 19 |
eps: 1.0e-6
|
| 20 |
+
betas: [0.9, 0.999]
|
| 21 |
|
| 22 |
scheduler:
|
| 23 |
name: cosine
|
| 24 |
+
warmup_steps: 500 # ~2% warmup for 25k steps
|
| 25 |
|
| 26 |
trainer:
|
| 27 |
+
max_epochs: 5 # More epochs for better convergence
|
| 28 |
gradient_clip_norm: 1.0
|
| 29 |
+
gradient_accumulation_steps: 12 # Effective batch: 72 (6*12)
|
| 30 |
validation_max_length: 128
|
| 31 |
label_smoothing: 0.1
|
| 32 |
task_weights:
|
| 33 |
+
summarization: 1.2 # Slightly prioritize summarization
|
| 34 |
+
emotion: 0.8
|
| 35 |
+
topic: 0.8
|
| 36 |
+
max_train_samples: 25000 # 25k samples - good balance
|
| 37 |
+
max_val_samples: 2500
|
| 38 |
+
early_stopping_patience: 3
|
| 39 |
+
log_grad_norm_frequency: 100
|
| 40 |
+
|
| 41 |
+
# Enable torch.compile for 1.5-2x speedup
|
| 42 |
+
compile_encoder: true
|
| 43 |
+
compile_decoder: true
|
| 44 |
+
|
| 45 |
+
tokenizer_max_length: 512
|
docs/api.md
DELETED
|
@@ -1,79 +0,0 @@
|
|
| 1 |
-
# API & CLI Documentation
|
| 2 |
-
|
| 3 |
-
## FastAPI Service
|
| 4 |
-
The FastAPI application is defined in `src/api/app.py` and wires routes from
|
| 5 |
-
`src/api/routes.py`. All dependencies resolve through `src/api/dependencies.py`, which lazily constructs the shared inference pipeline.
|
| 6 |
-
|
| 7 |
-
### POST `/summarize`
|
| 8 |
-
- **Request Body** (`SummaryRequest`):
|
| 9 |
-
```json
|
| 10 |
-
{
|
| 11 |
-
"text": "Your input document"
|
| 12 |
-
}
|
| 13 |
-
```
|
| 14 |
-
- **Response** (`SummaryResponse`):
|
| 15 |
-
```json
|
| 16 |
-
{
|
| 17 |
-
"summary": "Generated abstractive summary",
|
| 18 |
-
"emotion_labels": ["joy", "surprise"],
|
| 19 |
-
"emotion_scores": [0.91, 0.63],
|
| 20 |
-
"topic": "news",
|
| 21 |
-
"topic_confidence": 0.82
|
| 22 |
-
}
|
| 23 |
-
```
|
| 24 |
-
- **Behaviour:**
|
| 25 |
-
1. Text is preprocessed through `TextPreprocessor` (with optional sklearn transformer if configured).
|
| 26 |
-
2. The multitask model generates a summary via greedy decoding.
|
| 27 |
-
3. Emotion and topic heads produce logits which are converted to probabilities and mapped to
|
| 28 |
-
human-readable labels using `artifacts/labels.json`.
|
| 29 |
-
4. Results are returned as structured JSON suitable for a future Gradio interface.
|
| 30 |
-
|
| 31 |
-
### Error Handling
|
| 32 |
-
- If the checkpoint or label metadata is missing, the dependency raises an HTTP 503 error with
|
| 33 |
-
an explanatory message.
|
| 34 |
-
- Validation errors (missing `text`) are handled automatically by FastAPI/Pydantic.
|
| 35 |
-
|
| 36 |
-
## Command-Line Interface
|
| 37 |
-
`scripts/inference.py` provides a CLI that mirrors the API behaviour.
|
| 38 |
-
|
| 39 |
-
### Usage
|
| 40 |
-
```bash
|
| 41 |
-
python scripts/inference.py "Document to analyse" \
|
| 42 |
-
--checkpoint checkpoints/best.pt \
|
| 43 |
-
--labels artifacts/labels.json \
|
| 44 |
-
--tokenizer artifacts/hf_tokenizer \
|
| 45 |
-
--model-config configs/model/base.yaml \
|
| 46 |
-
--device cpu
|
| 47 |
-
```
|
| 48 |
-
|
| 49 |
-
Options:
|
| 50 |
-
- `text` – zero or more positional arguments. If omitted, use `--file` to point to a newline
|
| 51 |
-
delimited text file.
|
| 52 |
-
- `--file` – optional path containing one text per line.
|
| 53 |
-
- `--checkpoint` – path to the trained model weights.
|
| 54 |
-
- `--labels` – JSON containing emotion/topic vocabularies (defaults to `artifacts/labels.json`).
|
| 55 |
-
- `--tokenizer` – optional tokenizer directory; defaults to the exported artifact if present.
|
| 56 |
-
- `--model-config` – YAML describing the architecture.
|
| 57 |
-
- `--device` – `cpu` or `cuda`. Passing `cuda` attempts to run inference on GPU.
|
| 58 |
-
- `--summary-max-length` – overrides the default maximum generation length.
|
| 59 |
-
|
| 60 |
-
### Output
|
| 61 |
-
The CLI prints a JSON array where each entry contains the original text, summary, emotion labels
|
| 62 |
-
with scores, and topic prediction. This format is identical to the REST response, facilitating
|
| 63 |
-
integration tests and future Gradio UI rendering.
|
| 64 |
-
|
| 65 |
-
## Future Gradio UI
|
| 66 |
-
- The planned UI will call the same inference pipeline and display results interactively.
|
| 67 |
-
- Given the response schema, the UI can show:
|
| 68 |
-
- Generated summary text.
|
| 69 |
-
- Emotion chips with probability bars.
|
| 70 |
-
- Topic confidence gauges.
|
| 71 |
-
- Placeholder panel for attention heatmaps and explanations.
|
| 72 |
-
- Once implemented, documentation updates will add a `docs/ui.md` section and screenshots under
|
| 73 |
-
`docs/images/`.
|
| 74 |
-
|
| 75 |
-
## Testing
|
| 76 |
-
- `tests/test_api/test_routes.py` stubs the pipeline to ensure response fields and dependency
|
| 77 |
-
overrides behave as expected.
|
| 78 |
-
- `tests/test_inference/test_pipeline.py` validates pipeline methods end-to-end with dummy models,
|
| 79 |
-
guaranteeing API and CLI consumers receive consistent payload shapes.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/architecture.md
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# LexiMind Architecture
|
| 2 |
|
| 3 |
## Overview
|
|
|
|
| 4 |
LexiMind couples a from-scratch Transformer implementation with a modern data and inference stack. The project consists of three major layers:
|
| 5 |
|
| 6 |
1. **Data & Preprocessing** – lightweight text cleaning built on top of scikit-learn
|
|
@@ -15,6 +16,7 @@ LexiMind couples a from-scratch Transformer implementation with a modern data an
|
|
| 15 |
The custom Transformer is designed with **modern architectural choices** while maintaining compatibility with pre-trained weights from Google's **FLAN-T5**.
|
| 16 |
|
| 17 |
### Architecture Highlights
|
|
|
|
| 18 |
- **Pre-Layer Normalization (Pre-LN):** RMSNorm applied *before* each sublayer for stable training
|
| 19 |
- **RMSNorm:** More efficient than LayerNorm (no mean computation, no bias parameters)
|
| 20 |
- **FlashAttention:** Via PyTorch 2.0's `F.scaled_dot_product_attention` for O(N) memory
|
|
@@ -22,7 +24,9 @@ The custom Transformer is designed with **modern architectural choices** while m
|
|
| 22 |
- **Multi-Head Attention:** 12 heads with optional LoRA adapters and RoPE support
|
| 23 |
|
| 24 |
### Weight Loading from FLAN-T5
|
|
|
|
| 25 |
The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible Pre-LN architecture:
|
|
|
|
| 26 |
- **Token embeddings:** Shared between encoder and decoder
|
| 27 |
- **Attention projections:** Q, K, V, O weights (bias initialized to zero since T5 has no attention bias)
|
| 28 |
- **FFN weights:** `wi_1` → `linear1`, `wo` → `linear2` (T5 uses gated FFN; we use the up/down projections)
|
|
@@ -32,6 +36,7 @@ The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible
|
|
| 32 |
**Note:** T5 uses *relative position bias* computed in attention, not absolute embeddings. Our learned positional embeddings are randomly initialized and train quickly during fine-tuning.
|
| 33 |
|
| 34 |
### File Structure
|
|
|
|
| 35 |
- `src/models/encoder.py` – TransformerEncoder with Pre-LN RMSNorm blocks
|
| 36 |
- `src/models/decoder.py` – TransformerDecoder with KV-cache for efficient generation
|
| 37 |
- `src/models/attention.py` – Multi-Head Attention with FlashAttention, LoRA, and RoPE support
|
|
@@ -40,16 +45,19 @@ The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible
|
|
| 40 |
- `src/models/factory.py` – Builds models and loads FLAN-T5 weights
|
| 41 |
|
| 42 |
## Data, Tokenization, and Preprocessing
|
|
|
|
| 43 |
- `src/data/tokenization.py` wraps `AutoTokenizer` (configured for FLAN-T5) to provide tensor-aware batching and helper utilities for decoder input shifting.
|
| 44 |
- `src/data/preprocessing.py` introduces `TextPreprocessor`, layering a `BasicTextCleaner` with optional scikit-learn transformers.
|
| 45 |
- `src/data/dataset.py` and `src/data/dataloader.py` define strongly typed dataset containers and collators.
|
| 46 |
|
| 47 |
### T5 Tokenizer Differences
|
|
|
|
| 48 |
- **Vocab size:** 32,128 tokens (SentencePiece)
|
| 49 |
- **Special tokens:** pad=0, eos=1 (no explicit BOS; decoder starts with pad token)
|
| 50 |
- **Subword tokenization:** Unigram-based (vs BART's BPE)
|
| 51 |
|
| 52 |
## Training Pipeline
|
|
|
|
| 53 |
- `src/training/trainer.py` coordinates multi-task optimization with:
|
| 54 |
- Mixed precision training (bfloat16 on Ampere/Ada GPUs)
|
| 55 |
- Gradient accumulation for larger effective batch sizes
|
|
@@ -58,12 +66,14 @@ The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible
|
|
| 58 |
- Metrics in `src/training/metrics.py` include accuracy, multi-label F1, and ROUGE-like overlap
|
| 59 |
|
| 60 |
## Inference & Serving
|
|
|
|
| 61 |
- `src/inference/pipeline.py` exposes summarization, emotion, and topic predictions with shared pre-processing, generation, and thresholding logic.
|
| 62 |
- `src/inference/factory.py` rebuilds the full pipeline using the exported tokenizer artifact
|
| 63 |
- The CLI (`scripts/inference.py`) drives the pipeline from the command line
|
| 64 |
- Gradio demo (`scripts/demo_gradio.py`) provides a web interface
|
| 65 |
|
| 66 |
## Key Decisions
|
|
|
|
| 67 |
- **Custom Transformer + Pre-trained Weights:** Building from scratch demonstrates deep understanding while leveraging FLAN-T5's language knowledge
|
| 68 |
- **Pre-LN RMSNorm:** Modern architecture used by LLaMA, T5 v1.1, and other 2023-2025 models
|
| 69 |
- **Tokenizer Artifact Preference:** Inference favors `artifacts/hf_tokenizer` for reproducibility
|
|
|
|
| 1 |
# LexiMind Architecture
|
| 2 |
|
| 3 |
## Overview
|
| 4 |
+
|
| 5 |
LexiMind couples a from-scratch Transformer implementation with a modern data and inference stack. The project consists of three major layers:
|
| 6 |
|
| 7 |
1. **Data & Preprocessing** – lightweight text cleaning built on top of scikit-learn
|
|
|
|
| 16 |
The custom Transformer is designed with **modern architectural choices** while maintaining compatibility with pre-trained weights from Google's **FLAN-T5**.
|
| 17 |
|
| 18 |
### Architecture Highlights
|
| 19 |
+
|
| 20 |
- **Pre-Layer Normalization (Pre-LN):** RMSNorm applied *before* each sublayer for stable training
|
| 21 |
- **RMSNorm:** More efficient than LayerNorm (no mean computation, no bias parameters)
|
| 22 |
- **FlashAttention:** Via PyTorch 2.0's `F.scaled_dot_product_attention` for O(N) memory
|
|
|
|
| 24 |
- **Multi-Head Attention:** 12 heads with optional LoRA adapters and RoPE support
|
| 25 |
|
| 26 |
### Weight Loading from FLAN-T5
|
| 27 |
+
|
| 28 |
The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible Pre-LN architecture:
|
| 29 |
+
|
| 30 |
- **Token embeddings:** Shared between encoder and decoder
|
| 31 |
- **Attention projections:** Q, K, V, O weights (bias initialized to zero since T5 has no attention bias)
|
| 32 |
- **FFN weights:** `wi_1` → `linear1`, `wo` → `linear2` (T5 uses gated FFN; we use the up/down projections)
|
|
|
|
| 36 |
**Note:** T5 uses *relative position bias* computed in attention, not absolute embeddings. Our learned positional embeddings are randomly initialized and train quickly during fine-tuning.
|
| 37 |
|
| 38 |
### File Structure
|
| 39 |
+
|
| 40 |
- `src/models/encoder.py` – TransformerEncoder with Pre-LN RMSNorm blocks
|
| 41 |
- `src/models/decoder.py` – TransformerDecoder with KV-cache for efficient generation
|
| 42 |
- `src/models/attention.py` – Multi-Head Attention with FlashAttention, LoRA, and RoPE support
|
|
|
|
| 45 |
- `src/models/factory.py` – Builds models and loads FLAN-T5 weights
|
| 46 |
|
| 47 |
## Data, Tokenization, and Preprocessing
|
| 48 |
+
|
| 49 |
- `src/data/tokenization.py` wraps `AutoTokenizer` (configured for FLAN-T5) to provide tensor-aware batching and helper utilities for decoder input shifting.
|
| 50 |
- `src/data/preprocessing.py` introduces `TextPreprocessor`, layering a `BasicTextCleaner` with optional scikit-learn transformers.
|
| 51 |
- `src/data/dataset.py` and `src/data/dataloader.py` define strongly typed dataset containers and collators.
|
| 52 |
|
| 53 |
### T5 Tokenizer Differences
|
| 54 |
+
|
| 55 |
- **Vocab size:** 32,128 tokens (SentencePiece)
|
| 56 |
- **Special tokens:** pad=0, eos=1 (no explicit BOS; decoder starts with pad token)
|
| 57 |
- **Subword tokenization:** Unigram-based (vs BART's BPE)
|
| 58 |
|
| 59 |
## Training Pipeline
|
| 60 |
+
|
| 61 |
- `src/training/trainer.py` coordinates multi-task optimization with:
|
| 62 |
- Mixed precision training (bfloat16 on Ampere/Ada GPUs)
|
| 63 |
- Gradient accumulation for larger effective batch sizes
|
|
|
|
| 66 |
- Metrics in `src/training/metrics.py` include accuracy, multi-label F1, and ROUGE-like overlap
|
| 67 |
|
| 68 |
## Inference & Serving
|
| 69 |
+
|
| 70 |
- `src/inference/pipeline.py` exposes summarization, emotion, and topic predictions with shared pre-processing, generation, and thresholding logic.
|
| 71 |
- `src/inference/factory.py` rebuilds the full pipeline using the exported tokenizer artifact
|
| 72 |
- The CLI (`scripts/inference.py`) drives the pipeline from the command line
|
| 73 |
- Gradio demo (`scripts/demo_gradio.py`) provides a web interface
|
| 74 |
|
| 75 |
## Key Decisions
|
| 76 |
+
|
| 77 |
- **Custom Transformer + Pre-trained Weights:** Building from scratch demonstrates deep understanding while leveraging FLAN-T5's language knowledge
|
| 78 |
- **Pre-LN RMSNorm:** Modern architecture used by LLaMA, T5 v1.1, and other 2023-2025 models
|
| 79 |
- **Tokenizer Artifact Preference:** Inference favors `artifacts/hf_tokenizer` for reproducibility
|
docs/training.md
DELETED
|
@@ -1,80 +0,0 @@
|
|
| 1 |
-
# Training Procedure
|
| 2 |
-
|
| 3 |
-
## Data Sources
|
| 4 |
-
- **Summarization** – expects JSONL files with `source` and `summary` fields under
|
| 5 |
-
`data/processed/summarization`.
|
| 6 |
-
- **Emotion Classification** – multi-label samples loaded from JSONL files with
|
| 7 |
-
`text` and `emotions` arrays. The dataset owns a `MultiLabelBinarizer` for consistent encoding.
|
| 8 |
-
- **Topic Classification** – single-label categorical samples with `text` and `topic` fields, encoded via `LabelEncoder`.
|
| 9 |
-
|
| 10 |
-
Paths and tokenizer defaults are configured in `configs/data/datasets.yaml`. The tokenizer section chooses the Hugging Face backbone (`google/flan-t5-base` by default) and maximum length. Gutenberg book downloads are controlled via the `downloads.books` list (each entry includes `name`, `url`, and `output`).
|
| 11 |
-
|
| 12 |
-
## Dataloaders & Collators
|
| 13 |
-
- `SummarizationCollator` encodes encoder/decoder inputs, prepares decoder input IDs via `Tokenizer.prepare_decoder_inputs`, and masks padding tokens with `-100` for loss computation. Note: FLAN-T5 uses `pad_token_id=0` and `decoder_start_token_id=0`.
|
| 14 |
-
- `EmotionCollator` applies the dataset's `MultiLabelBinarizer`, returning dense float tensors suitable for `BCEWithLogitsLoss`.
|
| 15 |
-
- `TopicCollator` emits integer class IDs via the dataset's `LabelEncoder` for `CrossEntropyLoss`.
|
| 16 |
-
|
| 17 |
-
These collators keep all tokenization centralized, reducing duplication and making it easy to swap in additional sklearn transformations through `TextPreprocessor` should we wish to extend cleaning or normalization.
|
| 18 |
-
|
| 19 |
-
## Model Assembly
|
| 20 |
-
- `src/models/factory.build_multitask_model` rebuilds the encoder, decoder, and heads from the tokenizer metadata and YAML config. This factory is used both during training and inference to eliminate drift between environments.
|
| 21 |
-
- Pretrained weights are loaded from FLAN-T5 using `_load_t5_weights()`, which transfers:
|
| 22 |
-
- Shared token embeddings (with proper scaling)
|
| 23 |
-
- Attention projections (q, k, v, o) for all encoder/decoder layers
|
| 24 |
-
- FFN weights (wi_0, wi_1 for gated activation, wo for output)
|
| 25 |
-
- Layer normalization parameters (mapped from T5's RMSNorm)
|
| 26 |
-
- The model wraps:
|
| 27 |
-
- Transformer encoder/decoder stacks with **Pre-LN RMSNorm** architecture.
|
| 28 |
-
- LM head tied to decoder embeddings for summarization.
|
| 29 |
-
- Mean-pooled classification heads for emotion and topic tasks.
|
| 30 |
-
|
| 31 |
-
## Optimisation Loop
|
| 32 |
-
- `src/training/trainer.Trainer` orchestrates multi-task training.
|
| 33 |
-
- Cross-entropy is used for summarization (seq2seq logits vs. shifted labels).
|
| 34 |
-
- `BCEWithLogitsLoss` handles multi-label emotions.
|
| 35 |
-
- `CrossEntropyLoss` handles topic classification.
|
| 36 |
-
- Gradient clipping ensures stability, and per-task weights can be configured via
|
| 37 |
-
`TrainerConfig.task_weights` to balance gradients if needed.
|
| 38 |
-
- Metrics tracked per task:
|
| 39 |
-
- **Summarization** – ROUGE-like overlap metric (`training.metrics.rouge_like`).
|
| 40 |
-
- **Emotion** – micro F1 score for multi-label predictions.
|
| 41 |
-
- **Topic** – categorical accuracy.
|
| 42 |
-
|
| 43 |
-
## Checkpoints & Artifacts
|
| 44 |
-
- `src/utils/io.save_state` stores model weights; checkpoints live under `checkpoints/`.
|
| 45 |
-
- `artifacts/labels.json` captures the ordered emotion/topic vocabularies immediately after
|
| 46 |
-
training. This file is required for inference so class indices map back to human-readable labels.
|
| 47 |
-
- The tokenizer is exported to `artifacts/hf_tokenizer/` for reproducible vocabularies using `scripts/export_tokenizer.py`.
|
| 48 |
-
|
| 49 |
-
## Running Training
|
| 50 |
-
1. Ensure processed datasets are available (see `data/processed/` structure).
|
| 51 |
-
2. Export the FLAN-T5 tokenizer: `python scripts/export_tokenizer.py`
|
| 52 |
-
3. Choose a configuration (e.g., `configs/training/dev.yaml`) for hyperparameters and data splits.
|
| 53 |
-
4. Instantiate the tokenizer via `TokenizerConfig` and build datasets/dataloaders.
|
| 54 |
-
5. Use `build_multitask_model` to construct the model with FLAN-T5 weights, create an optimizer, and run
|
| 55 |
-
`Trainer.fit(train_loaders, val_loaders)`.
|
| 56 |
-
6. Save checkpoints and update `artifacts/labels.json` with the dataset label order.
|
| 57 |
-
|
| 58 |
-
```bash
|
| 59 |
-
# Quick start
|
| 60 |
-
python scripts/export_tokenizer.py # Export FLAN-T5 tokenizer
|
| 61 |
-
python scripts/train.py training=dev # Run dev training (2 epochs)
|
| 62 |
-
python scripts/train.py training=medium # Run medium training (5 epochs)
|
| 63 |
-
python scripts/train.py training=full # Run full training (10 epochs)
|
| 64 |
-
```
|
| 65 |
-
|
| 66 |
-
## Why FLAN-T5?
|
| 67 |
-
LexiMind's custom Transformer uses **Pre-LN (normalization before sublayers)** with **RMSNorm**. This modern architecture choice provides:
|
| 68 |
-
- Better gradient flow during training
|
| 69 |
-
- Improved training stability
|
| 70 |
-
- Faster convergence
|
| 71 |
-
|
| 72 |
-
FLAN-T5 uses the same Pre-LN RMSNorm architecture, making weight transfer straightforward. Previously used BART (Post-LN LayerNorm) had a fundamental architectural mismatch that caused training issues.
|
| 73 |
-
|
| 74 |
-
> **Note:** T5's relative position bias is NOT transferred. The model uses learned positional encodings which train from scratch. This is fine since positional information is task-specific.
|
| 75 |
-
|
| 76 |
-
## Future Enhancements
|
| 77 |
-
- Integrate curriculum scheduling or task-balanced sampling once empirical results dictate.
|
| 78 |
-
- Capture attention maps during training to support visualization in the planned Gradio UI.
|
| 79 |
-
- Leverage the optional `sklearn_transformer` hook in `TextPreprocessor` for lemmatization or domain-specific normalization when datasets require it.
|
| 80 |
-
- Experiment with FLAN-T5-large for improved performance on longer sequences.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs/training_history.json
CHANGED
|
@@ -1,59 +1,59 @@
|
|
| 1 |
{
|
| 2 |
-
"
|
| 3 |
-
"summarization_loss":
|
| 4 |
-
"summarization_rouge_like": 0.
|
| 5 |
-
"emotion_loss": 0.
|
| 6 |
-
"emotion_f1": 0.
|
| 7 |
-
"topic_loss":
|
| 8 |
-
"topic_accuracy": 0.
|
| 9 |
-
"total_loss":
|
| 10 |
-
"epoch":
|
| 11 |
},
|
| 12 |
-
"
|
| 13 |
-
"summarization_loss":
|
| 14 |
-
"summarization_rouge_like": 0.
|
| 15 |
-
"emotion_loss": 0.
|
| 16 |
-
"emotion_f1": 0.
|
| 17 |
-
"topic_loss": 0.
|
| 18 |
-
"topic_accuracy": 0.
|
| 19 |
-
"epoch":
|
| 20 |
},
|
| 21 |
-
"
|
| 22 |
-
"summarization_loss":
|
| 23 |
-
"summarization_rouge_like": 0.
|
| 24 |
-
"emotion_loss": 0.
|
| 25 |
-
"emotion_f1": 0.
|
| 26 |
-
"topic_loss": 0.
|
| 27 |
-
"topic_accuracy": 0.
|
| 28 |
-
"total_loss": 5.
|
| 29 |
-
"epoch":
|
| 30 |
},
|
| 31 |
-
"
|
| 32 |
-
"summarization_loss":
|
| 33 |
-
"summarization_rouge_like": 0.
|
| 34 |
-
"emotion_loss": 0.
|
| 35 |
-
"emotion_f1": 0.
|
| 36 |
-
"topic_loss": 0.
|
| 37 |
-
"topic_accuracy": 0.
|
| 38 |
-
"epoch":
|
| 39 |
},
|
| 40 |
-
"
|
| 41 |
-
"summarization_loss":
|
| 42 |
-
"summarization_rouge_like": 0.
|
| 43 |
-
"emotion_loss": 0.
|
| 44 |
-
"emotion_f1": 0.
|
| 45 |
-
"topic_loss": 0.
|
| 46 |
-
"topic_accuracy": 0.
|
| 47 |
-
"total_loss": 5.
|
| 48 |
-
"epoch":
|
| 49 |
},
|
| 50 |
-
"
|
| 51 |
-
"summarization_loss":
|
| 52 |
-
"summarization_rouge_like": 0.
|
| 53 |
-
"emotion_loss": 0.
|
| 54 |
-
"emotion_f1": 0.
|
| 55 |
-
"topic_loss": 0.
|
| 56 |
-
"topic_accuracy": 0.
|
| 57 |
-
"epoch":
|
| 58 |
}
|
| 59 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"train_epoch_6": {
|
| 3 |
+
"summarization_loss": 3.2071112584752606,
|
| 4 |
+
"summarization_rouge_like": 0.41666206128984185,
|
| 5 |
+
"emotion_loss": 0.13381094067425187,
|
| 6 |
+
"emotion_f1": 0.1527181073975268,
|
| 7 |
+
"topic_loss": 0.6847172836312407,
|
| 8 |
+
"topic_accuracy": 0.7834830254758819,
|
| 9 |
+
"total_loss": 5.492251664781721,
|
| 10 |
+
"epoch": 6.0
|
| 11 |
},
|
| 12 |
+
"val_epoch_6": {
|
| 13 |
+
"summarization_loss": 2.988837990901862,
|
| 14 |
+
"summarization_rouge_like": 0.4475286348323649,
|
| 15 |
+
"emotion_loss": 0.1262940275061054,
|
| 16 |
+
"emotion_f1": 0.19359053170564663,
|
| 17 |
+
"topic_loss": 0.7910004459155627,
|
| 18 |
+
"topic_accuracy": 0.754854122191724,
|
| 19 |
+
"epoch": 6.0
|
| 20 |
},
|
| 21 |
+
"train_epoch_7": {
|
| 22 |
+
"summarization_loss": 3.184010818695097,
|
| 23 |
+
"summarization_rouge_like": 0.41903763419721,
|
| 24 |
+
"emotion_loss": 0.12498181367997213,
|
| 25 |
+
"emotion_f1": 0.2043521878681856,
|
| 26 |
+
"topic_loss": 0.6483695249464139,
|
| 27 |
+
"topic_accuracy": 0.796684177822936,
|
| 28 |
+
"total_loss": 5.419693668500609,
|
| 29 |
+
"epoch": 7.0
|
| 30 |
},
|
| 31 |
+
"val_epoch_7": {
|
| 32 |
+
"summarization_loss": 2.985372142407835,
|
| 33 |
+
"summarization_rouge_like": 0.44758863369550994,
|
| 34 |
+
"emotion_loss": 0.1185748163268729,
|
| 35 |
+
"emotion_f1": 0.2514045691051182,
|
| 36 |
+
"topic_loss": 0.7817700606483663,
|
| 37 |
+
"topic_accuracy": 0.7554132357426027,
|
| 38 |
+
"epoch": 7.0
|
| 39 |
},
|
| 40 |
+
"train_epoch_8": {
|
| 41 |
+
"summarization_loss": 3.171688149997974,
|
| 42 |
+
"summarization_rouge_like": 0.4206951155149097,
|
| 43 |
+
"emotion_loss": 0.12107599671589805,
|
| 44 |
+
"emotion_f1": 0.2286830931525678,
|
| 45 |
+
"topic_loss": 0.6216138880150013,
|
| 46 |
+
"topic_accuracy": 0.8049539626051729,
|
| 47 |
+
"total_loss": 5.375899340986727,
|
| 48 |
+
"epoch": 8.0
|
| 49 |
},
|
| 50 |
+
"val_epoch_8": {
|
| 51 |
+
"summarization_loss": 2.984391659270994,
|
| 52 |
+
"summarization_rouge_like": 0.44770155741256373,
|
| 53 |
+
"emotion_loss": 0.11704520378562873,
|
| 54 |
+
"emotion_f1": 0.26809326239605075,
|
| 55 |
+
"topic_loss": 0.7841400383105634,
|
| 56 |
+
"topic_accuracy": 0.7546508081732227,
|
| 57 |
+
"epoch": 8.0
|
| 58 |
}
|
| 59 |
}
|
scripts/demo_gradio.py
CHANGED
|
@@ -4,10 +4,10 @@ Gradio demo for LexiMind multi-task NLP model.
|
|
| 4 |
Showcases the model's capabilities across three tasks:
|
| 5 |
- Summarization: Generates concise summaries of input text
|
| 6 |
- Emotion Detection: Multi-label emotion classification
|
| 7 |
-
- Topic Classification: Categorizes text into
|
| 8 |
|
| 9 |
Author: Oliver Perrin
|
| 10 |
-
Date: 2025-12-
|
| 11 |
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
|
@@ -38,24 +38,12 @@ logger = get_logger(__name__)
|
|
| 38 |
|
| 39 |
OUTPUTS_DIR = PROJECT_ROOT / "outputs"
|
| 40 |
EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
|
|
|
|
| 41 |
|
| 42 |
SAMPLE_TEXTS = [
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
"patterns with unprecedented accuracy. From healthcare to finance, AI is "
|
| 47 |
-
"revolutionizing industries worldwide."
|
| 48 |
-
),
|
| 49 |
-
(
|
| 50 |
-
"The team's incredible comeback in the final quarter left fans in tears of joy. "
|
| 51 |
-
"After trailing by 20 points, they scored three consecutive touchdowns to secure "
|
| 52 |
-
"their first championship victory in over a decade."
|
| 53 |
-
),
|
| 54 |
-
(
|
| 55 |
-
"Global markets tumbled today as investors reacted to rising inflation concerns. "
|
| 56 |
-
"The Federal Reserve hinted at potential interest rate hikes, sending shockwaves "
|
| 57 |
-
"through technology and banking sectors."
|
| 58 |
-
),
|
| 59 |
]
|
| 60 |
|
| 61 |
# --------------- Pipeline Management ---------------
|
|
@@ -94,27 +82,62 @@ def get_pipeline():
|
|
| 94 |
def analyze(text: str) -> tuple[str, str, str]:
|
| 95 |
"""Run all three tasks and return formatted results."""
|
| 96 |
if not text or not text.strip():
|
| 97 |
-
return "
|
| 98 |
|
| 99 |
try:
|
| 100 |
pipe = get_pipeline()
|
| 101 |
|
| 102 |
# Run tasks
|
| 103 |
-
summary = pipe.summarize([text], max_length=128)[0].strip()
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
| 105 |
topic = pipe.predict_topics([text])[0]
|
| 106 |
|
| 107 |
-
# Format emotions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if emotions.labels:
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
| 113 |
else:
|
| 114 |
-
emotion_str = "No strong emotions detected"
|
| 115 |
|
| 116 |
# Format topic
|
| 117 |
-
topic_str = f"**{topic.label}
|
| 118 |
|
| 119 |
return summary, emotion_str, topic_str
|
| 120 |
|
|
@@ -125,75 +148,138 @@ def analyze(text: str) -> tuple[str, str, str]:
|
|
| 125 |
|
| 126 |
def load_metrics() -> str:
|
| 127 |
"""Load evaluation metrics and format as markdown."""
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|------|--------|-------|
|
| 140 |
-
| **Emotion** | F1 Macro | **{r["emotion"]["f1_macro"]:.1%}** |
|
| 141 |
-
| **Topic** | Accuracy | **{r["topic"]["accuracy"]:.1%}** |
|
| 142 |
-
| **Summarization** | ROUGE-Like | {r["summarization"]["rouge_like"]:.1%} |
|
| 143 |
-
| **Summarization** | BLEU | {r["summarization"]["bleu"]:.1%} |
|
| 144 |
|
| 145 |
-
### Topic Classification (per-class)
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
| Sports | {r["topic"]["classification_report"]["Sports"]["precision"]:.1%} | {r["topic"]["classification_report"]["Sports"]["recall"]:.1%} | {r["topic"]["classification_report"]["Sports"]["f1-score"]:.1%} |
|
| 152 |
-
| World | {r["topic"]["classification_report"]["World"]["precision"]:.1%} | {r["topic"]["classification_report"]["World"]["recall"]:.1%} | {r["topic"]["classification_report"]["World"]["f1-score"]:.1%} |
|
| 153 |
-
"""
|
| 154 |
-
except Exception as e:
|
| 155 |
-
return f"Error loading metrics: {e}"
|
| 156 |
|
| 157 |
|
| 158 |
# --------------- Gradio Interface ---------------
|
| 159 |
|
| 160 |
with gr.Blocks(
|
| 161 |
-
title="LexiMind
|
| 162 |
theme=gr.themes.Soft(),
|
| 163 |
-
css=".output-box { min-height: 80px; }",
|
| 164 |
) as demo:
|
| 165 |
gr.Markdown(
|
| 166 |
"""
|
| 167 |
# 🧠 LexiMind
|
| 168 |
### Multi-Task Transformer for Document Analysis
|
| 169 |
|
| 170 |
-
A custom encoder-decoder Transformer trained on summarization
|
| 171 |
-
and topic classification. Built from scratch with PyTorch.
|
|
|
|
|
|
|
| 172 |
"""
|
| 173 |
)
|
| 174 |
|
| 175 |
# --------------- Try It Tab ---------------
|
| 176 |
with gr.Tab("🚀 Try It"):
|
| 177 |
with gr.Row():
|
| 178 |
-
with gr.Column(scale=
|
| 179 |
text_input = gr.Textbox(
|
| 180 |
-
label="Input Text",
|
| 181 |
-
lines=
|
| 182 |
-
placeholder="Enter text to analyze...",
|
| 183 |
value=SAMPLE_TEXTS[0],
|
| 184 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
with gr.Row():
|
| 186 |
-
|
| 187 |
-
gr.
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
| 192 |
|
| 193 |
with gr.Column(scale=2):
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
analyze_btn.click(
|
| 199 |
fn=analyze,
|
|
@@ -203,9 +289,35 @@ with gr.Blocks(
|
|
| 203 |
|
| 204 |
# --------------- Metrics Tab ---------------
|
| 205 |
with gr.Tab("📊 Metrics"):
|
| 206 |
-
gr.
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
# --------------- Architecture Tab ---------------
|
| 211 |
with gr.Tab("🔧 Architecture"):
|
|
@@ -213,28 +325,34 @@ with gr.Blocks(
|
|
| 213 |
"""
|
| 214 |
### Model Architecture
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
-
###
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
"""
|
| 228 |
)
|
| 229 |
-
with gr.Row():
|
| 230 |
-
gr.Image(
|
| 231 |
-
str(OUTPUTS_DIR / "attention_visualization.png"),
|
| 232 |
-
label="Self-Attention Pattern",
|
| 233 |
-
)
|
| 234 |
-
gr.Image(
|
| 235 |
-
str(OUTPUTS_DIR / "positional_encoding_heatmap.png"),
|
| 236 |
-
label="Positional Encodings",
|
| 237 |
-
)
|
| 238 |
|
| 239 |
# --------------- About Tab ---------------
|
| 240 |
with gr.Tab("ℹ️ About"):
|
|
@@ -242,22 +360,28 @@ with gr.Blocks(
|
|
| 242 |
"""
|
| 243 |
### About LexiMind
|
| 244 |
|
| 245 |
-
LexiMind is a
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
-
- **
|
| 249 |
-
- **
|
| 250 |
-
- **
|
| 251 |
-
- **Comprehensive evaluation** with multiple metrics
|
| 252 |
|
| 253 |
### Links
|
| 254 |
|
| 255 |
- 🔗 [GitHub Repository](https://github.com/OliverPerrin/LexiMind)
|
| 256 |
-
- 🤗 [HuggingFace
|
| 257 |
|
| 258 |
-
|
| 259 |
|
| 260 |
-
**Oliver Perrin**
|
| 261 |
"""
|
| 262 |
)
|
| 263 |
|
|
|
|
| 4 |
Showcases the model's capabilities across three tasks:
|
| 5 |
- Summarization: Generates concise summaries of input text
|
| 6 |
- Emotion Detection: Multi-label emotion classification
|
| 7 |
+
- Topic Classification: Categorizes text into topics
|
| 8 |
|
| 9 |
Author: Oliver Perrin
|
| 10 |
+
Date: 2025-12-05
|
| 11 |
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
|
|
|
| 38 |
|
| 39 |
OUTPUTS_DIR = PROJECT_ROOT / "outputs"
|
| 40 |
EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
|
| 41 |
+
TRAINING_HISTORY_PATH = OUTPUTS_DIR / "training_history.json"
|
| 42 |
|
| 43 |
SAMPLE_TEXTS = [
|
| 44 |
+
"Global markets tumbled today as investors reacted to rising inflation concerns. The Federal Reserve hinted at potential interest rate hikes, sending shockwaves through technology and banking sectors. Analysts predict continued volatility as economic uncertainty persists.",
|
| 45 |
+
"Scientists at MIT have developed a breakthrough quantum computing chip that operates at room temperature. This advancement could revolutionize drug discovery, cryptography, and artificial intelligence. The research team published their findings in Nature.",
|
| 46 |
+
"The championship game ended in dramatic fashion as the underdog team scored in the final seconds to secure victory. Fans rushed the field in celebration, marking the team's first title in 25 years.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
]
|
| 48 |
|
| 49 |
# --------------- Pipeline Management ---------------
|
|
|
|
| 82 |
def analyze(text: str) -> tuple[str, str, str]:
|
| 83 |
"""Run all three tasks and return formatted results."""
|
| 84 |
if not text or not text.strip():
|
| 85 |
+
return "Please enter text above to analyze.", "", ""
|
| 86 |
|
| 87 |
try:
|
| 88 |
pipe = get_pipeline()
|
| 89 |
|
| 90 |
# Run tasks
|
| 91 |
+
summary = pipe.summarize([text], max_length=128)[0].strip()
|
| 92 |
+
if not summary:
|
| 93 |
+
summary = "(Unable to generate summary)"
|
| 94 |
+
|
| 95 |
+
emotions = pipe.predict_emotions([text], threshold=0.3)[0] # Lower threshold
|
| 96 |
topic = pipe.predict_topics([text])[0]
|
| 97 |
|
| 98 |
+
# Format emotions with emoji
|
| 99 |
+
emotion_emoji = {
|
| 100 |
+
"joy": "😊",
|
| 101 |
+
"love": "❤️",
|
| 102 |
+
"anger": "😠",
|
| 103 |
+
"fear": "😨",
|
| 104 |
+
"sadness": "😢",
|
| 105 |
+
"surprise": "😲",
|
| 106 |
+
"neutral": "😐",
|
| 107 |
+
"admiration": "🤩",
|
| 108 |
+
"amusement": "😄",
|
| 109 |
+
"annoyance": "😤",
|
| 110 |
+
"approval": "👍",
|
| 111 |
+
"caring": "🤗",
|
| 112 |
+
"confusion": "😕",
|
| 113 |
+
"curiosity": "🤔",
|
| 114 |
+
"desire": "😍",
|
| 115 |
+
"disappointment": "😞",
|
| 116 |
+
"disapproval": "👎",
|
| 117 |
+
"disgust": "🤢",
|
| 118 |
+
"embarrassment": "😳",
|
| 119 |
+
"excitement": "🎉",
|
| 120 |
+
"gratitude": "🙏",
|
| 121 |
+
"grief": "😭",
|
| 122 |
+
"nervousness": "��",
|
| 123 |
+
"optimism": "🌟",
|
| 124 |
+
"pride": "🦁",
|
| 125 |
+
"realization": "💡",
|
| 126 |
+
"relief": "😌",
|
| 127 |
+
"remorse": "😔",
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
if emotions.labels:
|
| 131 |
+
emotion_parts = []
|
| 132 |
+
for lbl, score in zip(emotions.labels[:5], emotions.scores[:5], strict=False):
|
| 133 |
+
emoji = emotion_emoji.get(lbl.lower(), "•")
|
| 134 |
+
emotion_parts.append(f"{emoji} **{lbl.title()}** ({score:.0%})")
|
| 135 |
+
emotion_str = "\n".join(emotion_parts)
|
| 136 |
else:
|
| 137 |
+
emotion_str = "😐 No strong emotions detected"
|
| 138 |
|
| 139 |
# Format topic
|
| 140 |
+
topic_str = f"**{topic.label}**\n\nConfidence: {topic.confidence:.0%}"
|
| 141 |
|
| 142 |
return summary, emotion_str, topic_str
|
| 143 |
|
|
|
|
| 148 |
|
| 149 |
def load_metrics() -> str:
|
| 150 |
"""Load evaluation metrics and format as markdown."""
|
| 151 |
+
# Load evaluation report
|
| 152 |
+
eval_metrics = {}
|
| 153 |
+
if EVAL_REPORT_PATH.exists():
|
| 154 |
+
try:
|
| 155 |
+
with open(EVAL_REPORT_PATH) as f:
|
| 156 |
+
eval_metrics = json.load(f)
|
| 157 |
+
except Exception:
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
# Load training history
|
| 161 |
+
train_metrics = {}
|
| 162 |
+
if TRAINING_HISTORY_PATH.exists():
|
| 163 |
+
try:
|
| 164 |
+
with open(TRAINING_HISTORY_PATH) as f:
|
| 165 |
+
train_metrics = json.load(f)
|
| 166 |
+
except Exception:
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
# Get final validation metrics
|
| 170 |
+
val_final = train_metrics.get("val_epoch_3", {})
|
| 171 |
+
|
| 172 |
+
md = """
|
| 173 |
+
## 📈 Model Performance
|
| 174 |
+
|
| 175 |
+
### Training Results (3 Epochs)
|
| 176 |
+
|
| 177 |
+
| Task | Metric | Final Score |
|
| 178 |
+
|------|--------|-------------|
|
| 179 |
+
| **Topic Classification** | Accuracy | **{topic_acc:.1%}** |
|
| 180 |
+
| **Emotion Detection** | F1 (training) | {emo_f1:.1%} |
|
| 181 |
+
| **Summarization** | ROUGE-like | {rouge:.1%} |
|
| 182 |
+
|
| 183 |
+
### Evaluation Results
|
| 184 |
+
|
| 185 |
+
| Metric | Value |
|
| 186 |
+
|--------|-------|
|
| 187 |
+
| Topic Accuracy | **{eval_topic:.1%}** |
|
| 188 |
+
| Emotion F1 (macro) | {eval_emo:.1%} |
|
| 189 |
+
| ROUGE-like | {eval_rouge:.1%} |
|
| 190 |
+
| BLEU | {eval_bleu:.3f} |
|
| 191 |
+
|
| 192 |
+
---
|
| 193 |
+
|
| 194 |
+
### Topic Classification Details
|
| 195 |
|
| 196 |
+
| Category | Precision | Recall | F1 |
|
| 197 |
+
|----------|-----------|--------|-----|
|
| 198 |
+
""".format(
|
| 199 |
+
topic_acc=val_final.get("topic_accuracy", 0),
|
| 200 |
+
emo_f1=val_final.get("emotion_f1", 0),
|
| 201 |
+
rouge=val_final.get("summarization_rouge_like", 0),
|
| 202 |
+
eval_topic=eval_metrics.get("topic", {}).get("accuracy", 0),
|
| 203 |
+
eval_emo=eval_metrics.get("emotion", {}).get("f1_macro", 0),
|
| 204 |
+
eval_rouge=eval_metrics.get("summarization", {}).get("rouge_like", 0),
|
| 205 |
+
eval_bleu=eval_metrics.get("summarization", {}).get("bleu", 0),
|
| 206 |
+
)
|
| 207 |
|
| 208 |
+
# Add per-class metrics
|
| 209 |
+
topic_report = eval_metrics.get("topic", {}).get("classification_report", {})
|
| 210 |
+
for cat, metrics in topic_report.items():
|
| 211 |
+
if cat in ["macro avg", "weighted avg", "micro avg"]:
|
| 212 |
+
continue
|
| 213 |
+
if isinstance(metrics, dict):
|
| 214 |
+
md += f"| {cat} | {metrics.get('precision', 0):.1%} | {metrics.get('recall', 0):.1%} | {metrics.get('f1-score', 0):.1%} |\n"
|
| 215 |
|
| 216 |
+
return md
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
|
|
|
| 218 |
|
| 219 |
+
def get_viz_path(filename: str) -> str | None:
|
| 220 |
+
"""Get visualization path if file exists."""
|
| 221 |
+
path = OUTPUTS_DIR / filename
|
| 222 |
+
return str(path) if path.exists() else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
|
| 225 |
# --------------- Gradio Interface ---------------
|
| 226 |
|
| 227 |
with gr.Blocks(
|
| 228 |
+
title="LexiMind - Multi-Task NLP",
|
| 229 |
theme=gr.themes.Soft(),
|
|
|
|
| 230 |
) as demo:
|
| 231 |
gr.Markdown(
|
| 232 |
"""
|
| 233 |
# 🧠 LexiMind
|
| 234 |
### Multi-Task Transformer for Document Analysis
|
| 235 |
|
| 236 |
+
A custom encoder-decoder Transformer trained on **summarization**, **emotion detection** (28 classes),
|
| 237 |
+
and **topic classification** (10 categories). Built from scratch with PyTorch.
|
| 238 |
+
|
| 239 |
+
> ⚠️ **Note**: Summarization is experimental - the model works best on news-style articles.
|
| 240 |
"""
|
| 241 |
)
|
| 242 |
|
| 243 |
# --------------- Try It Tab ---------------
|
| 244 |
with gr.Tab("🚀 Try It"):
|
| 245 |
with gr.Row():
|
| 246 |
+
with gr.Column(scale=3):
|
| 247 |
text_input = gr.Textbox(
|
| 248 |
+
label="📝 Input Text",
|
| 249 |
+
lines=6,
|
| 250 |
+
placeholder="Enter or paste text to analyze (works best with news articles)...",
|
| 251 |
value=SAMPLE_TEXTS[0],
|
| 252 |
)
|
| 253 |
+
analyze_btn = gr.Button(
|
| 254 |
+
"🔍 Analyze",
|
| 255 |
+
variant="primary",
|
| 256 |
+
size="sm",
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
gr.Markdown("**Sample Texts** (click to use):")
|
| 260 |
with gr.Row():
|
| 261 |
+
sample1_btn = gr.Button("📰 Markets", size="sm", variant="secondary")
|
| 262 |
+
sample2_btn = gr.Button("🔬 Science", size="sm", variant="secondary")
|
| 263 |
+
sample3_btn = gr.Button("🏆 Sports", size="sm", variant="secondary")
|
| 264 |
+
|
| 265 |
+
sample1_btn.click(fn=lambda: SAMPLE_TEXTS[0], outputs=text_input)
|
| 266 |
+
sample2_btn.click(fn=lambda: SAMPLE_TEXTS[1], outputs=text_input)
|
| 267 |
+
sample3_btn.click(fn=lambda: SAMPLE_TEXTS[2], outputs=text_input)
|
| 268 |
|
| 269 |
with gr.Column(scale=2):
|
| 270 |
+
gr.Markdown("### Results")
|
| 271 |
+
summary_out = gr.Textbox(
|
| 272 |
+
label="📝 Summary",
|
| 273 |
+
lines=3,
|
| 274 |
+
interactive=False,
|
| 275 |
+
)
|
| 276 |
+
with gr.Row():
|
| 277 |
+
with gr.Column():
|
| 278 |
+
gr.Markdown("**😊 Emotions**")
|
| 279 |
+
emotion_out = gr.Markdown(value="*Run analysis*")
|
| 280 |
+
with gr.Column():
|
| 281 |
+
gr.Markdown("**📂 Topic**")
|
| 282 |
+
topic_out = gr.Markdown(value="*Run analysis*")
|
| 283 |
|
| 284 |
analyze_btn.click(
|
| 285 |
fn=analyze,
|
|
|
|
| 289 |
|
| 290 |
# --------------- Metrics Tab ---------------
|
| 291 |
with gr.Tab("📊 Metrics"):
|
| 292 |
+
with gr.Row():
|
| 293 |
+
with gr.Column(scale=2):
|
| 294 |
+
gr.Markdown(load_metrics())
|
| 295 |
+
with gr.Column(scale=1):
|
| 296 |
+
confusion_path = get_viz_path("topic_confusion_matrix.png")
|
| 297 |
+
if confusion_path:
|
| 298 |
+
gr.Image(confusion_path, label="Confusion Matrix", show_label=True)
|
| 299 |
+
|
| 300 |
+
# --------------- Visualizations Tab ---------------
|
| 301 |
+
with gr.Tab("🎨 Visualizations"):
|
| 302 |
+
gr.Markdown("### Model Internals")
|
| 303 |
+
|
| 304 |
+
with gr.Row():
|
| 305 |
+
attn_path = get_viz_path("attention_visualization.png")
|
| 306 |
+
if attn_path:
|
| 307 |
+
gr.Image(attn_path, label="Self-Attention Pattern")
|
| 308 |
+
|
| 309 |
+
pos_path = get_viz_path("positional_encoding_heatmap.png")
|
| 310 |
+
if pos_path:
|
| 311 |
+
gr.Image(pos_path, label="Positional Encodings")
|
| 312 |
+
|
| 313 |
+
with gr.Row():
|
| 314 |
+
multi_path = get_viz_path("multihead_attention_visualization.png")
|
| 315 |
+
if multi_path:
|
| 316 |
+
gr.Image(multi_path, label="Multi-Head Attention")
|
| 317 |
+
|
| 318 |
+
single_path = get_viz_path("single_vs_multihead.png")
|
| 319 |
+
if single_path:
|
| 320 |
+
gr.Image(single_path, label="Single vs Multi-Head Comparison")
|
| 321 |
|
| 322 |
# --------------- Architecture Tab ---------------
|
| 323 |
with gr.Tab("🔧 Architecture"):
|
|
|
|
| 325 |
"""
|
| 326 |
### Model Architecture
|
| 327 |
|
| 328 |
+
| Component | Configuration |
|
| 329 |
+
|-----------|---------------|
|
| 330 |
+
| **Base** | Custom Transformer (encoder-decoder) |
|
| 331 |
+
| **Initialization** | FLAN-T5-base weights |
|
| 332 |
+
| **Encoder** | 6 layers, 768 hidden dim, 12 heads |
|
| 333 |
+
| **Decoder** | 6 layers with cross-attention |
|
| 334 |
+
| **Activation** | Gated-GELU |
|
| 335 |
+
| **Position** | Relative position bias |
|
| 336 |
+
|
| 337 |
+
### Training Configuration
|
| 338 |
+
|
| 339 |
+
| Setting | Value |
|
| 340 |
+
|---------|-------|
|
| 341 |
+
| **Optimizer** | AdamW (lr=2e-5, wd=0.01) |
|
| 342 |
+
| **Scheduler** | Cosine with 1000 warmup steps |
|
| 343 |
+
| **Batch Size** | 14 × 3 accumulation = 42 effective |
|
| 344 |
+
| **Precision** | TF32 (Ampere GPU) |
|
| 345 |
+
| **Compilation** | torch.compile (inductor) |
|
| 346 |
|
| 347 |
+
### Datasets
|
| 348 |
|
| 349 |
+
| Task | Dataset | Size |
|
| 350 |
+
|------|---------|------|
|
| 351 |
+
| **Summarization** | CNN/DailyMail + BookSum | ~110K |
|
| 352 |
+
| **Emotion** | GoEmotions | ~43K (28 labels) |
|
| 353 |
+
| **Topic** | Yahoo Answers | ~200K (10 classes) |
|
| 354 |
"""
|
| 355 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
# --------------- About Tab ---------------
|
| 358 |
with gr.Tab("ℹ️ About"):
|
|
|
|
| 360 |
"""
|
| 361 |
### About LexiMind
|
| 362 |
|
| 363 |
+
LexiMind is a **portfolio project** demonstrating end-to-end machine learning engineering:
|
| 364 |
+
|
| 365 |
+
✅ Custom Transformer implementation from scratch
|
| 366 |
+
✅ Multi-task learning with shared encoder
|
| 367 |
+
✅ Production-ready inference pipeline
|
| 368 |
+
✅ Comprehensive evaluation and visualization
|
| 369 |
+
✅ CI/CD with GitHub Actions
|
| 370 |
+
|
| 371 |
+
### Known Limitations
|
| 372 |
|
| 373 |
+
- **Summarization** quality is limited (needs more training epochs)
|
| 374 |
+
- **Emotion detection** has low F1 due to class imbalance in GoEmotions
|
| 375 |
+
- Best results on **news-style text** (training domain)
|
|
|
|
| 376 |
|
| 377 |
### Links
|
| 378 |
|
| 379 |
- 🔗 [GitHub Repository](https://github.com/OliverPerrin/LexiMind)
|
| 380 |
+
- 🤗 [Model on HuggingFace](https://huggingface.co/OliverPerrin/LexiMind-Model)
|
| 381 |
|
| 382 |
+
---
|
| 383 |
|
| 384 |
+
**Built by Oliver Perrin** | December 2025
|
| 385 |
"""
|
| 386 |
)
|
| 387 |
|
scripts/download_data.py
CHANGED
|
@@ -85,6 +85,59 @@ TOPIC_LABELS = [
|
|
| 85 |
# --------------- Utility Functions ---------------
|
| 86 |
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
def _write_jsonl(records: list[dict], destination: Path, desc: str = "Writing") -> None:
|
| 89 |
"""Write records to JSONL file with progress bar."""
|
| 90 |
destination.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 85 |
# --------------- Utility Functions ---------------
|
| 86 |
|
| 87 |
|
| 88 |
+
def _normalize_label(label: object, label_names: list[str]) -> str:
|
| 89 |
+
"""Convert a label index or raw value into a string name.
|
| 90 |
+
|
| 91 |
+
- Valid integer indices are mapped to label_names.
|
| 92 |
+
- Everything else is stringified for robustness.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
if isinstance(label, int) and 0 <= label < len(label_names):
|
| 96 |
+
return label_names[label]
|
| 97 |
+
return str(label)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _emotion_records(dataset_split: Any, label_names: list[str]) -> list[dict[str, object]]:
|
| 101 |
+
"""Yield emotion records with resilient label handling."""
|
| 102 |
+
|
| 103 |
+
records: list[dict[str, object]] = []
|
| 104 |
+
for row in dataset_split:
|
| 105 |
+
text = str(getattr(row, "text", None) or row.get("text", ""))
|
| 106 |
+
raw_labels = getattr(row, "label", None) or row.get("label") or row.get("labels", [])
|
| 107 |
+
|
| 108 |
+
# Normalize to list
|
| 109 |
+
if isinstance(raw_labels, list):
|
| 110 |
+
label_values = raw_labels
|
| 111 |
+
elif raw_labels is None:
|
| 112 |
+
label_values = []
|
| 113 |
+
else:
|
| 114 |
+
label_values = [raw_labels]
|
| 115 |
+
|
| 116 |
+
emotions = [_normalize_label(lbl, label_names) for lbl in label_values]
|
| 117 |
+
if text:
|
| 118 |
+
records.append({"text": text, "emotions": emotions})
|
| 119 |
+
return records
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _topic_records(dataset_split: Any, label_names: list[str]) -> list[dict[str, object]]:
|
| 123 |
+
"""Yield topic records with resilient label handling."""
|
| 124 |
+
|
| 125 |
+
records: list[dict[str, object]] = []
|
| 126 |
+
for row in dataset_split:
|
| 127 |
+
text = str(getattr(row, "text", None) or row.get("text", ""))
|
| 128 |
+
raw_label = getattr(row, "label", None) or row.get("label") or row.get("topic")
|
| 129 |
+
|
| 130 |
+
if isinstance(raw_label, list):
|
| 131 |
+
label_value = raw_label[0] if raw_label else ""
|
| 132 |
+
else:
|
| 133 |
+
label_value = raw_label
|
| 134 |
+
|
| 135 |
+
topic = _normalize_label(label_value, label_names) if label_value is not None else ""
|
| 136 |
+
if text:
|
| 137 |
+
records.append({"text": text, "topic": topic})
|
| 138 |
+
return records
|
| 139 |
+
|
| 140 |
+
|
| 141 |
def _write_jsonl(records: list[dict], destination: Path, desc: str = "Writing") -> None:
|
| 142 |
"""Write records to JSONL file with progress bar."""
|
| 143 |
destination.parent.mkdir(parents=True, exist_ok=True)
|
scripts/process_books.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Process book collection with LexiMind model.
|
| 3 |
+
|
| 4 |
+
Analyzes each book to generate:
|
| 5 |
+
- Overall topic classification
|
| 6 |
+
- Dominant emotions
|
| 7 |
+
- Concise summary
|
| 8 |
+
|
| 9 |
+
Results are saved to data/processed/books/library.json for future use.
|
| 10 |
+
|
| 11 |
+
Author: Oliver Perrin
|
| 12 |
+
Date: December 2025
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 22 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 23 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 24 |
+
|
| 25 |
+
from src.inference.factory import create_inference_pipeline
|
| 26 |
+
from src.utils.logging import configure_logging, get_logger
|
| 27 |
+
|
| 28 |
+
configure_logging()
|
| 29 |
+
logger = get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
# --------------- Configuration ---------------
|
| 32 |
+
|
| 33 |
+
BOOKS_DIR = PROJECT_ROOT / "data" / "raw" / "books"
|
| 34 |
+
OUTPUT_PATH = PROJECT_ROOT / "data" / "processed" / "books" / "library.json"
|
| 35 |
+
|
| 36 |
+
# Chunk books into manageable sections for analysis
|
| 37 |
+
MAX_CHUNK_LENGTH = 1000 # characters per chunk
|
| 38 |
+
MAX_CHUNKS = 5 # analyze first N chunks to get representative sample
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# --------------- Book Processing ---------------
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def clean_text(text: str) -> str:
|
| 45 |
+
"""Clean and normalize book text."""
|
| 46 |
+
# Remove Project Gutenberg headers/footers (common patterns)
|
| 47 |
+
lines = text.split("\n")
|
| 48 |
+
start_idx = 0
|
| 49 |
+
end_idx = len(lines)
|
| 50 |
+
|
| 51 |
+
for i, line in enumerate(lines):
|
| 52 |
+
if "START OF" in line.upper() and "PROJECT GUTENBERG" in line.upper():
|
| 53 |
+
start_idx = i + 1
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
for i in range(len(lines) - 1, -1, -1):
|
| 57 |
+
if "END OF" in lines[i].upper() and "PROJECT GUTENBERG" in lines[i].upper():
|
| 58 |
+
end_idx = i
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
text = "\n".join(lines[start_idx:end_idx])
|
| 62 |
+
|
| 63 |
+
# Basic cleanup
|
| 64 |
+
text = text.strip()
|
| 65 |
+
text = " ".join(text.split()) # normalize whitespace
|
| 66 |
+
|
| 67 |
+
return text
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def chunk_text(text: str, chunk_size: int = MAX_CHUNK_LENGTH) -> list[str]:
|
| 71 |
+
"""Split text into chunks for analysis."""
|
| 72 |
+
words = text.split()
|
| 73 |
+
chunks = []
|
| 74 |
+
current_chunk = []
|
| 75 |
+
current_length = 0
|
| 76 |
+
|
| 77 |
+
for word in words:
|
| 78 |
+
current_chunk.append(word)
|
| 79 |
+
current_length += len(word) + 1 # +1 for space
|
| 80 |
+
|
| 81 |
+
if current_length >= chunk_size:
|
| 82 |
+
chunks.append(" ".join(current_chunk))
|
| 83 |
+
current_chunk = []
|
| 84 |
+
current_length = 0
|
| 85 |
+
|
| 86 |
+
if current_chunk:
|
| 87 |
+
chunks.append(" ".join(current_chunk))
|
| 88 |
+
|
| 89 |
+
return chunks
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def process_book(book_path: Path, pipeline) -> dict:
|
| 93 |
+
"""Analyze a single book and return metadata."""
|
| 94 |
+
logger.info(f"Processing {book_path.name}...")
|
| 95 |
+
|
| 96 |
+
# Read and clean
|
| 97 |
+
try:
|
| 98 |
+
text = book_path.read_text(encoding="utf-8", errors="ignore")
|
| 99 |
+
except Exception as exc:
|
| 100 |
+
logger.error(f"Failed to read {book_path.name}: {exc}")
|
| 101 |
+
return {}
|
| 102 |
+
|
| 103 |
+
text = clean_text(text)
|
| 104 |
+
|
| 105 |
+
if not text or len(text) < 100:
|
| 106 |
+
logger.warning(f"Skipping {book_path.name} - insufficient content")
|
| 107 |
+
return {}
|
| 108 |
+
|
| 109 |
+
# Chunk and sample
|
| 110 |
+
chunks = chunk_text(text)
|
| 111 |
+
sample_chunks = chunks[: min(MAX_CHUNKS, len(chunks))]
|
| 112 |
+
|
| 113 |
+
logger.info(f" Analyzing {len(sample_chunks)} chunks (of {len(chunks)} total)...")
|
| 114 |
+
|
| 115 |
+
# Run inference on chunks
|
| 116 |
+
try:
|
| 117 |
+
topics = pipeline.predict_topics(sample_chunks)
|
| 118 |
+
emotions = pipeline.predict_emotions(sample_chunks, threshold=0.3)
|
| 119 |
+
summaries = pipeline.summarize(sample_chunks, max_length=64)
|
| 120 |
+
|
| 121 |
+
# Aggregate results
|
| 122 |
+
# Topic: most common prediction
|
| 123 |
+
topic_counts: dict[str, int] = {}
|
| 124 |
+
for t in topics:
|
| 125 |
+
topic_counts[t.label] = topic_counts.get(t.label, 0) + 1
|
| 126 |
+
dominant_topic = max(topic_counts.items(), key=lambda x: x[1])[0]
|
| 127 |
+
|
| 128 |
+
# Emotion: aggregate top emotions
|
| 129 |
+
all_emotions: dict[str, list[float]] = {}
|
| 130 |
+
for emotion in emotions:
|
| 131 |
+
for label, score in zip(emotion.labels, emotion.scores, strict=False):
|
| 132 |
+
if label not in all_emotions:
|
| 133 |
+
all_emotions[label] = []
|
| 134 |
+
all_emotions[label].append(score)
|
| 135 |
+
|
| 136 |
+
# Average scores and take top 3
|
| 137 |
+
emotion_scores = {
|
| 138 |
+
label: sum(scores) / len(scores) for label, scores in all_emotions.items()
|
| 139 |
+
}
|
| 140 |
+
top_emotions = sorted(emotion_scores.items(), key=lambda x: x[1], reverse=True)[:3]
|
| 141 |
+
|
| 142 |
+
# Summary: combine first few chunk summaries
|
| 143 |
+
combined_summary = " ".join(summaries[:3])
|
| 144 |
+
|
| 145 |
+
result: dict[str, object] = {
|
| 146 |
+
"title": book_path.stem.replace("_", " ").title(),
|
| 147 |
+
"filename": book_path.name,
|
| 148 |
+
"topic": dominant_topic,
|
| 149 |
+
"emotions": [{"label": label, "score": float(score)} for label, score in top_emotions],
|
| 150 |
+
"summary": combined_summary,
|
| 151 |
+
"word_count": len(text.split()),
|
| 152 |
+
"chunks_analyzed": len(sample_chunks),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
logger.info(
|
| 156 |
+
f" ✓ {result['title']}: {result['topic']} | "
|
| 157 |
+
f"{', '.join(str(e['label']) for e in result['emotions'][:2] if isinstance(e, dict))}" # type: ignore[index]
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
return result
|
| 161 |
+
|
| 162 |
+
except Exception as exc:
|
| 163 |
+
logger.error(f"Analysis failed for {book_path.name}: {exc}", exc_info=True)
|
| 164 |
+
return {}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# --------------- Main ---------------
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def main():
|
| 171 |
+
"""Process all books and save library."""
|
| 172 |
+
logger.info("Loading inference pipeline...")
|
| 173 |
+
|
| 174 |
+
pipeline, label_metadata = create_inference_pipeline(
|
| 175 |
+
tokenizer_dir="artifacts/hf_tokenizer/",
|
| 176 |
+
checkpoint_path="checkpoints/best.pt",
|
| 177 |
+
labels_path="artifacts/labels.json",
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
logger.info("Finding books...")
|
| 181 |
+
book_files = sorted(BOOKS_DIR.glob("*.txt"))
|
| 182 |
+
|
| 183 |
+
if not book_files:
|
| 184 |
+
logger.error(f"No books found in {BOOKS_DIR}")
|
| 185 |
+
return
|
| 186 |
+
|
| 187 |
+
logger.info(f"Found {len(book_files)} books")
|
| 188 |
+
|
| 189 |
+
# Process each book
|
| 190 |
+
library = []
|
| 191 |
+
for book_path in book_files:
|
| 192 |
+
result = process_book(book_path, pipeline)
|
| 193 |
+
if result:
|
| 194 |
+
library.append(result)
|
| 195 |
+
|
| 196 |
+
# Save results
|
| 197 |
+
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 198 |
+
with open(OUTPUT_PATH, "w") as f:
|
| 199 |
+
json.dump(
|
| 200 |
+
{
|
| 201 |
+
"books": library,
|
| 202 |
+
"metadata": {
|
| 203 |
+
"total_books": len(library),
|
| 204 |
+
"chunk_size": MAX_CHUNK_LENGTH,
|
| 205 |
+
"chunks_per_book": MAX_CHUNKS,
|
| 206 |
+
},
|
| 207 |
+
},
|
| 208 |
+
f,
|
| 209 |
+
indent=2,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
logger.info(f"\n✓ Library saved to {OUTPUT_PATH}")
|
| 213 |
+
logger.info(f" Processed {len(library)} books")
|
| 214 |
+
|
| 215 |
+
# Print summary
|
| 216 |
+
print("\n" + "=" * 60)
|
| 217 |
+
print("BOOK LIBRARY SUMMARY")
|
| 218 |
+
print("=" * 60)
|
| 219 |
+
|
| 220 |
+
for book in library:
|
| 221 |
+
print(f"\n📚 {book['title']}")
|
| 222 |
+
print(f" Topic: {book['topic']}")
|
| 223 |
+
emotions_str = ", ".join(f"{e['label']} ({e['score']:.0%})" for e in book["emotions"])
|
| 224 |
+
print(f" Emotions: {emotions_str}")
|
| 225 |
+
print(f" Summary: {book['summary'][:100]}...")
|
| 226 |
+
|
| 227 |
+
print("\n" + "=" * 60)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
main()
|
scripts/train.py
CHANGED
|
@@ -13,6 +13,7 @@ from __future__ import annotations
|
|
| 13 |
import json
|
| 14 |
import logging
|
| 15 |
import os
|
|
|
|
| 16 |
import sys
|
| 17 |
import time
|
| 18 |
import warnings
|
|
@@ -51,7 +52,7 @@ from src.data.tokenization import Tokenizer, TokenizerConfig
|
|
| 51 |
from src.models.factory import ModelConfig, build_multitask_model
|
| 52 |
from src.training.trainer import Trainer, TrainerConfig
|
| 53 |
from src.training.utils import set_seed
|
| 54 |
-
from src.utils.io import save_state
|
| 55 |
from src.utils.labels import LabelMetadata, save_label_metadata
|
| 56 |
|
| 57 |
# --------------- Data Loading ---------------
|
|
@@ -93,12 +94,13 @@ def limit_samples(splits: Dict[str, list], cfg: DictConfig) -> None:
|
|
| 93 |
|
| 94 |
|
| 95 |
def compile_model(model: torch.nn.Module) -> torch.nn.Module:
|
| 96 |
-
"""Compile model with inductor backend (
|
|
|
|
| 97 |
from src.training.safe_compile import apply_safe_config, compile_model_safe
|
| 98 |
|
| 99 |
# Apply safe configuration first
|
| 100 |
apply_safe_config()
|
| 101 |
-
# Compile with default mode (inductor
|
| 102 |
return compile_model_safe(model, mode="default")
|
| 103 |
|
| 104 |
|
|
@@ -148,10 +150,12 @@ def main(cfg: DictConfig) -> None:
|
|
| 148 |
# --------------- Tokenizer & Datasets ---------------
|
| 149 |
|
| 150 |
tok_cfg = data_cfg.get("tokenizer", {})
|
|
|
|
|
|
|
| 151 |
tokenizer = Tokenizer(
|
| 152 |
TokenizerConfig(
|
| 153 |
pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
|
| 154 |
-
max_length=int(tok_cfg.get("max_length", 512)),
|
| 155 |
lower=bool(tok_cfg.get("lower", False)),
|
| 156 |
)
|
| 157 |
)
|
|
@@ -238,6 +242,7 @@ def main(cfg: DictConfig) -> None:
|
|
| 238 |
device = torch.device(cfg.device)
|
| 239 |
model_cfg = ModelConfig(
|
| 240 |
d_model=cfg.model.d_model,
|
|
|
|
| 241 |
num_encoder_layers=cfg.model.num_encoder_layers,
|
| 242 |
num_decoder_layers=cfg.model.num_decoder_layers,
|
| 243 |
num_attention_heads=cfg.model.num_attention_heads,
|
|
@@ -255,12 +260,41 @@ def main(cfg: DictConfig) -> None:
|
|
| 255 |
config=model_cfg,
|
| 256 |
).to(device)
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# Compile encoder/decoder for faster training (skip heads - small overhead)
|
| 259 |
-
|
|
|
|
|
|
|
| 260 |
from src.models.encoder import TransformerEncoder
|
| 261 |
|
| 262 |
model.encoder = cast(TransformerEncoder, compile_model(model.encoder))
|
| 263 |
-
if model.decoder is not None:
|
| 264 |
from src.models.decoder import TransformerDecoder
|
| 265 |
|
| 266 |
model.decoder = cast(TransformerDecoder, compile_model(model.decoder))
|
|
@@ -268,21 +302,30 @@ def main(cfg: DictConfig) -> None:
|
|
| 268 |
# --------------- Optimizer & Trainer ---------------
|
| 269 |
|
| 270 |
opt_cfg = cfg.training.get("optimizer", {})
|
|
|
|
| 271 |
optimizer = torch.optim.AdamW(
|
| 272 |
model.parameters(),
|
| 273 |
lr=float(opt_cfg.get("lr", 3e-5)),
|
| 274 |
weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
|
| 275 |
)
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
trainer = Trainer(
|
| 278 |
model=model,
|
| 279 |
optimizer=optimizer,
|
| 280 |
config=TrainerConfig(
|
| 281 |
-
max_epochs=
|
| 282 |
gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
|
| 283 |
task_weights=trainer_cfg.get("task_weights"),
|
| 284 |
label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
|
| 285 |
gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)),
|
|
|
|
|
|
|
| 286 |
),
|
| 287 |
device=device,
|
| 288 |
tokenizer=tokenizer,
|
|
@@ -298,7 +341,12 @@ def main(cfg: DictConfig) -> None:
|
|
| 298 |
save_state(model, str(path))
|
| 299 |
|
| 300 |
print("\nStarting training...")
|
| 301 |
-
history = trainer.fit(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
# --------------- Save Outputs ---------------
|
| 304 |
|
|
|
|
| 13 |
import json
|
| 14 |
import logging
|
| 15 |
import os
|
| 16 |
+
import re
|
| 17 |
import sys
|
| 18 |
import time
|
| 19 |
import warnings
|
|
|
|
| 52 |
from src.models.factory import ModelConfig, build_multitask_model
|
| 53 |
from src.training.trainer import Trainer, TrainerConfig
|
| 54 |
from src.training.utils import set_seed
|
| 55 |
+
from src.utils.io import load_state, save_state
|
| 56 |
from src.utils.labels import LabelMetadata, save_label_metadata
|
| 57 |
|
| 58 |
# --------------- Data Loading ---------------
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
def compile_model(model: torch.nn.Module) -> torch.nn.Module:
|
| 97 |
+
"""Compile model with inductor backend (optimized for speed)."""
|
| 98 |
+
print(f" -> Enabling torch.compile for {model.__class__.__name__}...")
|
| 99 |
from src.training.safe_compile import apply_safe_config, compile_model_safe
|
| 100 |
|
| 101 |
# Apply safe configuration first
|
| 102 |
apply_safe_config()
|
| 103 |
+
# Compile with default mode (inductor) - most stable
|
| 104 |
return compile_model_safe(model, mode="default")
|
| 105 |
|
| 106 |
|
|
|
|
| 150 |
# --------------- Tokenizer & Datasets ---------------
|
| 151 |
|
| 152 |
tok_cfg = data_cfg.get("tokenizer", {})
|
| 153 |
+
# Allow training overrides for max_length to run shorter dev sweeps
|
| 154 |
+
override_max_len = cfg.training.get("tokenizer_max_length")
|
| 155 |
tokenizer = Tokenizer(
|
| 156 |
TokenizerConfig(
|
| 157 |
pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
|
| 158 |
+
max_length=int(override_max_len or tok_cfg.get("max_length", 512)),
|
| 159 |
lower=bool(tok_cfg.get("lower", False)),
|
| 160 |
)
|
| 161 |
)
|
|
|
|
| 242 |
device = torch.device(cfg.device)
|
| 243 |
model_cfg = ModelConfig(
|
| 244 |
d_model=cfg.model.d_model,
|
| 245 |
+
vocab_size=getattr(cfg.model, "vocab_size", None), # Override tokenizer vocab if specified
|
| 246 |
num_encoder_layers=cfg.model.num_encoder_layers,
|
| 247 |
num_decoder_layers=cfg.model.num_decoder_layers,
|
| 248 |
num_attention_heads=cfg.model.num_attention_heads,
|
|
|
|
| 260 |
config=model_cfg,
|
| 261 |
).to(device)
|
| 262 |
|
| 263 |
+
# If Training Crashes: Resume from checkpoint if provided (load before compile to avoid key mismatches)
|
| 264 |
+
start_epoch = 1
|
| 265 |
+
resume_path = cfg.get("resume_from")
|
| 266 |
+
if resume_path:
|
| 267 |
+
ckpt_path = Path(resume_path)
|
| 268 |
+
if ckpt_path.exists():
|
| 269 |
+
print(f"\n↩Resuming from checkpoint: {ckpt_path}")
|
| 270 |
+
load_state(model, str(ckpt_path))
|
| 271 |
+
# Parse epoch number robustly from filename (e.g., epoch_5.pt)
|
| 272 |
+
epoch_num = None
|
| 273 |
+
try:
|
| 274 |
+
# Prefer stem (no suffix); fallback to any digit sequence in name
|
| 275 |
+
digits = re.findall(r"\d+", ckpt_path.stem)
|
| 276 |
+
if digits:
|
| 277 |
+
epoch_num = int(digits[-1])
|
| 278 |
+
except Exception:
|
| 279 |
+
epoch_num = None
|
| 280 |
+
|
| 281 |
+
if epoch_num is not None:
|
| 282 |
+
start_epoch = epoch_num + 1
|
| 283 |
+
print(f" -> Starting from epoch {start_epoch}")
|
| 284 |
+
else:
|
| 285 |
+
print(" -> Could not parse epoch number; starting from epoch 1")
|
| 286 |
+
start_epoch = 1
|
| 287 |
+
else:
|
| 288 |
+
print(f"⚠ Resume checkpoint not found: {ckpt_path}. Starting from scratch.")
|
| 289 |
+
|
| 290 |
# Compile encoder/decoder for faster training (skip heads - small overhead)
|
| 291 |
+
compile_encoder = bool(cfg.training.get("compile_encoder", True))
|
| 292 |
+
compile_decoder = bool(cfg.training.get("compile_decoder", True))
|
| 293 |
+
if compile_encoder and model.encoder is not None:
|
| 294 |
from src.models.encoder import TransformerEncoder
|
| 295 |
|
| 296 |
model.encoder = cast(TransformerEncoder, compile_model(model.encoder))
|
| 297 |
+
if compile_decoder and model.decoder is not None:
|
| 298 |
from src.models.decoder import TransformerDecoder
|
| 299 |
|
| 300 |
model.decoder = cast(TransformerDecoder, compile_model(model.decoder))
|
|
|
|
| 302 |
# --------------- Optimizer & Trainer ---------------
|
| 303 |
|
| 304 |
opt_cfg = cfg.training.get("optimizer", {})
|
| 305 |
+
sched_cfg = cfg.training.get("scheduler", {})
|
| 306 |
optimizer = torch.optim.AdamW(
|
| 307 |
model.parameters(),
|
| 308 |
lr=float(opt_cfg.get("lr", 3e-5)),
|
| 309 |
weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
|
| 310 |
)
|
| 311 |
|
| 312 |
+
# Clamp start_epoch to max_epochs to avoid empty loop
|
| 313 |
+
max_epochs = int(trainer_cfg.get("max_epochs", 1))
|
| 314 |
+
if start_epoch > max_epochs:
|
| 315 |
+
print(f"⚠ resume_from points past max_epochs ({max_epochs}); nothing to train. Setting start_epoch to {max_epochs}")
|
| 316 |
+
start_epoch = max_epochs
|
| 317 |
+
|
| 318 |
trainer = Trainer(
|
| 319 |
model=model,
|
| 320 |
optimizer=optimizer,
|
| 321 |
config=TrainerConfig(
|
| 322 |
+
max_epochs=max_epochs,
|
| 323 |
gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
|
| 324 |
task_weights=trainer_cfg.get("task_weights"),
|
| 325 |
label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
|
| 326 |
gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)),
|
| 327 |
+
scheduler_type=str(sched_cfg.get("name", "constant")),
|
| 328 |
+
warmup_steps=int(sched_cfg.get("warmup_steps", 0)),
|
| 329 |
),
|
| 330 |
device=device,
|
| 331 |
tokenizer=tokenizer,
|
|
|
|
| 341 |
save_state(model, str(path))
|
| 342 |
|
| 343 |
print("\nStarting training...")
|
| 344 |
+
history = trainer.fit(
|
| 345 |
+
train_loaders,
|
| 346 |
+
val_loaders,
|
| 347 |
+
checkpoint_callback=save_checkpoint,
|
| 348 |
+
start_epoch=start_epoch,
|
| 349 |
+
)
|
| 350 |
|
| 351 |
# --------------- Save Outputs ---------------
|
| 352 |
|
scripts/visualize_training.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visualize training metrics from MLflow runs.
|
| 3 |
+
|
| 4 |
+
Generates plots showing:
|
| 5 |
+
- Loss curves (training/validation)
|
| 6 |
+
- Task-specific metrics over time
|
| 7 |
+
- Learning rate schedule
|
| 8 |
+
- Training speed analysis
|
| 9 |
+
|
| 10 |
+
Author: Oliver Perrin
|
| 11 |
+
Date: December 2025
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import sys
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import mlflow
|
| 22 |
+
import mlflow.tracking
|
| 23 |
+
import seaborn as sns
|
| 24 |
+
|
| 25 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 26 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 27 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 28 |
+
|
| 29 |
+
from src.utils.logging import configure_logging, get_logger
|
| 30 |
+
|
| 31 |
+
configure_logging()
|
| 32 |
+
logger = get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
# Configure plotting style
|
| 35 |
+
sns.set_style("whitegrid")
|
| 36 |
+
plt.rcParams["figure.figsize"] = (12, 8)
|
| 37 |
+
plt.rcParams["figure.dpi"] = 100
|
| 38 |
+
|
| 39 |
+
OUTPUTS_DIR = PROJECT_ROOT / "outputs"
|
| 40 |
+
MLRUNS_DIR = PROJECT_ROOT / "mlruns"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_training_history() -> dict[str, object] | None:
|
| 44 |
+
"""Load training history from JSON if available."""
|
| 45 |
+
history_path = OUTPUTS_DIR / "training_history.json"
|
| 46 |
+
if history_path.exists():
|
| 47 |
+
with open(history_path) as f:
|
| 48 |
+
data: dict[str, object] = json.load(f)
|
| 49 |
+
return data
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_latest_run():
|
| 54 |
+
"""Get the most recent MLflow run."""
|
| 55 |
+
mlflow.set_tracking_uri(f"file://{MLRUNS_DIR}")
|
| 56 |
+
client = mlflow.tracking.MlflowClient()
|
| 57 |
+
|
| 58 |
+
# Get the experiment (LexiMind)
|
| 59 |
+
experiment = client.get_experiment_by_name("LexiMind")
|
| 60 |
+
if not experiment:
|
| 61 |
+
logger.error("No 'LexiMind' experiment found")
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
# Get all runs, sorted by start time
|
| 65 |
+
runs = client.search_runs(
|
| 66 |
+
experiment_ids=[experiment.experiment_id],
|
| 67 |
+
order_by=["start_time DESC"],
|
| 68 |
+
max_results=1,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if not runs:
|
| 72 |
+
logger.error("No runs found in experiment")
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
return runs[0]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def plot_loss_curves(run):
|
| 79 |
+
"""Plot training and validation loss over time."""
|
| 80 |
+
client = mlflow.tracking.MlflowClient()
|
| 81 |
+
|
| 82 |
+
# Get metrics
|
| 83 |
+
train_loss = client.get_metric_history(run.info.run_id, "train_total_loss")
|
| 84 |
+
val_loss = client.get_metric_history(run.info.run_id, "val_total_loss")
|
| 85 |
+
|
| 86 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 87 |
+
|
| 88 |
+
if not train_loss:
|
| 89 |
+
# Create placeholder plot
|
| 90 |
+
ax.text(
|
| 91 |
+
0.5,
|
| 92 |
+
0.5,
|
| 93 |
+
"No training data yet\n\nWaiting for first epoch to complete...",
|
| 94 |
+
ha="center",
|
| 95 |
+
va="center",
|
| 96 |
+
fontsize=14,
|
| 97 |
+
color="gray",
|
| 98 |
+
)
|
| 99 |
+
ax.set_xlim(0, 1)
|
| 100 |
+
ax.set_ylim(0, 1)
|
| 101 |
+
else:
|
| 102 |
+
# Extract steps and values
|
| 103 |
+
train_steps = [m.step for m in train_loss]
|
| 104 |
+
train_values = [m.value for m in train_loss]
|
| 105 |
+
|
| 106 |
+
ax.plot(train_steps, train_values, label="Training Loss", linewidth=2, alpha=0.8)
|
| 107 |
+
|
| 108 |
+
if val_loss:
|
| 109 |
+
val_steps = [m.step for m in val_loss]
|
| 110 |
+
val_values = [m.value for m in val_loss]
|
| 111 |
+
ax.plot(val_steps, val_values, label="Validation Loss", linewidth=2, alpha=0.8)
|
| 112 |
+
|
| 113 |
+
ax.legend(fontsize=11)
|
| 114 |
+
|
| 115 |
+
ax.set_xlabel("Epoch", fontsize=12)
|
| 116 |
+
ax.set_ylabel("Loss", fontsize=12)
|
| 117 |
+
ax.set_title("Training Progress: Total Loss", fontsize=14, fontweight="bold")
|
| 118 |
+
ax.grid(True, alpha=0.3)
|
| 119 |
+
|
| 120 |
+
plt.tight_layout()
|
| 121 |
+
output_path = OUTPUTS_DIR / "training_loss_curve.png"
|
| 122 |
+
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 123 |
+
logger.info(f"✓ Saved loss curve to {output_path}")
|
| 124 |
+
plt.close()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def plot_task_metrics(run):
|
| 128 |
+
"""Plot metrics for each task."""
|
| 129 |
+
client = mlflow.tracking.MlflowClient()
|
| 130 |
+
|
| 131 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 132 |
+
fig.suptitle("Task-Specific Training Metrics", fontsize=16, fontweight="bold")
|
| 133 |
+
|
| 134 |
+
# Summarization
|
| 135 |
+
ax = axes[0, 0]
|
| 136 |
+
train_sum = client.get_metric_history(run.info.run_id, "train_summarization_loss")
|
| 137 |
+
val_sum = client.get_metric_history(run.info.run_id, "val_summarization_loss")
|
| 138 |
+
|
| 139 |
+
if train_sum:
|
| 140 |
+
ax.plot(
|
| 141 |
+
[m.step for m in train_sum], [m.value for m in train_sum], label="Train", linewidth=2
|
| 142 |
+
)
|
| 143 |
+
if val_sum:
|
| 144 |
+
ax.plot([m.step for m in val_sum], [m.value for m in val_sum], label="Val", linewidth=2)
|
| 145 |
+
ax.set_title("Summarization Loss", fontweight="bold")
|
| 146 |
+
ax.set_xlabel("Epoch")
|
| 147 |
+
ax.set_ylabel("Loss")
|
| 148 |
+
ax.legend()
|
| 149 |
+
ax.grid(True, alpha=0.3)
|
| 150 |
+
|
| 151 |
+
# Emotion
|
| 152 |
+
ax = axes[0, 1]
|
| 153 |
+
train_emo = client.get_metric_history(run.info.run_id, "train_emotion_loss")
|
| 154 |
+
val_emo = client.get_metric_history(run.info.run_id, "val_emotion_loss")
|
| 155 |
+
train_f1 = client.get_metric_history(run.info.run_id, "train_emotion_f1")
|
| 156 |
+
val_f1 = client.get_metric_history(run.info.run_id, "val_emotion_f1")
|
| 157 |
+
|
| 158 |
+
if train_emo:
|
| 159 |
+
ax.plot(
|
| 160 |
+
[m.step for m in train_emo],
|
| 161 |
+
[m.value for m in train_emo],
|
| 162 |
+
label="Train Loss",
|
| 163 |
+
linewidth=2,
|
| 164 |
+
)
|
| 165 |
+
if val_emo:
|
| 166 |
+
ax.plot(
|
| 167 |
+
[m.step for m in val_emo], [m.value for m in val_emo], label="Val Loss", linewidth=2
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
ax2 = ax.twinx()
|
| 171 |
+
if train_f1:
|
| 172 |
+
ax2.plot(
|
| 173 |
+
[m.step for m in train_f1],
|
| 174 |
+
[m.value for m in train_f1],
|
| 175 |
+
label="Train F1",
|
| 176 |
+
linewidth=2,
|
| 177 |
+
linestyle="--",
|
| 178 |
+
alpha=0.7,
|
| 179 |
+
)
|
| 180 |
+
if val_f1:
|
| 181 |
+
ax2.plot(
|
| 182 |
+
[m.step for m in val_f1],
|
| 183 |
+
[m.value for m in val_f1],
|
| 184 |
+
label="Val F1",
|
| 185 |
+
linewidth=2,
|
| 186 |
+
linestyle="--",
|
| 187 |
+
alpha=0.7,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
ax.set_title("Emotion Detection", fontweight="bold")
|
| 191 |
+
ax.set_xlabel("Epoch")
|
| 192 |
+
ax.set_ylabel("Loss")
|
| 193 |
+
ax2.set_ylabel("F1 Score")
|
| 194 |
+
ax.legend(loc="upper left")
|
| 195 |
+
ax2.legend(loc="upper right")
|
| 196 |
+
ax.grid(True, alpha=0.3)
|
| 197 |
+
|
| 198 |
+
# Topic
|
| 199 |
+
ax = axes[1, 0]
|
| 200 |
+
train_topic = client.get_metric_history(run.info.run_id, "train_topic_loss")
|
| 201 |
+
val_topic = client.get_metric_history(run.info.run_id, "val_topic_loss")
|
| 202 |
+
train_acc = client.get_metric_history(run.info.run_id, "train_topic_accuracy")
|
| 203 |
+
val_acc = client.get_metric_history(run.info.run_id, "val_topic_accuracy")
|
| 204 |
+
|
| 205 |
+
if train_topic:
|
| 206 |
+
ax.plot(
|
| 207 |
+
[m.step for m in train_topic],
|
| 208 |
+
[m.value for m in train_topic],
|
| 209 |
+
label="Train Loss",
|
| 210 |
+
linewidth=2,
|
| 211 |
+
)
|
| 212 |
+
if val_topic:
|
| 213 |
+
ax.plot(
|
| 214 |
+
[m.step for m in val_topic], [m.value for m in val_topic], label="Val Loss", linewidth=2
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
ax2 = ax.twinx()
|
| 218 |
+
if train_acc:
|
| 219 |
+
ax2.plot(
|
| 220 |
+
[m.step for m in train_acc],
|
| 221 |
+
[m.value for m in train_acc],
|
| 222 |
+
label="Train Acc",
|
| 223 |
+
linewidth=2,
|
| 224 |
+
linestyle="--",
|
| 225 |
+
alpha=0.7,
|
| 226 |
+
)
|
| 227 |
+
if val_acc:
|
| 228 |
+
ax2.plot(
|
| 229 |
+
[m.step for m in val_acc],
|
| 230 |
+
[m.value for m in val_acc],
|
| 231 |
+
label="Val Acc",
|
| 232 |
+
linewidth=2,
|
| 233 |
+
linestyle="--",
|
| 234 |
+
alpha=0.7,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
ax.set_title("Topic Classification", fontweight="bold")
|
| 238 |
+
ax.set_xlabel("Epoch")
|
| 239 |
+
ax.set_ylabel("Loss")
|
| 240 |
+
ax2.set_ylabel("Accuracy")
|
| 241 |
+
ax.legend(loc="upper left")
|
| 242 |
+
ax2.legend(loc="upper right")
|
| 243 |
+
ax.grid(True, alpha=0.3)
|
| 244 |
+
|
| 245 |
+
# Summary statistics
|
| 246 |
+
ax = axes[1, 1]
|
| 247 |
+
ax.axis("off")
|
| 248 |
+
|
| 249 |
+
# Get final metrics
|
| 250 |
+
summary_text = "Final Metrics (Last Epoch)\n" + "=" * 35 + "\n\n"
|
| 251 |
+
|
| 252 |
+
if val_topic and val_acc:
|
| 253 |
+
summary_text += f"Topic Accuracy: {val_acc[-1].value:.1%}\n"
|
| 254 |
+
if val_emo and val_f1:
|
| 255 |
+
summary_text += f"Emotion F1: {val_f1[-1].value:.1%}\n"
|
| 256 |
+
if val_sum:
|
| 257 |
+
summary_text += f"Summarization Loss: {val_sum[-1].value:.3f}\n"
|
| 258 |
+
|
| 259 |
+
ax.text(0.1, 0.5, summary_text, fontsize=12, family="monospace", verticalalignment="center")
|
| 260 |
+
|
| 261 |
+
plt.tight_layout()
|
| 262 |
+
output_path = OUTPUTS_DIR / "task_metrics.png"
|
| 263 |
+
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 264 |
+
logger.info(f"✓ Saved task metrics to {output_path}")
|
| 265 |
+
plt.close()
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def plot_learning_rate(run):
|
| 269 |
+
"""Plot learning rate schedule if available."""
|
| 270 |
+
client = mlflow.tracking.MlflowClient()
|
| 271 |
+
lr_metrics = client.get_metric_history(run.info.run_id, "learning_rate")
|
| 272 |
+
|
| 273 |
+
fig, ax = plt.subplots(figsize=(12, 5))
|
| 274 |
+
|
| 275 |
+
if not lr_metrics:
|
| 276 |
+
# Create placeholder
|
| 277 |
+
ax.text(
|
| 278 |
+
0.5,
|
| 279 |
+
0.5,
|
| 280 |
+
"No learning rate data yet\n\n(Will be logged in future training runs)",
|
| 281 |
+
ha="center",
|
| 282 |
+
va="center",
|
| 283 |
+
fontsize=14,
|
| 284 |
+
color="gray",
|
| 285 |
+
)
|
| 286 |
+
ax.set_xlim(0, 1)
|
| 287 |
+
ax.set_ylim(0, 1)
|
| 288 |
+
else:
|
| 289 |
+
steps = [m.step for m in lr_metrics]
|
| 290 |
+
values = [m.value for m in lr_metrics]
|
| 291 |
+
|
| 292 |
+
ax.plot(steps, values, linewidth=2, color="darkblue")
|
| 293 |
+
|
| 294 |
+
# Mark warmup region
|
| 295 |
+
warmup_steps = 1000 # From config
|
| 296 |
+
if warmup_steps < max(steps):
|
| 297 |
+
ax.axvline(warmup_steps, color="red", linestyle="--", alpha=0.5, label="Warmup End")
|
| 298 |
+
ax.legend()
|
| 299 |
+
|
| 300 |
+
ax.set_xlabel("Step", fontsize=12)
|
| 301 |
+
ax.set_ylabel("Learning Rate", fontsize=12)
|
| 302 |
+
ax.set_title("Learning Rate Schedule (Cosine with Warmup)", fontsize=14, fontweight="bold")
|
| 303 |
+
ax.grid(True, alpha=0.3)
|
| 304 |
+
|
| 305 |
+
plt.tight_layout()
|
| 306 |
+
output_path = OUTPUTS_DIR / "learning_rate_schedule.png"
|
| 307 |
+
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 308 |
+
logger.info(f"✓ Saved LR schedule to {output_path}")
|
| 309 |
+
plt.close()
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def main():
|
| 313 |
+
"""Generate all training visualizations."""
|
| 314 |
+
logger.info("Loading MLflow data...")
|
| 315 |
+
|
| 316 |
+
run = get_latest_run()
|
| 317 |
+
if not run:
|
| 318 |
+
logger.error("No training run found. Make sure training has started.")
|
| 319 |
+
return
|
| 320 |
+
|
| 321 |
+
logger.info(f"Analyzing run: {run.info.run_id}")
|
| 322 |
+
|
| 323 |
+
OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 324 |
+
|
| 325 |
+
logger.info("Generating visualizations...")
|
| 326 |
+
|
| 327 |
+
plot_loss_curves(run)
|
| 328 |
+
plot_task_metrics(run)
|
| 329 |
+
plot_learning_rate(run)
|
| 330 |
+
|
| 331 |
+
logger.info("\n" + "=" * 60)
|
| 332 |
+
logger.info("✓ All visualizations saved to outputs/")
|
| 333 |
+
logger.info("=" * 60)
|
| 334 |
+
logger.info(" - training_loss_curve.png")
|
| 335 |
+
logger.info(" - task_metrics.png")
|
| 336 |
+
logger.info(" - learning_rate_schedule.png")
|
| 337 |
+
logger.info("=" * 60)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
if __name__ == "__main__":
|
| 341 |
+
main()
|
src/data/dataloader.py
CHANGED
|
@@ -48,13 +48,16 @@ class SummarizationCollator:
|
|
| 48 |
src_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
|
| 49 |
tgt_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
|
| 50 |
|
| 51 |
-
# Shift targets: tgt_ids = [BOS, A, B], labels = [A, B, EOS]
|
| 52 |
ids = tgt_enc["input_ids"]
|
| 53 |
mask = tgt_enc["attention_mask"]
|
| 54 |
|
| 55 |
-
|
| 56 |
-
labels = ids
|
| 57 |
-
labels[mask
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
return {
|
| 60 |
"src_ids": src_enc["input_ids"],
|
|
|
|
| 48 |
src_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
|
| 49 |
tgt_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
|
| 50 |
|
|
|
|
| 51 |
ids = tgt_enc["input_ids"]
|
| 52 |
mask = tgt_enc["attention_mask"]
|
| 53 |
|
| 54 |
+
# Create labels for loss: mask padding with -100
|
| 55 |
+
labels = ids.clone()
|
| 56 |
+
labels[mask == 0] = -100
|
| 57 |
+
|
| 58 |
+
# Create decoder inputs from original ids (no -100)
|
| 59 |
+
# prepare_decoder_inputs shifts right and adds BOS
|
| 60 |
+
tgt_ids = self.tokenizer.prepare_decoder_inputs(ids)
|
| 61 |
|
| 62 |
return {
|
| 63 |
"src_ids": src_enc["input_ids"],
|
src/inference/pipeline.py
CHANGED
|
@@ -69,6 +69,7 @@ class InferenceConfig:
|
|
| 69 |
|
| 70 |
summary_max_length: int = 128
|
| 71 |
summary_repetition_penalty: float = 1.2 # Penalize repeated tokens
|
|
|
|
| 72 |
emotion_threshold: float = 0.5
|
| 73 |
device: str | None = None
|
| 74 |
|
|
@@ -164,6 +165,8 @@ class InferencePipeline:
|
|
| 164 |
|
| 165 |
# Decode and format summaries
|
| 166 |
raw_summaries = self.tokenizer.decode_batch(generated.tolist())
|
|
|
|
|
|
|
| 167 |
return [_format_summary(s) for s in raw_summaries]
|
| 168 |
|
| 169 |
# --------------- Emotion ---------------
|
|
|
|
| 69 |
|
| 70 |
summary_max_length: int = 128
|
| 71 |
summary_repetition_penalty: float = 1.2 # Penalize repeated tokens
|
| 72 |
+
summary_formatting: bool = True # Apply text cleanup/formatting to generated summaries
|
| 73 |
emotion_threshold: float = 0.5
|
| 74 |
device: str | None = None
|
| 75 |
|
|
|
|
| 165 |
|
| 166 |
# Decode and format summaries
|
| 167 |
raw_summaries = self.tokenizer.decode_batch(generated.tolist())
|
| 168 |
+
if not self.config.summary_formatting:
|
| 169 |
+
return raw_summaries
|
| 170 |
return [_format_summary(s) for s in raw_summaries]
|
| 171 |
|
| 172 |
# --------------- Emotion ---------------
|
src/inference/postprocessing.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Output postprocessing utilities for LexiMind.
|
| 3 |
-
|
| 4 |
-
Provides text cleaning helpers for model outputs.
|
| 5 |
-
|
| 6 |
-
Author: Oliver Perrin
|
| 7 |
-
Date: December 2025
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
from typing import List
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def strip_whitespace(texts: List[str]) -> List[str]:
|
| 14 |
-
return [text.strip() for text in texts]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/decoder.py
CHANGED
|
@@ -18,10 +18,12 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast
|
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torch.nn as nn
|
|
|
|
| 21 |
|
| 22 |
from .attention import MultiHeadAttention, T5RelativePositionBias
|
| 23 |
from .feedforward import FeedForward
|
| 24 |
from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
|
|
@@ -77,9 +79,9 @@ class TransformerDecoderLayer(nn.Module):
|
|
| 77 |
quantization=quantization,
|
| 78 |
)
|
| 79 |
|
| 80 |
-
self.norm1 =
|
| 81 |
-
self.norm2 =
|
| 82 |
-
self.norm3 =
|
| 83 |
|
| 84 |
self.dropout1 = nn.Dropout(dropout)
|
| 85 |
self.dropout2 = nn.Dropout(dropout)
|
|
@@ -189,6 +191,7 @@ class TransformerDecoder(nn.Module):
|
|
| 189 |
use_learned_pos_enc: bool = False,
|
| 190 |
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
|
| 191 |
use_relative_position_bias: bool = False, # T5-style relative position bias
|
|
|
|
| 192 |
):
|
| 193 |
super().__init__()
|
| 194 |
self.vocab_size = vocab_size
|
|
@@ -196,8 +199,10 @@ class TransformerDecoder(nn.Module):
|
|
| 196 |
self.pad_token_id = pad_token_id
|
| 197 |
self.num_heads = num_heads
|
| 198 |
self.use_relative_position_bias = use_relative_position_bias
|
|
|
|
| 199 |
|
| 200 |
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
|
|
|
|
| 201 |
|
| 202 |
# Positional encoding (disabled when using relative position bias for T5)
|
| 203 |
self.self_relative_position_bias: Optional[T5RelativePositionBias] = None
|
|
@@ -238,8 +243,8 @@ class TransformerDecoder(nn.Module):
|
|
| 238 |
]
|
| 239 |
)
|
| 240 |
|
| 241 |
-
self.final_norm =
|
| 242 |
-
self.output_projection = nn.Linear(d_model, vocab_size)
|
| 243 |
self.input_dropout = nn.Dropout(dropout)
|
| 244 |
|
| 245 |
def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
@@ -252,6 +257,18 @@ class TransformerDecoder(nn.Module):
|
|
| 252 |
"""
|
| 253 |
assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
|
| 254 |
pad_mask = input_ids != self.pad_token_id # (B, T)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) # (B, T, T)
|
| 256 |
return attn_mask
|
| 257 |
|
|
@@ -263,7 +280,7 @@ class TransformerDecoder(nn.Module):
|
|
| 263 |
memory_mask: Optional[torch.Tensor] = None,
|
| 264 |
collect_attn: bool = False,
|
| 265 |
skip_padding_mask: bool = False, # Set True during generation to avoid masking start token
|
| 266 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]]:
|
| 267 |
"""
|
| 268 |
Args:
|
| 269 |
inputs: (B, T) token ids or (B, T, d_model) embeddings
|
|
@@ -304,6 +321,12 @@ class TransformerDecoder(nn.Module):
|
|
| 304 |
else:
|
| 305 |
# Ensure boolean and device alignment; accept (B, T, T) or (B,1,T,T) or (1,1,T,T)
|
| 306 |
tgt_mask = tgt_mask.to(dtype=torch.bool, device=x.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
# Normalize memory_mask dtype/device and expand simple shapes
|
| 309 |
if memory_mask is not None:
|
|
@@ -313,7 +336,7 @@ class TransformerDecoder(nn.Module):
|
|
| 313 |
elif memory_mask.dim() == 3: # (B, T, S) -> (B, 1, T, S)
|
| 314 |
memory_mask = memory_mask.unsqueeze(1)
|
| 315 |
|
| 316 |
-
attn_list: List[Dict[str, torch.Tensor]] = []
|
| 317 |
|
| 318 |
# Compute relative position biases (T5-style)
|
| 319 |
# Note: T5 uses relative position bias for self-attention but NOT for cross-attention
|
|
@@ -328,19 +351,37 @@ class TransformerDecoder(nn.Module):
|
|
| 328 |
|
| 329 |
# Pass through decoder layers
|
| 330 |
for layer in self.layers:
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
if collect_attn:
|
| 341 |
attn_list.append(attn)
|
| 342 |
|
| 343 |
x = self.final_norm(x)
|
|
|
|
| 344 |
logits = self.output_projection(x) # (B, T, vocab)
|
| 345 |
|
| 346 |
if collect_attn:
|
|
|
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torch.nn as nn
|
| 21 |
+
from torch.utils.checkpoint import checkpoint
|
| 22 |
|
| 23 |
from .attention import MultiHeadAttention, T5RelativePositionBias
|
| 24 |
from .feedforward import FeedForward
|
| 25 |
from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
|
| 26 |
+
from .t5_layer_norm import T5LayerNorm
|
| 27 |
|
| 28 |
|
| 29 |
def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
|
|
|
|
| 79 |
quantization=quantization,
|
| 80 |
)
|
| 81 |
|
| 82 |
+
self.norm1 = T5LayerNorm(d_model)
|
| 83 |
+
self.norm2 = T5LayerNorm(d_model)
|
| 84 |
+
self.norm3 = T5LayerNorm(d_model)
|
| 85 |
|
| 86 |
self.dropout1 = nn.Dropout(dropout)
|
| 87 |
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
| 191 |
use_learned_pos_enc: bool = False,
|
| 192 |
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
|
| 193 |
use_relative_position_bias: bool = False, # T5-style relative position bias
|
| 194 |
+
gradient_checkpointing: bool = False,
|
| 195 |
):
|
| 196 |
super().__init__()
|
| 197 |
self.vocab_size = vocab_size
|
|
|
|
| 199 |
self.pad_token_id = pad_token_id
|
| 200 |
self.num_heads = num_heads
|
| 201 |
self.use_relative_position_bias = use_relative_position_bias
|
| 202 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 203 |
|
| 204 |
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
|
| 205 |
+
# Note: T5 does NOT scale logits (scaling factor removed)
|
| 206 |
|
| 207 |
# Positional encoding (disabled when using relative position bias for T5)
|
| 208 |
self.self_relative_position_bias: Optional[T5RelativePositionBias] = None
|
|
|
|
| 243 |
]
|
| 244 |
)
|
| 245 |
|
| 246 |
+
self.final_norm = T5LayerNorm(d_model)
|
| 247 |
+
self.output_projection = nn.Linear(d_model, vocab_size, bias=False) # T5 has no bias
|
| 248 |
self.input_dropout = nn.Dropout(dropout)
|
| 249 |
|
| 250 |
def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 257 |
"""
|
| 258 |
assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
|
| 259 |
pad_mask = input_ids != self.pad_token_id # (B, T)
|
| 260 |
+
|
| 261 |
+
# Always allow attending to the first token (BOS), even if it is pad_token_id
|
| 262 |
+
# Avoid in-place mutation for better torch.compile compatibility
|
| 263 |
+
if pad_mask.size(1) > 0:
|
| 264 |
+
# Create a mask for the first column (B, 1)
|
| 265 |
+
first_col_mask = torch.zeros_like(pad_mask[:, :1], dtype=torch.bool)
|
| 266 |
+
first_col_mask[:] = True
|
| 267 |
+
# Combine: pad_mask OR (column==0)
|
| 268 |
+
# We can do this by creating a column index tensor
|
| 269 |
+
col_indices = torch.arange(pad_mask.size(1), device=pad_mask.device).unsqueeze(0)
|
| 270 |
+
pad_mask = pad_mask | (col_indices == 0)
|
| 271 |
+
|
| 272 |
attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) # (B, T, T)
|
| 273 |
return attn_mask
|
| 274 |
|
|
|
|
| 280 |
memory_mask: Optional[torch.Tensor] = None,
|
| 281 |
collect_attn: bool = False,
|
| 282 |
skip_padding_mask: bool = False, # Set True during generation to avoid masking start token
|
| 283 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, Optional[torch.Tensor]]]]]:
|
| 284 |
"""
|
| 285 |
Args:
|
| 286 |
inputs: (B, T) token ids or (B, T, d_model) embeddings
|
|
|
|
| 321 |
else:
|
| 322 |
# Ensure boolean and device alignment; accept (B, T, T) or (B,1,T,T) or (1,1,T,T)
|
| 323 |
tgt_mask = tgt_mask.to(dtype=torch.bool, device=x.device)
|
| 324 |
+
# If tgt_mask is just causal (T, T), expand it
|
| 325 |
+
if tgt_mask.dim() == 2:
|
| 326 |
+
tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)
|
| 327 |
+
elif tgt_mask.dim() == 3:
|
| 328 |
+
tgt_mask = tgt_mask.unsqueeze(1)
|
| 329 |
+
|
| 330 |
|
| 331 |
# Normalize memory_mask dtype/device and expand simple shapes
|
| 332 |
if memory_mask is not None:
|
|
|
|
| 336 |
elif memory_mask.dim() == 3: # (B, T, S) -> (B, 1, T, S)
|
| 337 |
memory_mask = memory_mask.unsqueeze(1)
|
| 338 |
|
| 339 |
+
attn_list: List[Dict[str, Optional[torch.Tensor]]] = []
|
| 340 |
|
| 341 |
# Compute relative position biases (T5-style)
|
| 342 |
# Note: T5 uses relative position bias for self-attention but NOT for cross-attention
|
|
|
|
| 351 |
|
| 352 |
# Pass through decoder layers
|
| 353 |
for layer in self.layers:
|
| 354 |
+
if self.gradient_checkpointing and self.training:
|
| 355 |
+
# Gradient checkpointing requires the inputs to require grad
|
| 356 |
+
def create_custom_forward(module):
|
| 357 |
+
def custom_forward(*inputs):
|
| 358 |
+
return module(*inputs, tgt_mask=tgt_mask, memory_mask=memory_mask, collect_attn=collect_attn, self_attn_position_bias=self_position_bias, cross_attn_position_bias=cross_position_bias)
|
| 359 |
+
return custom_forward
|
| 360 |
+
|
| 361 |
+
x, attn = cast(
|
| 362 |
+
Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]],
|
| 363 |
+
checkpoint(
|
| 364 |
+
create_custom_forward(layer),
|
| 365 |
+
x,
|
| 366 |
+
memory,
|
| 367 |
+
use_reentrant=False,
|
| 368 |
+
),
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
x, attn = layer(
|
| 372 |
+
x,
|
| 373 |
+
memory,
|
| 374 |
+
tgt_mask=tgt_mask,
|
| 375 |
+
memory_mask=memory_mask,
|
| 376 |
+
collect_attn=collect_attn,
|
| 377 |
+
self_attn_position_bias=self_position_bias,
|
| 378 |
+
cross_attn_position_bias=cross_position_bias,
|
| 379 |
+
)
|
| 380 |
if collect_attn:
|
| 381 |
attn_list.append(attn)
|
| 382 |
|
| 383 |
x = self.final_norm(x)
|
| 384 |
+
# T5 does NOT scale logits - direct projection to vocabulary
|
| 385 |
logits = self.output_projection(x) # (B, T, vocab)
|
| 386 |
|
| 387 |
if collect_attn:
|
src/models/encoder.py
CHANGED
|
@@ -13,15 +13,17 @@ Author: Oliver Perrin
|
|
| 13 |
Date: 2025-10-23
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
from typing import List, Literal, Optional, Tuple, Union
|
| 17 |
|
| 18 |
import torch
|
| 19 |
import torch.nn as nn
|
|
|
|
| 20 |
|
| 21 |
# Encoder implementation
|
| 22 |
from .attention import MultiHeadAttention, T5RelativePositionBias
|
| 23 |
from .feedforward import FeedForward
|
| 24 |
from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class TransformerEncoderLayer(nn.Module):
|
|
@@ -65,8 +67,8 @@ class TransformerEncoderLayer(nn.Module):
|
|
| 65 |
quantization=quantization,
|
| 66 |
)
|
| 67 |
|
| 68 |
-
self.norm1 =
|
| 69 |
-
self.norm2 =
|
| 70 |
|
| 71 |
self.dropout1 = nn.Dropout(dropout)
|
| 72 |
self.dropout2 = nn.Dropout(dropout)
|
|
@@ -153,12 +155,14 @@ class TransformerEncoder(nn.Module):
|
|
| 153 |
use_learned_pos_enc: bool = False,
|
| 154 |
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
|
| 155 |
use_relative_position_bias: bool = False, # T5-style relative position bias
|
|
|
|
| 156 |
):
|
| 157 |
super().__init__()
|
| 158 |
self.vocab_size = vocab_size
|
| 159 |
self.d_model = d_model
|
| 160 |
self.pad_token_id = pad_token_id
|
| 161 |
self.use_relative_position_bias = use_relative_position_bias
|
|
|
|
| 162 |
|
| 163 |
# Token embedding (only used if forward receives token ids)
|
| 164 |
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
|
|
@@ -201,8 +205,8 @@ class TransformerEncoder(nn.Module):
|
|
| 201 |
]
|
| 202 |
)
|
| 203 |
|
| 204 |
-
# Final
|
| 205 |
-
self.final_norm =
|
| 206 |
|
| 207 |
# Dropout applied after embedding + positional encoding (paper uses this)
|
| 208 |
self.input_dropout = nn.Dropout(dropout)
|
|
@@ -282,7 +286,25 @@ class TransformerEncoder(nn.Module):
|
|
| 282 |
|
| 283 |
# Pass through each encoder layer (optionally collect attn)
|
| 284 |
for layer in self.layers:
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
if collect_attn:
|
| 287 |
attn_weights_per_layer.append(attn)
|
| 288 |
|
|
|
|
| 13 |
Date: 2025-10-23
|
| 14 |
"""
|
| 15 |
|
| 16 |
+
from typing import List, Literal, Optional, Tuple, Union, cast
|
| 17 |
|
| 18 |
import torch
|
| 19 |
import torch.nn as nn
|
| 20 |
+
from torch.utils.checkpoint import checkpoint
|
| 21 |
|
| 22 |
# Encoder implementation
|
| 23 |
from .attention import MultiHeadAttention, T5RelativePositionBias
|
| 24 |
from .feedforward import FeedForward
|
| 25 |
from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
|
| 26 |
+
from .t5_layer_norm import T5LayerNorm
|
| 27 |
|
| 28 |
|
| 29 |
class TransformerEncoderLayer(nn.Module):
|
|
|
|
| 67 |
quantization=quantization,
|
| 68 |
)
|
| 69 |
|
| 70 |
+
self.norm1 = T5LayerNorm(d_model)
|
| 71 |
+
self.norm2 = T5LayerNorm(d_model)
|
| 72 |
|
| 73 |
self.dropout1 = nn.Dropout(dropout)
|
| 74 |
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
| 155 |
use_learned_pos_enc: bool = False,
|
| 156 |
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
|
| 157 |
use_relative_position_bias: bool = False, # T5-style relative position bias
|
| 158 |
+
gradient_checkpointing: bool = False,
|
| 159 |
):
|
| 160 |
super().__init__()
|
| 161 |
self.vocab_size = vocab_size
|
| 162 |
self.d_model = d_model
|
| 163 |
self.pad_token_id = pad_token_id
|
| 164 |
self.use_relative_position_bias = use_relative_position_bias
|
| 165 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 166 |
|
| 167 |
# Token embedding (only used if forward receives token ids)
|
| 168 |
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
|
|
|
|
| 205 |
]
|
| 206 |
)
|
| 207 |
|
| 208 |
+
# Final T5LayerNorm for Pre-LN stacks
|
| 209 |
+
self.final_norm = T5LayerNorm(d_model)
|
| 210 |
|
| 211 |
# Dropout applied after embedding + positional encoding (paper uses this)
|
| 212 |
self.input_dropout = nn.Dropout(dropout)
|
|
|
|
| 286 |
|
| 287 |
# Pass through each encoder layer (optionally collect attn)
|
| 288 |
for layer in self.layers:
|
| 289 |
+
if self.gradient_checkpointing and self.training:
|
| 290 |
+
# Gradient checkpointing requires the inputs to require grad
|
| 291 |
+
# We use a lambda to pass keyword arguments
|
| 292 |
+
def create_custom_forward(module):
|
| 293 |
+
def custom_forward(*inputs):
|
| 294 |
+
return module(*inputs, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
|
| 295 |
+
return custom_forward
|
| 296 |
+
|
| 297 |
+
x, attn = cast(
|
| 298 |
+
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
| 299 |
+
checkpoint(
|
| 300 |
+
create_custom_forward(layer),
|
| 301 |
+
x,
|
| 302 |
+
use_reentrant=False,
|
| 303 |
+
),
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
x, attn = layer(x, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
|
| 307 |
+
|
| 308 |
if collect_attn:
|
| 309 |
attn_weights_per_layer.append(attn)
|
| 310 |
|
src/models/factory.py
CHANGED
|
@@ -14,15 +14,15 @@ from __future__ import annotations
|
|
| 14 |
|
| 15 |
from dataclasses import dataclass
|
| 16 |
from pathlib import Path
|
| 17 |
-
from typing import Literal, Optional, cast
|
| 18 |
|
| 19 |
import torch
|
| 20 |
from transformers import T5ForConditionalGeneration
|
| 21 |
|
| 22 |
from ..data.tokenization import Tokenizer
|
| 23 |
from ..utils.config import load_yaml
|
| 24 |
-
from .decoder import TransformerDecoder
|
| 25 |
-
from .encoder import TransformerEncoder
|
| 26 |
from .heads import ClassificationHead, LMHead
|
| 27 |
from .multitask import MultiTaskModel
|
| 28 |
|
|
@@ -35,6 +35,7 @@ class ModelConfig:
|
|
| 35 |
"""Configuration describing the transformer architecture."""
|
| 36 |
|
| 37 |
d_model: int = 768
|
|
|
|
| 38 |
num_encoder_layers: int = 12
|
| 39 |
num_decoder_layers: int = 12
|
| 40 |
num_attention_heads: int = 12
|
|
@@ -50,6 +51,7 @@ class ModelConfig:
|
|
| 50 |
use_relative_position_bias: bool = (
|
| 51 |
False # T5-style relative position bias (use True for T5/FLAN-T5)
|
| 52 |
)
|
|
|
|
| 53 |
|
| 54 |
def __post_init__(self):
|
| 55 |
if self.d_model % self.num_attention_heads != 0:
|
|
@@ -77,6 +79,7 @@ def load_model_config(path: Optional[str | Path]) -> ModelConfig:
|
|
| 77 |
data = load_yaml(str(path)).data
|
| 78 |
return ModelConfig(
|
| 79 |
d_model=int(data.get("d_model", 512)),
|
|
|
|
| 80 |
num_encoder_layers=int(data.get("num_encoder_layers", 6)),
|
| 81 |
num_decoder_layers=int(data.get("num_decoder_layers", 6)),
|
| 82 |
num_attention_heads=int(data.get("num_attention_heads", 8)),
|
|
@@ -88,6 +91,7 @@ def load_model_config(path: Optional[str | Path]) -> ModelConfig:
|
|
| 88 |
use_learned_pos_enc=bool(data.get("use_learned_pos_enc", True)),
|
| 89 |
activation=str(data.get("activation", "gelu")),
|
| 90 |
use_relative_position_bias=bool(data.get("use_relative_position_bias", False)),
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
|
|
@@ -107,11 +111,10 @@ def _load_pretrained_weights(
|
|
| 107 |
-> We zero-initialize the bias terms
|
| 108 |
"""
|
| 109 |
print(f"Loading pretrained weights from {model_name}...")
|
| 110 |
-
t5 = T5ForConditionalGeneration.from_pretrained(model_name)
|
| 111 |
|
| 112 |
# Load shared embeddings (T5 uses shared embeddings for encoder and decoder)
|
| 113 |
# Note: T5's vocab is padded to multiple of 128 for efficiency (32100 -> 32128)
|
| 114 |
-
# Our model uses the tokenizer's actual vocab size, so we only copy the valid tokens
|
| 115 |
print("Transferring shared token embeddings...")
|
| 116 |
shared_embeddings = t5.shared.weight.data
|
| 117 |
our_vocab_size = encoder.embedding.weight.size(0)
|
|
@@ -124,6 +127,19 @@ def _load_pretrained_weights(
|
|
| 124 |
print(f" Copying first {min_vocab} token embeddings...")
|
| 125 |
encoder.embedding.weight.data[:min_vocab].copy_(shared_embeddings[:min_vocab])
|
| 126 |
decoder.embedding.weight.data[:min_vocab].copy_(shared_embeddings[:min_vocab])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
else:
|
| 128 |
encoder.embedding.weight.data.copy_(shared_embeddings)
|
| 129 |
decoder.embedding.weight.data.copy_(shared_embeddings)
|
|
@@ -136,11 +152,13 @@ def _load_pretrained_weights(
|
|
| 136 |
print("Transferring encoder weights...")
|
| 137 |
t5_encoder = t5.encoder
|
| 138 |
|
| 139 |
-
for
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# Self-attention (T5 has no bias in attention projections)
|
| 146 |
custom_layer.self_attn.W_Q.weight.data.copy_(t5_self_attn.q.weight.data)
|
|
@@ -190,7 +208,7 @@ def _load_pretrained_weights(
|
|
| 190 |
if hasattr(encoder, "relative_position_bias") and encoder.relative_position_bias is not None:
|
| 191 |
print("Transferring encoder relative position bias...")
|
| 192 |
t5_enc_rel_bias = (
|
| 193 |
-
t5_encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.data
|
| 194 |
)
|
| 195 |
encoder.relative_position_bias.relative_attention_bias.weight.data.copy_(t5_enc_rel_bias)
|
| 196 |
|
|
@@ -198,13 +216,15 @@ def _load_pretrained_weights(
|
|
| 198 |
print("Transferring decoder weights...")
|
| 199 |
t5_decoder = t5.decoder
|
| 200 |
|
| 201 |
-
for
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
| 208 |
|
| 209 |
# Self-attention
|
| 210 |
custom_layer.self_attn.W_Q.weight.data.copy_(t5_self_attn.q.weight.data)
|
|
@@ -265,7 +285,7 @@ def _load_pretrained_weights(
|
|
| 265 |
):
|
| 266 |
print("Transferring decoder self-attention relative position bias...")
|
| 267 |
t5_dec_self_rel_bias = (
|
| 268 |
-
t5_decoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.data
|
| 269 |
)
|
| 270 |
decoder.self_relative_position_bias.relative_attention_bias.weight.data.copy_(
|
| 271 |
t5_dec_self_rel_bias
|
|
@@ -278,7 +298,7 @@ def _load_pretrained_weights(
|
|
| 278 |
print("Transferring decoder cross-attention relative position bias...")
|
| 279 |
# Cross-attention relative position bias is in EncDecAttention of first block
|
| 280 |
t5_dec_cross_rel_bias = (
|
| 281 |
-
t5_decoder.block[0].layer[1].EncDecAttention.relative_attention_bias.weight.data
|
| 282 |
)
|
| 283 |
decoder.cross_relative_position_bias.relative_attention_bias.weight.data.copy_(
|
| 284 |
t5_dec_cross_rel_bias
|
|
@@ -367,9 +387,9 @@ def _load_llama_weights(
|
|
| 367 |
num_layers = min(len(encoder.layers), len(llama.model.layers))
|
| 368 |
|
| 369 |
for i in range(num_layers):
|
| 370 |
-
llama_layer = llama.model.layers[i]
|
| 371 |
-
enc_layer = encoder.layers[i]
|
| 372 |
-
dec_layer = decoder.layers[i]
|
| 373 |
|
| 374 |
# --- Self-Attention ---
|
| 375 |
# Llama: q_proj, k_proj, v_proj, o_proj
|
|
@@ -460,15 +480,19 @@ def build_multitask_model(
|
|
| 460 |
if hasattr(tokenizer, "config") and hasattr(tokenizer.config, "max_length"):
|
| 461 |
max_len = tokenizer.config.max_length
|
| 462 |
elif hasattr(tokenizer, "model_max_length"):
|
| 463 |
-
max_len = tokenizer.model_max_length
|
| 464 |
else:
|
| 465 |
max_len = 512 # Default fallback
|
| 466 |
|
| 467 |
# Cast activation to the literal type for mypy
|
| 468 |
activation = cast(ActivationType, cfg.activation)
|
| 469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
encoder = TransformerEncoder(
|
| 471 |
-
vocab_size=
|
| 472 |
d_model=cfg.d_model,
|
| 473 |
num_layers=cfg.num_encoder_layers,
|
| 474 |
num_heads=cfg.num_attention_heads,
|
|
@@ -480,9 +504,10 @@ def build_multitask_model(
|
|
| 480 |
use_learned_pos_enc=cfg.use_learned_pos_enc,
|
| 481 |
activation=activation,
|
| 482 |
use_relative_position_bias=cfg.use_relative_position_bias,
|
|
|
|
| 483 |
)
|
| 484 |
decoder = TransformerDecoder(
|
| 485 |
-
vocab_size=
|
| 486 |
d_model=cfg.d_model,
|
| 487 |
num_layers=cfg.num_decoder_layers,
|
| 488 |
num_heads=cfg.num_attention_heads,
|
|
@@ -494,6 +519,7 @@ def build_multitask_model(
|
|
| 494 |
use_learned_pos_enc=cfg.use_learned_pos_enc,
|
| 495 |
activation=activation,
|
| 496 |
use_relative_position_bias=cfg.use_relative_position_bias,
|
|
|
|
| 497 |
)
|
| 498 |
|
| 499 |
# Load pretrained weights if requested (but allow override for inference)
|
|
@@ -513,12 +539,14 @@ def build_multitask_model(
|
|
| 513 |
)
|
| 514 |
_load_pretrained_weights(encoder, decoder, cfg.pretrained_model_name)
|
| 515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
|
| 517 |
model.add_head(
|
| 518 |
"summarization",
|
| 519 |
-
LMHead(
|
| 520 |
-
d_model=cfg.d_model, vocab_size=tokenizer.vocab_size, tie_embedding=decoder.embedding
|
| 521 |
-
),
|
| 522 |
)
|
| 523 |
model.add_head(
|
| 524 |
"emotion",
|
|
|
|
| 14 |
|
| 15 |
from dataclasses import dataclass
|
| 16 |
from pathlib import Path
|
| 17 |
+
from typing import Any, Literal, Optional, cast
|
| 18 |
|
| 19 |
import torch
|
| 20 |
from transformers import T5ForConditionalGeneration
|
| 21 |
|
| 22 |
from ..data.tokenization import Tokenizer
|
| 23 |
from ..utils.config import load_yaml
|
| 24 |
+
from .decoder import TransformerDecoder, TransformerDecoderLayer
|
| 25 |
+
from .encoder import TransformerEncoder, TransformerEncoderLayer
|
| 26 |
from .heads import ClassificationHead, LMHead
|
| 27 |
from .multitask import MultiTaskModel
|
| 28 |
|
|
|
|
| 35 |
"""Configuration describing the transformer architecture."""
|
| 36 |
|
| 37 |
d_model: int = 768
|
| 38 |
+
vocab_size: Optional[int] = None # Override tokenizer vocab size (e.g., 32128 for FLAN-T5)
|
| 39 |
num_encoder_layers: int = 12
|
| 40 |
num_decoder_layers: int = 12
|
| 41 |
num_attention_heads: int = 12
|
|
|
|
| 51 |
use_relative_position_bias: bool = (
|
| 52 |
False # T5-style relative position bias (use True for T5/FLAN-T5)
|
| 53 |
)
|
| 54 |
+
gradient_checkpointing: bool = False
|
| 55 |
|
| 56 |
def __post_init__(self):
|
| 57 |
if self.d_model % self.num_attention_heads != 0:
|
|
|
|
| 79 |
data = load_yaml(str(path)).data
|
| 80 |
return ModelConfig(
|
| 81 |
d_model=int(data.get("d_model", 512)),
|
| 82 |
+
vocab_size=data.get("vocab_size", None), # Optional vocab size override
|
| 83 |
num_encoder_layers=int(data.get("num_encoder_layers", 6)),
|
| 84 |
num_decoder_layers=int(data.get("num_decoder_layers", 6)),
|
| 85 |
num_attention_heads=int(data.get("num_attention_heads", 8)),
|
|
|
|
| 91 |
use_learned_pos_enc=bool(data.get("use_learned_pos_enc", True)),
|
| 92 |
activation=str(data.get("activation", "gelu")),
|
| 93 |
use_relative_position_bias=bool(data.get("use_relative_position_bias", False)),
|
| 94 |
+
gradient_checkpointing=bool(data.get("gradient_checkpointing", False)),
|
| 95 |
)
|
| 96 |
|
| 97 |
|
|
|
|
| 111 |
-> We zero-initialize the bias terms
|
| 112 |
"""
|
| 113 |
print(f"Loading pretrained weights from {model_name}...")
|
| 114 |
+
t5 = T5ForConditionalGeneration.from_pretrained(model_name) # type: ignore[attr-defined]
|
| 115 |
|
| 116 |
# Load shared embeddings (T5 uses shared embeddings for encoder and decoder)
|
| 117 |
# Note: T5's vocab is padded to multiple of 128 for efficiency (32100 -> 32128)
|
|
|
|
| 118 |
print("Transferring shared token embeddings...")
|
| 119 |
shared_embeddings = t5.shared.weight.data
|
| 120 |
our_vocab_size = encoder.embedding.weight.size(0)
|
|
|
|
| 127 |
print(f" Copying first {min_vocab} token embeddings...")
|
| 128 |
encoder.embedding.weight.data[:min_vocab].copy_(shared_embeddings[:min_vocab])
|
| 129 |
decoder.embedding.weight.data[:min_vocab].copy_(shared_embeddings[:min_vocab])
|
| 130 |
+
|
| 131 |
+
# Initialize any extra tokens (e.g., tokens 32100-32127) with small random values
|
| 132 |
+
if our_vocab_size > t5_vocab_size:
|
| 133 |
+
print(
|
| 134 |
+
f" Initializing {our_vocab_size - t5_vocab_size} extra padding tokens with small values..."
|
| 135 |
+
)
|
| 136 |
+
# Use small random initialization for stability (mean of existing embeddings ± small noise)
|
| 137 |
+
mean_emb = shared_embeddings.mean(dim=0, keepdim=True)
|
| 138 |
+
encoder.embedding.weight.data[t5_vocab_size:].normal_(mean=0.0, std=0.02)
|
| 139 |
+
encoder.embedding.weight.data[t5_vocab_size:] += mean_emb
|
| 140 |
+
decoder.embedding.weight.data[t5_vocab_size:].copy_(
|
| 141 |
+
encoder.embedding.weight.data[t5_vocab_size:]
|
| 142 |
+
)
|
| 143 |
else:
|
| 144 |
encoder.embedding.weight.data.copy_(shared_embeddings)
|
| 145 |
decoder.embedding.weight.data.copy_(shared_embeddings)
|
|
|
|
| 152 |
print("Transferring encoder weights...")
|
| 153 |
t5_encoder = t5.encoder
|
| 154 |
|
| 155 |
+
for custom_layer_untyped, t5_layer in zip(encoder.layers, t5_encoder.block, strict=False):
|
| 156 |
+
custom_layer = cast(TransformerEncoderLayer, custom_layer_untyped)
|
| 157 |
+
t5_block = cast(Any, t5_layer)
|
| 158 |
+
t5_self_attn = t5_block.layer[0].SelfAttention
|
| 159 |
+
t5_ffn = t5_block.layer[1].DenseReluDense
|
| 160 |
+
t5_norm1 = t5_block.layer[0].layer_norm
|
| 161 |
+
t5_norm2 = t5_block.layer[1].layer_norm
|
| 162 |
|
| 163 |
# Self-attention (T5 has no bias in attention projections)
|
| 164 |
custom_layer.self_attn.W_Q.weight.data.copy_(t5_self_attn.q.weight.data)
|
|
|
|
| 208 |
if hasattr(encoder, "relative_position_bias") and encoder.relative_position_bias is not None:
|
| 209 |
print("Transferring encoder relative position bias...")
|
| 210 |
t5_enc_rel_bias = (
|
| 211 |
+
cast(Any, t5_encoder.block[0]).layer[0].SelfAttention.relative_attention_bias.weight.data
|
| 212 |
)
|
| 213 |
encoder.relative_position_bias.relative_attention_bias.weight.data.copy_(t5_enc_rel_bias)
|
| 214 |
|
|
|
|
| 216 |
print("Transferring decoder weights...")
|
| 217 |
t5_decoder = t5.decoder
|
| 218 |
|
| 219 |
+
for custom_layer_untyped, t5_layer in zip(decoder.layers, t5_decoder.block, strict=False):
|
| 220 |
+
custom_layer = cast(TransformerDecoderLayer, custom_layer_untyped)
|
| 221 |
+
t5_block = cast(Any, t5_layer)
|
| 222 |
+
t5_self_attn = t5_block.layer[0].SelfAttention
|
| 223 |
+
t5_cross_attn = t5_block.layer[1].EncDecAttention
|
| 224 |
+
t5_ffn = t5_block.layer[2].DenseReluDense
|
| 225 |
+
t5_norm1 = t5_block.layer[0].layer_norm
|
| 226 |
+
t5_norm2 = t5_block.layer[1].layer_norm
|
| 227 |
+
t5_norm3 = t5_block.layer[2].layer_norm
|
| 228 |
|
| 229 |
# Self-attention
|
| 230 |
custom_layer.self_attn.W_Q.weight.data.copy_(t5_self_attn.q.weight.data)
|
|
|
|
| 285 |
):
|
| 286 |
print("Transferring decoder self-attention relative position bias...")
|
| 287 |
t5_dec_self_rel_bias = (
|
| 288 |
+
cast(Any, t5_decoder.block[0]).layer[0].SelfAttention.relative_attention_bias.weight.data
|
| 289 |
)
|
| 290 |
decoder.self_relative_position_bias.relative_attention_bias.weight.data.copy_(
|
| 291 |
t5_dec_self_rel_bias
|
|
|
|
| 298 |
print("Transferring decoder cross-attention relative position bias...")
|
| 299 |
# Cross-attention relative position bias is in EncDecAttention of first block
|
| 300 |
t5_dec_cross_rel_bias = (
|
| 301 |
+
cast(Any, t5_decoder.block[0]).layer[1].EncDecAttention.relative_attention_bias.weight.data
|
| 302 |
)
|
| 303 |
decoder.cross_relative_position_bias.relative_attention_bias.weight.data.copy_(
|
| 304 |
t5_dec_cross_rel_bias
|
|
|
|
| 387 |
num_layers = min(len(encoder.layers), len(llama.model.layers))
|
| 388 |
|
| 389 |
for i in range(num_layers):
|
| 390 |
+
llama_layer = cast(Any, llama.model.layers[i])
|
| 391 |
+
enc_layer = cast(TransformerEncoderLayer, encoder.layers[i])
|
| 392 |
+
dec_layer = cast(TransformerDecoderLayer, decoder.layers[i])
|
| 393 |
|
| 394 |
# --- Self-Attention ---
|
| 395 |
# Llama: q_proj, k_proj, v_proj, o_proj
|
|
|
|
| 480 |
if hasattr(tokenizer, "config") and hasattr(tokenizer.config, "max_length"):
|
| 481 |
max_len = tokenizer.config.max_length
|
| 482 |
elif hasattr(tokenizer, "model_max_length"):
|
| 483 |
+
max_len = cast(Any, tokenizer).model_max_length
|
| 484 |
else:
|
| 485 |
max_len = 512 # Default fallback
|
| 486 |
|
| 487 |
# Cast activation to the literal type for mypy
|
| 488 |
activation = cast(ActivationType, cfg.activation)
|
| 489 |
|
| 490 |
+
# Use cfg.vocab_size (32128) instead of tokenizer.vocab_size (32100)
|
| 491 |
+
# to match FLAN-T5's padded vocabulary
|
| 492 |
+
vocab_size = cfg.vocab_size if cfg.vocab_size is not None else tokenizer.vocab_size
|
| 493 |
+
|
| 494 |
encoder = TransformerEncoder(
|
| 495 |
+
vocab_size=vocab_size,
|
| 496 |
d_model=cfg.d_model,
|
| 497 |
num_layers=cfg.num_encoder_layers,
|
| 498 |
num_heads=cfg.num_attention_heads,
|
|
|
|
| 504 |
use_learned_pos_enc=cfg.use_learned_pos_enc,
|
| 505 |
activation=activation,
|
| 506 |
use_relative_position_bias=cfg.use_relative_position_bias,
|
| 507 |
+
gradient_checkpointing=cfg.gradient_checkpointing,
|
| 508 |
)
|
| 509 |
decoder = TransformerDecoder(
|
| 510 |
+
vocab_size=vocab_size,
|
| 511 |
d_model=cfg.d_model,
|
| 512 |
num_layers=cfg.num_decoder_layers,
|
| 513 |
num_heads=cfg.num_attention_heads,
|
|
|
|
| 519 |
use_learned_pos_enc=cfg.use_learned_pos_enc,
|
| 520 |
activation=activation,
|
| 521 |
use_relative_position_bias=cfg.use_relative_position_bias,
|
| 522 |
+
gradient_checkpointing=cfg.gradient_checkpointing,
|
| 523 |
)
|
| 524 |
|
| 525 |
# Load pretrained weights if requested (but allow override for inference)
|
|
|
|
| 539 |
)
|
| 540 |
_load_pretrained_weights(encoder, decoder, cfg.pretrained_model_name)
|
| 541 |
|
| 542 |
+
# T5 uses separate embeddings and lm_head (tie_word_embeddings=False)
|
| 543 |
+
# Both are initialized from pretrained weights if use_pretrained=True
|
| 544 |
+
# We do NOT tie them here - they remain independent for better flexibility
|
| 545 |
+
|
| 546 |
model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
|
| 547 |
model.add_head(
|
| 548 |
"summarization",
|
| 549 |
+
LMHead(d_model=cfg.d_model, vocab_size=vocab_size, tie_embedding=decoder.embedding),
|
|
|
|
|
|
|
| 550 |
)
|
| 551 |
model.add_head(
|
| 552 |
"emotion",
|
src/models/t5_layer_norm.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""T5-style Layer Normalization (RMSNorm without mean centering).
|
| 2 |
+
|
| 3 |
+
T5 uses a variant of RMSNorm that does NOT subtract the mean.
|
| 4 |
+
This is critical for matching T5's behavior.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class T5LayerNorm(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
T5-style layer normalization without mean centering.
|
| 14 |
+
|
| 15 |
+
This is similar to RMSNorm but does NOT subtract the mean from x.
|
| 16 |
+
Formula: output = x / sqrt(mean(x^2) + eps) * weight
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
normalized_shape: Input shape (typically d_model)
|
| 20 |
+
eps: Small constant for numerical stability
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, normalized_shape: int, eps: float = 1e-6):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 26 |
+
self.variance_epsilon = eps
|
| 27 |
+
|
| 28 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
"""
|
| 30 |
+
Args:
|
| 31 |
+
hidden_states: (*, normalized_shape)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Normalized tensor of same shape
|
| 35 |
+
"""
|
| 36 |
+
# T5 uses variance = mean(x^2), does NOT subtract mean
|
| 37 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 38 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 39 |
+
|
| 40 |
+
# Scale by learned weight (no bias in T5)
|
| 41 |
+
return self.weight * hidden_states
|
src/training/early_stopping.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Early stopping implementation for training.
|
| 2 |
+
|
| 3 |
+
Author: Oliver Perrin
|
| 4 |
+
Date: December 2025
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class EarlyStopping:
|
| 9 |
+
"""Stop training when validation loss stops improving.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
patience: Number of epochs to wait before stopping
|
| 13 |
+
min_delta: Minimum change to qualify as improvement
|
| 14 |
+
mode: 'min' for loss (lower is better), 'max' for accuracy
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
patience: int = 3,
|
| 20 |
+
min_delta: float = 0.001,
|
| 21 |
+
mode: str = "min"
|
| 22 |
+
):
|
| 23 |
+
self.patience = patience
|
| 24 |
+
self.min_delta = min_delta
|
| 25 |
+
self.mode = mode
|
| 26 |
+
self.counter = 0
|
| 27 |
+
self.best_value = float('inf') if mode == 'min' else float('-inf')
|
| 28 |
+
self.early_stop = False
|
| 29 |
+
|
| 30 |
+
def __call__(self, metric_value: float) -> bool:
|
| 31 |
+
"""Check if training should stop.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
metric_value: Current metric value (e.g., validation loss)
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
True if training should stop, False otherwise
|
| 38 |
+
"""
|
| 39 |
+
if self.mode == 'min':
|
| 40 |
+
improved = metric_value < (self.best_value - self.min_delta)
|
| 41 |
+
else:
|
| 42 |
+
improved = metric_value > (self.best_value + self.min_delta)
|
| 43 |
+
|
| 44 |
+
if improved:
|
| 45 |
+
self.best_value = metric_value
|
| 46 |
+
self.counter = 0
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
self.counter += 1
|
| 50 |
+
if self.counter >= self.patience:
|
| 51 |
+
self.early_stop = True
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
def reset(self):
|
| 57 |
+
"""Reset early stopping state."""
|
| 58 |
+
self.counter = 0
|
| 59 |
+
self.best_value = float('inf') if self.mode == 'min' else float('-inf')
|
| 60 |
+
self.early_stop = False
|
src/training/gradient_monitor.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradient monitoring utilities.
|
| 2 |
+
|
| 3 |
+
Author: Oliver Perrin
|
| 4 |
+
Date: December 2025
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GradientMonitor:
|
| 14 |
+
"""Monitor gradient statistics during training.
|
| 15 |
+
|
| 16 |
+
Tracks gradient norms, helps detect gradient issues like vanishing/exploding.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, model: nn.Module, log_frequency: int = 100):
|
| 20 |
+
"""Initialize gradient monitor.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
model: Model to monitor
|
| 24 |
+
log_frequency: Log gradients every N steps
|
| 25 |
+
"""
|
| 26 |
+
self.model = model
|
| 27 |
+
self.log_frequency = log_frequency
|
| 28 |
+
self.step_count = 0
|
| 29 |
+
|
| 30 |
+
def compute_grad_norm(self) -> Dict[str, float]:
|
| 31 |
+
"""Compute gradient norm statistics.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Dictionary with gradient statistics
|
| 35 |
+
"""
|
| 36 |
+
total_norm = 0.0
|
| 37 |
+
max_norm = 0.0
|
| 38 |
+
num_params = 0
|
| 39 |
+
|
| 40 |
+
for p in self.model.parameters():
|
| 41 |
+
if p.grad is not None:
|
| 42 |
+
param_norm = p.grad.data.norm(2).item()
|
| 43 |
+
total_norm += param_norm ** 2
|
| 44 |
+
max_norm = max(max_norm, param_norm)
|
| 45 |
+
num_params += 1
|
| 46 |
+
|
| 47 |
+
total_norm = total_norm ** 0.5
|
| 48 |
+
|
| 49 |
+
return {
|
| 50 |
+
"grad_norm": total_norm,
|
| 51 |
+
"grad_norm_max": max_norm,
|
| 52 |
+
"num_params_with_grad": num_params,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def check_gradients(self) -> Dict[str, int]:
|
| 56 |
+
"""Check for gradient issues (NaN, Inf, zero).
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Dictionary with counts of gradient issues
|
| 60 |
+
"""
|
| 61 |
+
nan_count = 0
|
| 62 |
+
inf_count = 0
|
| 63 |
+
zero_count = 0
|
| 64 |
+
|
| 65 |
+
for p in self.model.parameters():
|
| 66 |
+
if p.grad is not None:
|
| 67 |
+
if torch.isnan(p.grad).any():
|
| 68 |
+
nan_count += 1
|
| 69 |
+
if torch.isinf(p.grad).any():
|
| 70 |
+
inf_count += 1
|
| 71 |
+
if (p.grad == 0).all():
|
| 72 |
+
zero_count += 1
|
| 73 |
+
|
| 74 |
+
return {
|
| 75 |
+
"nan_grads": nan_count,
|
| 76 |
+
"inf_grads": inf_count,
|
| 77 |
+
"zero_grads": zero_count,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def log_gradients(self, step: Optional[int] = None) -> Optional[Dict[str, float]]:
|
| 81 |
+
"""Log gradient statistics if it's time.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
step: Current training step (uses internal counter if None)
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Gradient statistics if logged, None otherwise
|
| 88 |
+
"""
|
| 89 |
+
if step is None:
|
| 90 |
+
step = self.step_count
|
| 91 |
+
self.step_count += 1
|
| 92 |
+
|
| 93 |
+
if step % self.log_frequency == 0:
|
| 94 |
+
stats = self.compute_grad_norm()
|
| 95 |
+
issues = self.check_gradients()
|
| 96 |
+
|
| 97 |
+
# Combine stats
|
| 98 |
+
all_stats = {**stats, **issues}
|
| 99 |
+
|
| 100 |
+
return all_stats
|
| 101 |
+
|
| 102 |
+
return None
|
src/training/safe_compile.py
CHANGED
|
@@ -1,86 +1,52 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Safe torch.compile configuration that prevents NaN issues.
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
|
| 8 |
import torch
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def compile_model_safe(
|
| 12 |
model: torch.nn.Module,
|
| 13 |
mode: str = "default",
|
|
|
|
| 14 |
) -> torch.nn.Module:
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
-
Compile model with inductor backend and safety guardrails.
|
| 17 |
|
| 18 |
-
|
| 19 |
-
CUDA graphs (reduce-overhead mode) don't work with dynamic shapes or
|
| 20 |
-
shared embeddings like in T5.
|
| 21 |
|
| 22 |
-
Args:
|
| 23 |
-
model: Model to compile
|
| 24 |
-
mode: Compilation mode ("default" recommended, avoid "reduce-overhead")
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
# Explicitly disable CUDA graphs
|
| 48 |
-
if hasattr(cfg, "triton"):
|
| 49 |
-
if hasattr(cfg.triton, "cudagraphs"):
|
| 50 |
-
cfg.triton.cudagraphs = False
|
| 51 |
-
if hasattr(cfg.triton, "max_autotune_gemm"):
|
| 52 |
-
cfg.triton.max_autotune_gemm = False
|
| 53 |
-
|
| 54 |
-
# Compile with inductor (no CUDA graphs)
|
| 55 |
-
compiled = torch.compile(model, mode=mode, fullgraph=False, dynamic=True)
|
| 56 |
-
print(f"✓ Compiled with inductor ({mode} mode)")
|
| 57 |
-
return compiled
|
| 58 |
-
|
| 59 |
-
except Exception as e:
|
| 60 |
-
print(f"⚠ Inductor compilation failed: {e}")
|
| 61 |
-
print(" Falling back to aot_eager")
|
| 62 |
-
try:
|
| 63 |
-
return torch.compile(model, backend="aot_eager")
|
| 64 |
-
except Exception:
|
| 65 |
-
print(" Using uncompiled model")
|
| 66 |
-
return model
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def apply_safe_config():
|
| 70 |
-
"""Apply safe configuration to torch._inductor before any compilation."""
|
| 71 |
-
if hasattr(torch, "_inductor"):
|
| 72 |
-
cfg = torch._inductor.config
|
| 73 |
-
if hasattr(cfg, "epilogue_fusion"):
|
| 74 |
-
cfg.epilogue_fusion = False
|
| 75 |
-
if hasattr(cfg, "coordinate_descent_tuning"):
|
| 76 |
-
cfg.coordinate_descent_tuning = False
|
| 77 |
-
if hasattr(cfg, "triton"):
|
| 78 |
-
if hasattr(cfg.triton, "cudagraphs"):
|
| 79 |
-
cfg.triton.cudagraphs = False
|
| 80 |
-
if hasattr(cfg.triton, "max_autotune_gemm"):
|
| 81 |
-
cfg.triton.max_autotune_gemm = False
|
| 82 |
-
|
| 83 |
-
# Dynamo config for stability
|
| 84 |
-
torch._dynamo.config.suppress_errors = True
|
| 85 |
-
torch._dynamo.config.cache_size_limit = 64
|
| 86 |
print("✓ Applied safe inductor configuration")
|
|
|
|
| 1 |
+
"""Safe defaults for `torch.compile` to reduce instability in tests and training."""
|
|
|
|
| 2 |
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
|
| 7 |
import torch
|
| 8 |
|
| 9 |
|
| 10 |
+
def _set_attr(obj: object, name: str, value: Any) -> None:
|
| 11 |
+
"""Set attribute on dynamic objects only if it exists (keeps static checkers quiet)."""
|
| 12 |
+
|
| 13 |
+
target = getattr(obj, name, None)
|
| 14 |
+
if target is not None:
|
| 15 |
+
setattr(obj, name, value)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
def compile_model_safe(
|
| 19 |
model: torch.nn.Module,
|
| 20 |
mode: str = "default",
|
| 21 |
+
dynamic: bool | None = None,
|
| 22 |
) -> torch.nn.Module:
|
| 23 |
+
"""Safely compile model with inductor backend.
|
| 24 |
+
|
| 25 |
+
Parameters mirror `torch.compile` but default to conservative settings.
|
| 26 |
"""
|
|
|
|
| 27 |
|
| 28 |
+
return torch.compile(model, backend="inductor", mode=mode, dynamic=dynamic)
|
|
|
|
|
|
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
def apply_safe_config() -> None:
|
| 32 |
+
"""Apply conservative torch._inductor and torch._dynamo settings if present."""
|
| 33 |
+
|
| 34 |
+
inductor = getattr(torch, "_inductor", None)
|
| 35 |
+
cfg = getattr(inductor, "config", None) if inductor is not None else None
|
| 36 |
+
|
| 37 |
+
if cfg is not None:
|
| 38 |
+
_set_attr(cfg, "epilogue_fusion", False)
|
| 39 |
+
_set_attr(cfg, "coordinate_descent_tuning", False)
|
| 40 |
+
triton_cfg = getattr(cfg, "triton", None)
|
| 41 |
+
if triton_cfg is not None:
|
| 42 |
+
_set_attr(triton_cfg, "cudagraphs", False)
|
| 43 |
+
_set_attr(triton_cfg, "max_autotune_gemm", False)
|
| 44 |
+
|
| 45 |
+
dynamo_cfg = getattr(torch, "_dynamo", None)
|
| 46 |
+
if dynamo_cfg is not None:
|
| 47 |
+
dyn_config = getattr(dynamo_cfg, "config", None)
|
| 48 |
+
if dyn_config is not None:
|
| 49 |
+
_set_attr(dyn_config, "suppress_errors", True)
|
| 50 |
+
_set_attr(dyn_config, "cache_size_limit", 64)
|
| 51 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
print("✓ Applied safe inductor configuration")
|
src/training/trainer.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
Multi-task Trainer for LexiMind.
|
| 3 |
|
| 4 |
Handles training across summarization, emotion, and topic heads with mixed-precision,
|
| 5 |
-
gradient accumulation, and MLflow logging.
|
| 6 |
|
| 7 |
Author: Oliver Perrin
|
| 8 |
Date: December 2025
|
|
@@ -10,6 +10,7 @@ Date: December 2025
|
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
|
|
|
| 13 |
import sys
|
| 14 |
import time
|
| 15 |
from collections import defaultdict
|
|
@@ -19,13 +20,36 @@ from typing import Any, Callable, Dict, List
|
|
| 19 |
import mlflow
|
| 20 |
import torch
|
| 21 |
import torch.nn.functional as F
|
|
|
|
| 22 |
from torch.utils.data import DataLoader
|
| 23 |
from tqdm import tqdm
|
| 24 |
|
| 25 |
from ..data.tokenization import Tokenizer
|
|
|
|
|
|
|
| 26 |
from .metrics import accuracy, multilabel_f1, rouge_like
|
| 27 |
from .nan_debugger import NaNDetector
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# --------------- Configuration ---------------
|
| 30 |
|
| 31 |
|
|
@@ -42,6 +66,15 @@ class TrainerConfig:
|
|
| 42 |
experiment_name: str = "LexiMind"
|
| 43 |
run_name: str | None = None
|
| 44 |
gradient_accumulation_steps: int = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
# --------------- Trainer ---------------
|
|
@@ -61,6 +94,8 @@ class Trainer:
|
|
| 61 |
self.config = config
|
| 62 |
self.device = device
|
| 63 |
self.tokenizer = tokenizer
|
|
|
|
|
|
|
| 64 |
|
| 65 |
# Task losses
|
| 66 |
self.emotion_loss = torch.nn.BCEWithLogitsLoss()
|
|
@@ -76,6 +111,18 @@ class Trainer:
|
|
| 76 |
self.nan_skip_count = 0
|
| 77 |
self.max_nan_skips = 50
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# Track current step for debugging
|
| 80 |
self._current_step = 0
|
| 81 |
|
|
@@ -87,6 +134,46 @@ class Trainer:
|
|
| 87 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 88 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
# --------------- Training Loop ---------------
|
| 91 |
|
| 92 |
def fit(
|
|
@@ -94,17 +181,24 @@ class Trainer:
|
|
| 94 |
train_loaders: Dict[str, DataLoader],
|
| 95 |
val_loaders: Dict[str, DataLoader] | None = None,
|
| 96 |
checkpoint_callback: Callable | None = None,
|
|
|
|
| 97 |
) -> Dict[str, Dict[str, float]]:
|
| 98 |
"""Train model across all tasks with progress tracking."""
|
| 99 |
history: Dict[str, Dict[str, float]] = {}
|
| 100 |
total_start = time.perf_counter()
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
with mlflow.start_run(run_name=self.config.run_name):
|
| 103 |
self._log_config()
|
| 104 |
|
| 105 |
# Epoch progress bar
|
| 106 |
epoch_pbar = tqdm(
|
| 107 |
-
range(
|
| 108 |
desc="Training",
|
| 109 |
unit="epoch",
|
| 110 |
position=0,
|
|
@@ -129,6 +223,15 @@ class Trainer:
|
|
| 129 |
if "summarization" in val_loaders:
|
| 130 |
self._validate_generation(val_loaders["summarization"], epoch)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
# Checkpoint
|
| 133 |
if checkpoint_callback:
|
| 134 |
checkpoint_callback(epoch, self.model, history)
|
|
@@ -256,7 +359,19 @@ class Trainer:
|
|
| 256 |
return averaged
|
| 257 |
|
| 258 |
def _optimizer_step(self) -> None:
|
| 259 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
# Check gradients for NaN/Inf BEFORE clipping
|
| 261 |
nan_grad = self.nan_detector.check_gradients(self._current_step)
|
| 262 |
if nan_grad is not None:
|
|
@@ -280,6 +395,14 @@ class Trainer:
|
|
| 280 |
|
| 281 |
self.optimizer.zero_grad()
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
# Check parameters for NaN AFTER update
|
| 284 |
nan_param = self.nan_detector.check_parameters(self._current_step)
|
| 285 |
if nan_param is not None:
|
|
@@ -287,6 +410,31 @@ class Trainer:
|
|
| 287 |
f"NaN in parameter {nan_param} after optimizer step at step {self._current_step}!"
|
| 288 |
)
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
def _get_batch(
|
| 291 |
self, iterators: Dict, loader: DataLoader, task: str
|
| 292 |
) -> Dict[str, torch.Tensor] | None:
|
|
@@ -341,6 +489,8 @@ class Trainer:
|
|
| 341 |
inputs["src_mask"] = batch["src_mask"]
|
| 342 |
|
| 343 |
logits = self.model.forward("summarization", inputs)
|
|
|
|
|
|
|
| 344 |
loss = F.cross_entropy(
|
| 345 |
logits.view(-1, logits.size(-1)),
|
| 346 |
batch["labels"].view(-1),
|
|
@@ -348,6 +498,11 @@ class Trainer:
|
|
| 348 |
label_smoothing=self.config.label_smoothing,
|
| 349 |
)
|
| 350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
# Quick ROUGE estimate
|
| 352 |
preds = self.tokenizer.decode_batch(logits.argmax(dim=-1).tolist())
|
| 353 |
refs = self._decode_labels(batch["labels"])
|
|
|
|
| 2 |
Multi-task Trainer for LexiMind.
|
| 3 |
|
| 4 |
Handles training across summarization, emotion, and topic heads with mixed-precision,
|
| 5 |
+
gradient accumulation, gradient monitoring, early stopping, and MLflow logging.
|
| 6 |
|
| 7 |
Author: Oliver Perrin
|
| 8 |
Date: December 2025
|
|
|
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
+
import math
|
| 14 |
import sys
|
| 15 |
import time
|
| 16 |
from collections import defaultdict
|
|
|
|
| 20 |
import mlflow
|
| 21 |
import torch
|
| 22 |
import torch.nn.functional as F
|
| 23 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 24 |
from torch.utils.data import DataLoader
|
| 25 |
from tqdm import tqdm
|
| 26 |
|
| 27 |
from ..data.tokenization import Tokenizer
|
| 28 |
+
from .early_stopping import EarlyStopping
|
| 29 |
+
from .gradient_monitor import GradientMonitor
|
| 30 |
from .metrics import accuracy, multilabel_f1, rouge_like
|
| 31 |
from .nan_debugger import NaNDetector
|
| 32 |
|
| 33 |
+
|
| 34 |
+
def _get_cosine_schedule_with_warmup(
|
| 35 |
+
optimizer: torch.optim.Optimizer,
|
| 36 |
+
num_warmup_steps: int,
|
| 37 |
+
num_training_steps: int,
|
| 38 |
+
min_lr_ratio: float = 0.1,
|
| 39 |
+
) -> LambdaLR:
|
| 40 |
+
"""Create cosine LR schedule with linear warmup."""
|
| 41 |
+
|
| 42 |
+
def lr_lambda(current_step: int) -> float:
|
| 43 |
+
if current_step < num_warmup_steps:
|
| 44 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 45 |
+
progress = float(current_step - num_warmup_steps) / float(
|
| 46 |
+
max(1, num_training_steps - num_warmup_steps)
|
| 47 |
+
)
|
| 48 |
+
return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
| 49 |
+
|
| 50 |
+
return LambdaLR(optimizer, lr_lambda)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
# --------------- Configuration ---------------
|
| 54 |
|
| 55 |
|
|
|
|
| 66 |
experiment_name: str = "LexiMind"
|
| 67 |
run_name: str | None = None
|
| 68 |
gradient_accumulation_steps: int = 1
|
| 69 |
+
# Learning rate scheduler
|
| 70 |
+
scheduler_type: str = "cosine" # "cosine", "linear", or "constant"
|
| 71 |
+
warmup_steps: int = 0
|
| 72 |
+
num_training_steps: int = 0 # Set automatically if 0
|
| 73 |
+
# Early stopping
|
| 74 |
+
early_stopping_patience: int | None = None # None = disabled
|
| 75 |
+
early_stopping_min_delta: float = 0.001
|
| 76 |
+
# Gradient monitoring
|
| 77 |
+
log_grad_norm_frequency: int = 100 # Log gradient norms every N steps
|
| 78 |
|
| 79 |
|
| 80 |
# --------------- Trainer ---------------
|
|
|
|
| 94 |
self.config = config
|
| 95 |
self.device = device
|
| 96 |
self.tokenizer = tokenizer
|
| 97 |
+
self.scheduler: LambdaLR | None = None # Set in fit()
|
| 98 |
+
self.global_step = 0 # Track global step for scheduler
|
| 99 |
|
| 100 |
# Task losses
|
| 101 |
self.emotion_loss = torch.nn.BCEWithLogitsLoss()
|
|
|
|
| 111 |
self.nan_skip_count = 0
|
| 112 |
self.max_nan_skips = 50
|
| 113 |
|
| 114 |
+
# Gradient monitoring
|
| 115 |
+
self.grad_monitor = GradientMonitor(model, log_frequency=config.log_grad_norm_frequency)
|
| 116 |
+
|
| 117 |
+
# Early stopping
|
| 118 |
+
self.early_stopping: EarlyStopping | None = None
|
| 119 |
+
if config.early_stopping_patience is not None:
|
| 120 |
+
self.early_stopping = EarlyStopping(
|
| 121 |
+
patience=config.early_stopping_patience,
|
| 122 |
+
min_delta=config.early_stopping_min_delta,
|
| 123 |
+
mode="min" # Lower loss is better
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
# Track current step for debugging
|
| 127 |
self._current_step = 0
|
| 128 |
|
|
|
|
| 134 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 135 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 136 |
|
| 137 |
+
def _setup_scheduler(self, train_loaders: Dict[str, DataLoader], start_epoch: int = 1) -> None:
|
| 138 |
+
"""Initialize learning rate scheduler based on config."""
|
| 139 |
+
# Calculate steps per epoch once
|
| 140 |
+
max_batches = max(len(loader) for loader in train_loaders.values())
|
| 141 |
+
self.steps_per_epoch = max_batches // max(1, self.config.gradient_accumulation_steps)
|
| 142 |
+
|
| 143 |
+
if self.config.scheduler_type == "constant":
|
| 144 |
+
return # No scheduler needed
|
| 145 |
+
|
| 146 |
+
# Some tests pass a MagicMock optimizer without param_groups; skip scheduler gracefully
|
| 147 |
+
try:
|
| 148 |
+
_ = self.optimizer.param_groups # type: ignore[attr-defined]
|
| 149 |
+
except AttributeError:
|
| 150 |
+
self.scheduler = None
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
# Calculate total training steps
|
| 154 |
+
epochs_remaining = max(0, self.config.max_epochs - (start_epoch - 1))
|
| 155 |
+
num_training_steps = self.config.num_training_steps or (
|
| 156 |
+
self.steps_per_epoch * epochs_remaining
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
warmup_steps = self.config.warmup_steps
|
| 160 |
+
print(
|
| 161 |
+
f"✓ LR Scheduler: {self.config.scheduler_type} with {warmup_steps} warmup steps, {num_training_steps} total steps"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if self.config.scheduler_type == "cosine":
|
| 165 |
+
self.scheduler = _get_cosine_schedule_with_warmup(
|
| 166 |
+
self.optimizer, warmup_steps, num_training_steps
|
| 167 |
+
)
|
| 168 |
+
elif self.config.scheduler_type == "linear":
|
| 169 |
+
|
| 170 |
+
def linear_decay(step: int) -> float:
|
| 171 |
+
if step < warmup_steps:
|
| 172 |
+
return float(step) / float(max(1, warmup_steps))
|
| 173 |
+
return max(0.0, 1.0 - (step - warmup_steps) / (num_training_steps - warmup_steps))
|
| 174 |
+
|
| 175 |
+
self.scheduler = LambdaLR(self.optimizer, linear_decay)
|
| 176 |
+
|
| 177 |
# --------------- Training Loop ---------------
|
| 178 |
|
| 179 |
def fit(
|
|
|
|
| 181 |
train_loaders: Dict[str, DataLoader],
|
| 182 |
val_loaders: Dict[str, DataLoader] | None = None,
|
| 183 |
checkpoint_callback: Callable | None = None,
|
| 184 |
+
start_epoch: int = 1,
|
| 185 |
) -> Dict[str, Dict[str, float]]:
|
| 186 |
"""Train model across all tasks with progress tracking."""
|
| 187 |
history: Dict[str, Dict[str, float]] = {}
|
| 188 |
total_start = time.perf_counter()
|
| 189 |
|
| 190 |
+
# Setup LR scheduler
|
| 191 |
+
self._setup_scheduler(train_loaders, start_epoch=start_epoch)
|
| 192 |
+
# Initialize global_step to reflect completed epochs when resuming
|
| 193 |
+
if hasattr(self, "steps_per_epoch"):
|
| 194 |
+
self.global_step = max(0, (start_epoch - 1) * self.steps_per_epoch)
|
| 195 |
+
|
| 196 |
with mlflow.start_run(run_name=self.config.run_name):
|
| 197 |
self._log_config()
|
| 198 |
|
| 199 |
# Epoch progress bar
|
| 200 |
epoch_pbar = tqdm(
|
| 201 |
+
range(start_epoch, self.config.max_epochs + 1),
|
| 202 |
desc="Training",
|
| 203 |
unit="epoch",
|
| 204 |
position=0,
|
|
|
|
| 223 |
if "summarization" in val_loaders:
|
| 224 |
self._validate_generation(val_loaders["summarization"], epoch)
|
| 225 |
|
| 226 |
+
# Early stopping check
|
| 227 |
+
if self.early_stopping is not None:
|
| 228 |
+
val_loss = val_metrics.get("total_loss", val_metrics.get("summarization_loss", float('inf')))
|
| 229 |
+
if self.early_stopping(val_loss):
|
| 230 |
+
tqdm.write(f"\n⚠ Early stopping triggered at epoch {epoch}")
|
| 231 |
+
tqdm.write(f" Best validation loss: {self.early_stopping.best_value:.4f}")
|
| 232 |
+
tqdm.write(f" Patience exhausted ({self.early_stopping.patience} epochs)")
|
| 233 |
+
break
|
| 234 |
+
|
| 235 |
# Checkpoint
|
| 236 |
if checkpoint_callback:
|
| 237 |
checkpoint_callback(epoch, self.model, history)
|
|
|
|
| 359 |
return averaged
|
| 360 |
|
| 361 |
def _optimizer_step(self) -> None:
|
| 362 |
+
"""Perform optimizer step with gradient clipping."""
|
| 363 |
+
# Log gradient norms before clipping
|
| 364 |
+
grad_stats = self.grad_monitor.log_gradients(self.global_step)
|
| 365 |
+
if grad_stats is not None:
|
| 366 |
+
tqdm.write(
|
| 367 |
+
f" [Step {self.global_step}] "
|
| 368 |
+
f"Grad norm: {grad_stats['grad_norm']:.4f}, "
|
| 369 |
+
f"Max: {grad_stats['grad_norm_max']:.4f}"
|
| 370 |
+
)
|
| 371 |
+
# Log to MLflow
|
| 372 |
+
for key, val in grad_stats.items():
|
| 373 |
+
mlflow.log_metric(f"grad_{key}", val, step=self.global_step)
|
| 374 |
+
|
| 375 |
# Check gradients for NaN/Inf BEFORE clipping
|
| 376 |
nan_grad = self.nan_detector.check_gradients(self._current_step)
|
| 377 |
if nan_grad is not None:
|
|
|
|
| 395 |
|
| 396 |
self.optimizer.zero_grad()
|
| 397 |
|
| 398 |
+
# Step the learning rate scheduler
|
| 399 |
+
if self.scheduler is not None:
|
| 400 |
+
self.scheduler.step()
|
| 401 |
+
self.global_step += 1
|
| 402 |
+
# Log learning rate
|
| 403 |
+
current_lr = self.scheduler.get_last_lr()[0]
|
| 404 |
+
mlflow.log_metric("learning_rate", current_lr, step=self.global_step)
|
| 405 |
+
|
| 406 |
# Check parameters for NaN AFTER update
|
| 407 |
nan_param = self.nan_detector.check_parameters(self._current_step)
|
| 408 |
if nan_param is not None:
|
|
|
|
| 410 |
f"NaN in parameter {nan_param} after optimizer step at step {self._current_step}!"
|
| 411 |
)
|
| 412 |
|
| 413 |
+
def _clip_embedding_gradients(self, max_norm: float = 5.0) -> None:
|
| 414 |
+
"""Clip embedding gradients only if they exceed threshold.
|
| 415 |
+
|
| 416 |
+
Less aggressive clipping to allow learning while preventing
|
| 417 |
+
overflow with inductor backend + gradient accumulation.
|
| 418 |
+
"""
|
| 419 |
+
for name, param in self.model.named_parameters():
|
| 420 |
+
if param.grad is not None and "embedding" in name.lower():
|
| 421 |
+
grad = param.grad
|
| 422 |
+
# Only fix actual NaN/Inf, don't preemptively clip
|
| 423 |
+
if torch.isnan(grad).any() or torch.isinf(grad).any():
|
| 424 |
+
# Count NaNs for monitoring
|
| 425 |
+
nan_count = torch.isnan(grad).sum().item()
|
| 426 |
+
inf_count = torch.isinf(grad).sum().item()
|
| 427 |
+
if nan_count > 0 or inf_count > 0:
|
| 428 |
+
# Replace with zeros only where invalid
|
| 429 |
+
param.grad = torch.where(
|
| 430 |
+
torch.isnan(grad) | torch.isinf(grad), torch.zeros_like(grad), grad
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
# Normal gradient - only clip if extremely large
|
| 434 |
+
grad_norm = param.grad.norm()
|
| 435 |
+
if grad_norm > max_norm:
|
| 436 |
+
param.grad = param.grad * (max_norm / (grad_norm + 1e-6))
|
| 437 |
+
|
| 438 |
def _get_batch(
|
| 439 |
self, iterators: Dict, loader: DataLoader, task: str
|
| 440 |
) -> Dict[str, torch.Tensor] | None:
|
|
|
|
| 489 |
inputs["src_mask"] = batch["src_mask"]
|
| 490 |
|
| 491 |
logits = self.model.forward("summarization", inputs)
|
| 492 |
+
|
| 493 |
+
# Compute loss with proper masking
|
| 494 |
loss = F.cross_entropy(
|
| 495 |
logits.view(-1, logits.size(-1)),
|
| 496 |
batch["labels"].view(-1),
|
|
|
|
| 498 |
label_smoothing=self.config.label_smoothing,
|
| 499 |
)
|
| 500 |
|
| 501 |
+
# Sanity check logits
|
| 502 |
+
if self.global_step % 100 == 0:
|
| 503 |
+
with torch.no_grad():
|
| 504 |
+
tqdm.write(f" [Step {self.global_step}] Summarization logits: mean={logits.mean().item():.2f}, std={logits.std().item():.2f}, loss={loss.item():.4f}")
|
| 505 |
+
|
| 506 |
# Quick ROUGE estimate
|
| 507 |
preds = self.tokenizer.decode_batch(logits.argmax(dim=-1).tolist())
|
| 508 |
refs = self._decode_labels(batch["labels"])
|
tests/test_inference/test_pipeline.py
CHANGED
|
@@ -2,11 +2,18 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import cast
|
| 7 |
|
|
|
|
| 8 |
import torch
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from src.data.tokenization import Tokenizer, TokenizerConfig
|
| 11 |
from src.inference.pipeline import (
|
| 12 |
EmotionPrediction,
|
|
@@ -16,6 +23,21 @@ from src.inference.pipeline import (
|
|
| 16 |
)
|
| 17 |
from src.utils.labels import LabelMetadata
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def _local_tokenizer_config() -> TokenizerConfig:
|
| 21 |
root = Path(__file__).resolve().parents[2]
|
|
@@ -48,7 +70,7 @@ class DummyDecoder(torch.nn.Module):
|
|
| 48 |
device: torch.device,
|
| 49 |
**kwargs: object,
|
| 50 |
) -> torch.Tensor:
|
| 51 |
-
seq = self.sequence.to(device)
|
| 52 |
if seq.numel() > max_len:
|
| 53 |
seq = seq[:max_len]
|
| 54 |
batch = memory.size(0)
|
|
@@ -70,9 +92,9 @@ class DummyModel(torch.nn.Module):
|
|
| 70 |
) -> torch.Tensor: # pragma: no cover - simple dispatch
|
| 71 |
batch = inputs["input_ids"].size(0)
|
| 72 |
if task == "emotion":
|
| 73 |
-
return self._emotion_logits.unsqueeze(0).repeat(batch, 1)
|
| 74 |
if task == "topic":
|
| 75 |
-
return self._topic_logits.unsqueeze(0).repeat(batch, 1)
|
| 76 |
raise KeyError(task)
|
| 77 |
|
| 78 |
|
|
@@ -85,7 +107,7 @@ def _build_pipeline() -> InferencePipeline:
|
|
| 85 |
tokenizer=tokenizer,
|
| 86 |
emotion_labels=metadata.emotion,
|
| 87 |
topic_labels=metadata.topic,
|
| 88 |
-
config=InferenceConfig(summary_max_length=12),
|
| 89 |
)
|
| 90 |
|
| 91 |
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import sys
|
| 6 |
+
import warnings
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import cast
|
| 9 |
|
| 10 |
+
import pytest
|
| 11 |
import torch
|
| 12 |
|
| 13 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 14 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 16 |
+
|
| 17 |
from src.data.tokenization import Tokenizer, TokenizerConfig
|
| 18 |
from src.inference.pipeline import (
|
| 19 |
EmotionPrediction,
|
|
|
|
| 23 |
)
|
| 24 |
from src.utils.labels import LabelMetadata
|
| 25 |
|
| 26 |
+
# Silence noisy DeprecationWarnings from underlying tokenizer bindings used in tests
|
| 27 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 28 |
+
warnings.filterwarnings(
|
| 29 |
+
"ignore",
|
| 30 |
+
message=r"builtin type SwigPy.*has no __module__ attribute",
|
| 31 |
+
category=DeprecationWarning,
|
| 32 |
+
)
|
| 33 |
+
warnings.filterwarnings(
|
| 34 |
+
"ignore",
|
| 35 |
+
category=DeprecationWarning,
|
| 36 |
+
module=r"importlib\\._bootstrap",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
| 40 |
+
|
| 41 |
|
| 42 |
def _local_tokenizer_config() -> TokenizerConfig:
|
| 43 |
root = Path(__file__).resolve().parents[2]
|
|
|
|
| 70 |
device: torch.device,
|
| 71 |
**kwargs: object,
|
| 72 |
) -> torch.Tensor:
|
| 73 |
+
seq = cast(torch.Tensor, self.sequence).to(device)
|
| 74 |
if seq.numel() > max_len:
|
| 75 |
seq = seq[:max_len]
|
| 76 |
batch = memory.size(0)
|
|
|
|
| 92 |
) -> torch.Tensor: # pragma: no cover - simple dispatch
|
| 93 |
batch = inputs["input_ids"].size(0)
|
| 94 |
if task == "emotion":
|
| 95 |
+
return cast(torch.Tensor, self._emotion_logits).unsqueeze(0).repeat(batch, 1)
|
| 96 |
if task == "topic":
|
| 97 |
+
return cast(torch.Tensor, self._topic_logits).unsqueeze(0).repeat(batch, 1)
|
| 98 |
raise KeyError(task)
|
| 99 |
|
| 100 |
|
|
|
|
| 107 |
tokenizer=tokenizer,
|
| 108 |
emotion_labels=metadata.emotion,
|
| 109 |
topic_labels=metadata.topic,
|
| 110 |
+
config=InferenceConfig(summary_max_length=12, summary_formatting=False),
|
| 111 |
)
|
| 112 |
|
| 113 |
|
tests/test_models/test_visualizations.py
CHANGED
|
@@ -34,7 +34,7 @@ def test_attention_visualization():
|
|
| 34 |
V = torch.eye(seq_len, d_k).unsqueeze(0) # Identity-like
|
| 35 |
|
| 36 |
# Compute attention
|
| 37 |
-
|
| 38 |
|
| 39 |
# Plot attention weights
|
| 40 |
plt.figure(figsize=(8, 6))
|
|
|
|
| 34 |
V = torch.eye(seq_len, d_k).unsqueeze(0) # Identity-like
|
| 35 |
|
| 36 |
# Compute attention
|
| 37 |
+
_output, weights = attention(Q, K, V, return_attn_weights=True)
|
| 38 |
|
| 39 |
# Plot attention weights
|
| 40 |
plt.figure(figsize=(8, 6))
|