anthonym21 commited on
Commit
c36dfe4
·
verified ·
1 Parent(s): 735e7d8

Step 53110: 27.1B tokens (Stage 2 in progress), loss=1.463, ppl=4.3

Browse files
config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SABERForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_saber.SABERConfig",
7
+ "AutoModelForCausalLM": "modeling_saber.SABERForCausalLM"
8
+ },
9
+ "curiosity_coeff": 0.01,
10
+ "d_anchor": 96,
11
+ "d_exp": 192,
12
+ "d_ff": 2164,
13
+ "d_model": 1536,
14
+ "dtype": "float32",
15
+ "enable_anchors": true,
16
+ "enable_experience": true,
17
+ "gradient_checkpointing": false,
18
+ "head_dim": 128,
19
+ "initializer_range": 0.02,
20
+ "max_position_embeddings": 2048,
21
+ "model_type": "saber",
22
+ "n_anchors": 64,
23
+ "n_heads": 12,
24
+ "n_layers": 20,
25
+ "predictability_mode": false,
26
+ "resonant_alpha_init": 3.0,
27
+ "resonant_layers": [
28
+ 0,
29
+ 2,
30
+ 4,
31
+ 6,
32
+ 8,
33
+ 10,
34
+ 12,
35
+ 14,
36
+ 16,
37
+ 18
38
+ ],
39
+ "rms_norm_eps": 1e-06,
40
+ "rope_theta": 10000.0,
41
+ "tie_word_embeddings": true,
42
+ "transformers_version": "5.3.0",
43
+ "use_cache": true,
44
+ "vocab_size": 50257
45
+ }
configuration_saber.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ configuration_saber.py — HuggingFace-compatible configuration for Eve-3-SABER-1B.
3
+
4
+ Usage:
5
+ from configuration_saber import SABERConfig
6
+
7
+ config = SABERConfig() # default 1B spec
8
+ config.save_pretrained("./eve-3-saber-1b")
9
+ config = SABERConfig.from_pretrained("./eve-3-saber-1b")
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import List, Optional
15
+
16
+ from transformers import PretrainedConfig
17
+
18
+
19
+ class SABERConfig(PretrainedConfig):
20
+ r"""
21
+ Configuration class for Eve-3-SABER-1B.
22
+
23
+ SABER (Semantic Anchor-Biased Experience-Resonant) is a dense decoder-only
24
+ transformer with three novel components:
25
+
26
+ 1. **Slip-Anchors** — a per-layer learnable codebook that biases K and V
27
+ *after* RoPE, preserving FlashAttention compatibility.
28
+ 2. **Experience Stream** — a low-dimensional per-token state that flows
29
+ *layer-to-layer* (not token-to-token), with a curiosity auxiliary loss.
30
+ 3. **Resonant FFN** — even-numbered layers augment SwiGLU with a learned
31
+ sinusoidal modulation, blended via a trainable alpha.
32
+
33
+ Args:
34
+ vocab_size (int):
35
+ Vocabulary size. Defaults to ``50257`` (GPT-2 tokenizer).
36
+ d_model (int):
37
+ Hidden/residual dimension. Defaults to ``2048``.
38
+ n_heads (int):
39
+ Number of attention heads. Defaults to ``16``.
40
+ head_dim (int):
41
+ Per-head dimension; must satisfy ``d_model == n_heads * head_dim``.
42
+ Defaults to ``128``.
43
+ n_layers (int):
44
+ Number of transformer blocks. Defaults to ``24``.
45
+ d_ff (int):
46
+ SwiGLU inner dimension. The spec value ``5461`` yields ~1.38B params;
47
+ use ``2855`` (tuned via ``param_counter.py --tune-dff``) to hit
48
+ exactly 1.0B. Defaults to ``5461`` (spec) so the number is always
49
+ explicit and reviewable.
50
+ max_position_embeddings (int):
51
+ Maximum sequence length for RoPE. Defaults to ``2048``.
52
+ rope_theta (float):
53
+ Base for RoPE frequency computation. Defaults to ``10000.0``.
54
+ rms_norm_eps (float):
55
+ Epsilon for RMSNorm numerical stability. Defaults to ``1e-6``.
56
+ initializer_range (float):
57
+ Std-dev for weight initialization (Normal). Defaults to ``0.02``.
58
+ tie_word_embeddings (bool):
59
+ Whether to tie the LM head weights to the input embedding table.
60
+ Defaults to ``True``.
61
+
62
+ --- Slip-Anchor hyperparameters ---
63
+ n_anchors (int):
64
+ Codebook size. Defaults to ``64``.
65
+ d_anchor (int):
66
+ Anchor bottleneck dimension. Defaults to ``128``.
67
+
68
+ --- Experience-stream hyperparameters ---
69
+ d_exp (int):
70
+ Experience stream dimension. Defaults to ``256``.
71
+ curiosity_coeff (float):
72
+ Weight of curiosity auxiliary loss. Defaults to ``0.01``.
73
+
74
+ --- Resonant-FFN hyperparameters ---
75
+ resonant_layers (Optional[List[int]]):
76
+ Which layer indices use the resonant FFN. ``None`` means "all even
77
+ layers (0, 2, 4, …)". Pass an explicit list to override (e.g. last
78
+ 8 layers only for predictability mode).
79
+ resonant_alpha_init (float):
80
+ Initial value of ``alpha_raw`` before sigmoid; ``sigmoid(3.0)≈0.95``
81
+ starts training near pure SwiGLU. Defaults to ``3.0``.
82
+
83
+ --- Predictability mode (GPT-5.2 Thinking) ---
84
+ predictability_mode (bool):
85
+ When ``True`` the following overrides are applied at model
86
+ construction time:
87
+ * Anchor gate bias → ``-3`` (anchors nearly silent).
88
+ * ``U_e`` scale → ``0.05`` (tiny experience updates).
89
+ * ``resonant_layers`` → last 8 layers only.
90
+ Defaults to ``False``.
91
+
92
+ --- Gradient checkpointing ---
93
+ use_cache (bool):
94
+ Whether past KV states are returned (not used during training).
95
+ Defaults to ``True``.
96
+ gradient_checkpointing (bool):
97
+ Enable activation checkpointing. Set via
98
+ ``model.gradient_checkpointing_enable()`` rather than here in most
99
+ cases. Defaults to ``False``.
100
+ """
101
+
102
+ # Required by HuggingFace AutoModel registry
103
+ model_type: str = "saber"
104
+
105
+ # Map canonical HF attribute names to SABER field names so that
106
+ # generic HF utilities (e.g. model.config.hidden_size) work transparently.
107
+ attribute_map = {
108
+ "hidden_size": "d_model",
109
+ "num_hidden_layers": "n_layers",
110
+ "num_attention_heads": "n_heads",
111
+ "intermediate_size": "d_ff",
112
+ "max_position_embeddings": "max_position_embeddings",
113
+ }
114
+
115
+ def __init__(
116
+ self,
117
+ # Core architecture
118
+ vocab_size: int = 50257,
119
+ d_model: int = 2048,
120
+ n_heads: int = 16,
121
+ head_dim: int = 128,
122
+ n_layers: int = 24,
123
+ d_ff: int = 2855,
124
+ max_position_embeddings: int = 2048,
125
+ rope_theta: float = 10_000.0,
126
+ rms_norm_eps: float = 1e-6,
127
+ initializer_range: float = 0.02,
128
+ tie_word_embeddings: bool = True,
129
+ # Slip-anchor
130
+ n_anchors: int = 64,
131
+ d_anchor: int = 128,
132
+ # Experience stream
133
+ d_exp: int = 256,
134
+ curiosity_coeff: float = 0.01,
135
+ # Resonant FFN
136
+ resonant_layers: Optional[List[int]] = None,
137
+ resonant_alpha_init: float = 3.0,
138
+ # Predictability mode
139
+ predictability_mode: bool = False,
140
+ # Inference / training toggles
141
+ use_cache: bool = True,
142
+ gradient_checkpointing: bool = False,
143
+ # Ablation flags (component enable/disable)
144
+ enable_anchors: bool = True,
145
+ enable_experience: bool = True,
146
+ **kwargs,
147
+ ) -> None:
148
+ # ------------------------------------------------------------------ #
149
+ # Validate key relationships
150
+ # ------------------------------------------------------------------ #
151
+ if d_model != n_heads * head_dim:
152
+ raise ValueError(
153
+ f"d_model ({d_model}) must equal n_heads ({n_heads}) × "
154
+ f"head_dim ({head_dim}) = {n_heads * head_dim}."
155
+ )
156
+
157
+ # ------------------------------------------------------------------ #
158
+ # Core
159
+ # ------------------------------------------------------------------ #
160
+ self.vocab_size = vocab_size
161
+ self.d_model = d_model
162
+ self.n_heads = n_heads
163
+ self.head_dim = head_dim
164
+ self.n_layers = n_layers
165
+ self.d_ff = d_ff
166
+ self.max_position_embeddings = max_position_embeddings
167
+ self.rope_theta = rope_theta
168
+ self.rms_norm_eps = rms_norm_eps
169
+ self.initializer_range = initializer_range
170
+
171
+ # ------------------------------------------------------------------ #
172
+ # Slip-anchor
173
+ # ------------------------------------------------------------------ #
174
+ self.n_anchors = n_anchors
175
+ self.d_anchor = d_anchor
176
+
177
+ # ------------------------------------------------------------------ #
178
+ # Experience stream
179
+ # ------------------------------------------------------------------ #
180
+ self.d_exp = d_exp
181
+ self.curiosity_coeff = curiosity_coeff
182
+
183
+ # ------------------------------------------------------------------ #
184
+ # Resonant FFN — default to all even layers
185
+ # ------------------------------------------------------------------ #
186
+ if resonant_layers is None:
187
+ resonant_layers = [i for i in range(n_layers) if i % 2 == 0]
188
+ self.resonant_layers = resonant_layers
189
+ self.resonant_alpha_init = resonant_alpha_init
190
+
191
+ # ------------------------------------------------------------------ #
192
+ # Predictability mode overrides
193
+ # ------------------------------------------------------------------ #
194
+ self.predictability_mode = predictability_mode
195
+ if predictability_mode:
196
+ # Last 8 layers only
197
+ self.resonant_layers = list(range(n_layers - 8, n_layers))
198
+
199
+ # ------------------------------------------------------------------ #
200
+ # Inference / training
201
+ # ------------------------------------------------------------------ #
202
+ self.use_cache = use_cache
203
+ self.gradient_checkpointing = gradient_checkpointing
204
+
205
+ # Ablation flags — allow disabling novel components
206
+ self.enable_anchors = enable_anchors
207
+ self.enable_experience = enable_experience
208
+
209
+ # ------------------------------------------------------------------ #
210
+ # Pass through to PretrainedConfig (handles tie_word_embeddings, etc.)
211
+ # ------------------------------------------------------------------ #
212
+ super().__init__(
213
+ tie_word_embeddings=tie_word_embeddings,
214
+ **kwargs,
215
+ )
216
+
217
+ # ---------------------------------------------------------------------- #
218
+ # Derived helpers (read-only properties, not serialized)
219
+ # ---------------------------------------------------------------------- #
220
+
221
+ @property
222
+ def num_key_value_heads(self) -> int:
223
+ """Alias for n_heads (SABER uses MHA, not GQA)."""
224
+ return self.n_heads
225
+
226
+ @property
227
+ def n_resonant_layers(self) -> int:
228
+ """Number of layers that use the resonant FFN."""
229
+ return len(self.resonant_layers)
230
+
231
+ def __repr__(self) -> str: # noqa: D401
232
+ resonant_str = (
233
+ f"all-even (n={self.n_resonant_layers})"
234
+ if self.resonant_layers == [i for i in range(self.n_layers) if i % 2 == 0]
235
+ else str(self.resonant_layers)
236
+ )
237
+ return (
238
+ f"SABERConfig(\n"
239
+ f" d_model={self.d_model}, n_heads={self.n_heads}, "
240
+ f"head_dim={self.head_dim}, n_layers={self.n_layers},\n"
241
+ f" d_ff={self.d_ff}, vocab_size={self.vocab_size}, "
242
+ f"max_seq={self.max_position_embeddings},\n"
243
+ f" n_anchors={self.n_anchors}, d_anchor={self.d_anchor}, "
244
+ f"d_exp={self.d_exp},\n"
245
+ f" curiosity_coeff={self.curiosity_coeff}, "
246
+ f"resonant_layers={resonant_str},\n"
247
+ f" resonant_alpha_init={self.resonant_alpha_init}, "
248
+ f"predictability_mode={self.predictability_mode},\n"
249
+ f" tie_word_embeddings={self.tie_word_embeddings}, "
250
+ f"use_cache={self.use_cache}\n"
251
+ f")"
252
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "output_attentions": false,
4
+ "output_hidden_states": false,
5
+ "transformers_version": "5.3.0",
6
+ "use_cache": true
7
+ }
meta.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "step": 53110,
3
+ "tokens_seen": 27108895744,
4
+ "stage_idx": 1,
5
+ "wandb_run_id": null,
6
+ "total_target": 50000000000
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c91927e9f167f59f63c90c285590e3fd3c4ead4eb6474bd6ddeba597c6806ea4
3
+ size 1999952456
modeling_saber.py ADDED
@@ -0,0 +1,948 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_saber.py — Full PyTorch implementation of Eve-3-SABER-1B.
3
+
4
+ Architecture highlights
5
+ -----------------------
6
+ * Dense decoder-only transformer with pre-RMSNorm.
7
+ * RoPE (rotary position embeddings) applied to Q and K after head reshape.
8
+ * **Slip-Anchors**: learnable codebook biases K/V *after* RoPE, fully
9
+ compatible with FlashAttention / F.scaled_dot_product_attention.
10
+ * **Experience Stream**: a per-token, layer-traversing state with a curiosity
11
+ auxiliary loss (prediction-error on a stop-gradient summary).
12
+ * **Resonant FFN**: even-indexed layers augment SwiGLU with a learned
13
+ sinusoidal modulation blended by a trainable scalar alpha.
14
+ * Weight-tied LM head.
15
+ * Gradient-checkpointing support.
16
+
17
+ Intended usage (HuggingFace Trainer / SFTTrainer compatible):
18
+ from configuration_saber import SABERConfig
19
+ from modeling_saber import SABERForCausalLM
20
+
21
+ config = SABERConfig()
22
+ model = SABERForCausalLM(config)
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import math
28
+ from typing import List, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ import torch.utils.checkpoint
34
+
35
+ from transformers import PreTrainedModel
36
+ from transformers.generation import GenerationMixin
37
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
38
+ from transformers.utils import logging
39
+
40
+ from configuration_saber import SABERConfig
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # 1. RMSNorm
46
+ # ---------------------------------------------------------------------------
47
+
48
+ class SABERRMSNorm(nn.Module):
49
+ """Root-mean-square layer normalization (no bias, learnable scale)."""
50
+
51
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
52
+ super().__init__()
53
+ self.weight = nn.Parameter(torch.ones(hidden_size))
54
+ self.eps = eps
55
+
56
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
57
+ # x: (..., hidden_size)
58
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ # Cast to float for numerical stability, then back to input dtype
62
+ return (self._norm(x.float()) * self.weight.float()).to(x.dtype)
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # 2. Rotary Position Embeddings (RoPE)
67
+ # ---------------------------------------------------------------------------
68
+
69
+ class SABERRotaryEmbedding(nn.Module):
70
+ """
71
+ Standard RoPE using complex-number rotation (Llama / GPT-NeoX style).
72
+
73
+ Frequencies are cached up to ``max_seq_len`` and extended on the fly if
74
+ a longer sequence is encountered.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ head_dim: int,
80
+ max_seq_len: int = 2048,
81
+ theta: float = 10_000.0,
82
+ device: Optional[torch.device] = None,
83
+ ) -> None:
84
+ super().__init__()
85
+ self.head_dim = head_dim
86
+ self.max_seq_len = max_seq_len
87
+ self.theta = theta
88
+
89
+ # Precompute inverse frequencies (half of head_dim)
90
+ inv_freq = 1.0 / (
91
+ theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
92
+ / head_dim)
93
+ )
94
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
95
+ self._build_cache(max_seq_len, device=device)
96
+
97
+ def _build_cache(self, seq_len: int, device: Optional[torch.device] = None) -> None:
98
+ """Build (or extend) the cos/sin cache."""
99
+ t = torch.arange(seq_len, dtype=torch.float32,
100
+ device=self.inv_freq.device if device is None else device)
101
+ freqs = torch.outer(t, self.inv_freq) # (seq_len, head_dim/2)
102
+ emb = torch.cat([freqs, freqs], dim=-1) # (seq_len, head_dim)
103
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
104
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
105
+ self.max_seq_len = seq_len
106
+
107
+ @staticmethod
108
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
109
+ """Rotate the second half of the last dimension by -90°."""
110
+ half = x.shape[-1] // 2
111
+ x1, x2 = x[..., :half], x[..., half:]
112
+ return torch.cat([-x2, x1], dim=-1)
113
+
114
+ def forward(
115
+ self,
116
+ q: torch.Tensor,
117
+ k: torch.Tensor,
118
+ seq_len: int,
119
+ position_ids: Optional[torch.LongTensor] = None,
120
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
121
+ """
122
+ Apply RoPE to q and k.
123
+
124
+ q, k: (batch, n_heads, seq_len, head_dim)
125
+ position_ids: (batch, seq_len) or None
126
+ """
127
+ if seq_len > self.max_seq_len:
128
+ self._build_cache(seq_len, device=q.device)
129
+
130
+ if position_ids is not None:
131
+ # Gather cos/sin for the specific positions in this batch.
132
+ # cos_cached: (1, 1, max_seq, head_dim) → flatten to (max_seq, head_dim)
133
+ # then index with position_ids (B, L) → (B, L, head_dim)
134
+ # and unsqueeze head axis → (B, 1, L, head_dim)
135
+ cos_2d = self.cos_cached.squeeze(0).squeeze(0).to(q.dtype) # (max_seq, head_dim)
136
+ sin_2d = self.sin_cached.squeeze(0).squeeze(0).to(q.dtype)
137
+ cos = cos_2d[position_ids].unsqueeze(1) # (B, 1, L, head_dim)
138
+ sin = sin_2d[position_ids].unsqueeze(1)
139
+ else:
140
+ cos = self.cos_cached[:, :, :seq_len, :].to(q.dtype) # (1, 1, L, head_dim)
141
+ sin = self.sin_cached[:, :, :seq_len, :].to(q.dtype)
142
+
143
+ q_rot = q * cos + self._rotate_half(q) * sin
144
+ k_rot = k * cos + self._rotate_half(k) * sin
145
+ return q_rot, k_rot
146
+
147
+
148
+ # ---------------------------------------------------------------------------
149
+ # 3. Slip-Anchors
150
+ # ---------------------------------------------------------------------------
151
+
152
+ class SlipAnchors(nn.Module):
153
+ """
154
+ Slip-anchor module — biases K and V using a learnable codebook.
155
+
156
+ Applied *after* RoPE, so FlashAttention compatibility is preserved.
157
+
158
+ Parameters
159
+ ----------
160
+ d_model : residual hidden dimension (2048)
161
+ n_anchors : codebook size (64)
162
+ d_anchor : anchor bottleneck dim (128)
163
+ head_dim : per-head dimension (128)
164
+ n_heads : number of attention heads (16)
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ d_model: int,
170
+ n_anchors: int,
171
+ d_anchor: int,
172
+ head_dim: int,
173
+ n_heads: int,
174
+ ) -> None:
175
+ super().__init__()
176
+ self.n_anchors = n_anchors
177
+ self.d_anchor = d_anchor
178
+ self.n_heads = n_heads
179
+ self.head_dim = head_dim
180
+
181
+ # Learnable codebook: (n_anchors, d_anchor)
182
+ self.anchors = nn.Parameter(torch.empty(n_anchors, d_anchor))
183
+ # h → anchor space
184
+ self.W_anchor_down = nn.Linear(d_model, d_anchor, bias=False)
185
+ # anchor context → K bias (per head)
186
+ self.U_k = nn.Linear(d_anchor, head_dim, bias=False)
187
+ # anchor context → V bias (per head)
188
+ self.U_v = nn.Linear(d_anchor, head_dim, bias=False)
189
+
190
+ self._init_weights()
191
+
192
+ def _init_weights(self) -> None:
193
+ nn.init.normal_(self.anchors, std=0.02)
194
+ nn.init.normal_(self.W_anchor_down.weight, std=0.02)
195
+ nn.init.normal_(self.U_k.weight, std=0.02)
196
+ nn.init.normal_(self.U_v.weight, std=0.02)
197
+
198
+ def forward(
199
+ self,
200
+ h: torch.Tensor, # (B, L, d_model) — pre-attention hidden state
201
+ K: torch.Tensor, # (B, n_heads, L, head_dim) — post-RoPE
202
+ V: torch.Tensor, # (B, n_heads, L, head_dim)
203
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ """Return K_modified, V_modified."""
205
+ B, L, _ = h.shape
206
+
207
+ # 1. Project h to anchor space: (B, L, d_anchor)
208
+ h_anchor = self.W_anchor_down(h)
209
+
210
+ # 2. Soft scores over codebook: (B, L, n_anchors)
211
+ scores = torch.softmax(h_anchor @ self.anchors.T, dim=-1)
212
+
213
+ # 3. Weighted anchor context: (B, L, d_anchor)
214
+ anchor_context = scores @ self.anchors
215
+
216
+ # 4. Project to K and V bias spaces: (B, L, head_dim)
217
+ k_bias = self.U_k(anchor_context) # (B, L, head_dim)
218
+ v_bias = self.U_v(anchor_context) # (B, L, head_dim)
219
+
220
+ # 5. Broadcast across heads: unsqueeze head dim → (B, 1, L, head_dim)
221
+ K_modified = K + k_bias.unsqueeze(1)
222
+ V_modified = V + v_bias.unsqueeze(1)
223
+
224
+ return K_modified, V_modified
225
+
226
+
227
+ # ---------------------------------------------------------------------------
228
+ # 4. Attention
229
+ # ---------------------------------------------------------------------------
230
+
231
+ class SABERAttention(nn.Module):
232
+ """
233
+ Multi-head attention with:
234
+ * No projection biases.
235
+ * RoPE applied to Q and K after head reshape.
236
+ * Slip-anchor modulation of K and V after RoPE.
237
+ * F.scaled_dot_product_attention (FlashAttention 2 compatible).
238
+ """
239
+
240
+ def __init__(self, config: SABERConfig, layer_idx: int) -> None:
241
+ super().__init__()
242
+ self.config = config
243
+ self.layer_idx = layer_idx
244
+ self.d_model = config.d_model
245
+ self.n_heads = config.n_heads
246
+ self.head_dim = config.head_dim
247
+
248
+ # QKV and O projections — no bias throughout
249
+ self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)
250
+ self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False)
251
+ self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False)
252
+ self.o_proj = nn.Linear(self.d_model, self.d_model, bias=False)
253
+
254
+ # Rotary embeddings (shared via the parent model, but instantiated here
255
+ # for standalone correctness)
256
+ self.rotary_emb = SABERRotaryEmbedding(
257
+ head_dim=self.head_dim,
258
+ max_seq_len=config.max_position_embeddings,
259
+ theta=config.rope_theta,
260
+ )
261
+
262
+ # Slip-anchors
263
+ self.slip_anchors = SlipAnchors(
264
+ d_model=self.d_model,
265
+ n_anchors=config.n_anchors,
266
+ d_anchor=config.d_anchor,
267
+ head_dim=self.head_dim,
268
+ n_heads=self.n_heads,
269
+ )
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states: torch.Tensor, # (B, L, d_model)
274
+ attention_mask: Optional[torch.Tensor] = None,
275
+ position_ids: Optional[torch.LongTensor] = None,
276
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
277
+ use_cache: bool = False,
278
+ output_attentions: bool = False,
279
+ ) -> Tuple[torch.Tensor, ...]:
280
+
281
+ B, L, _ = hidden_states.shape
282
+
283
+ # ---- QKV projections ----
284
+ Q = self.q_proj(hidden_states) # (B, L, d_model)
285
+ K = self.k_proj(hidden_states)
286
+ V = self.v_proj(hidden_states)
287
+
288
+ # ---- Reshape to (B, n_heads, L, head_dim) ----
289
+ def _reshape(t: torch.Tensor) -> torch.Tensor:
290
+ return t.view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
291
+
292
+ Q, K, V = _reshape(Q), _reshape(K), _reshape(V)
293
+
294
+ # ---- Apply RoPE to Q and K ----
295
+ kv_seq_len = L
296
+ if past_key_value is not None:
297
+ kv_seq_len += past_key_value[0].shape[-2]
298
+
299
+ Q, K = self.rotary_emb(Q, K, seq_len=kv_seq_len, position_ids=position_ids)
300
+
301
+ # ---- KV cache ----
302
+ if past_key_value is not None:
303
+ K = torch.cat([past_key_value[0], K], dim=2)
304
+ V = torch.cat([past_key_value[1], V], dim=2)
305
+
306
+ present_kv = (K, V) if use_cache else None
307
+
308
+ # ---- Slip-anchor modulation of K and V ----
309
+ # Pass raw h (pre-attn hidden state) to avoid circularity
310
+ if getattr(self.config, 'enable_anchors', True):
311
+ K, V = self.slip_anchors(hidden_states, K, V)
312
+
313
+ # ---- Scaled dot-product attention (FlashAttention 2 compatible) ----
314
+ # Build causal mask if needed (SDPA handles is_causal natively)
315
+ is_causal = attention_mask is None and L > 1
316
+ attn_out = F.scaled_dot_product_attention(
317
+ Q, K, V,
318
+ attn_mask=attention_mask,
319
+ dropout_p=0.0,
320
+ is_causal=is_causal,
321
+ ) # (B, n_heads, L, head_dim)
322
+
323
+ # ---- Merge heads and project ----
324
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, self.d_model)
325
+ attn_out = self.o_proj(attn_out)
326
+
327
+ outputs: Tuple = (attn_out,)
328
+ if use_cache:
329
+ outputs += (present_kv,)
330
+ if output_attentions:
331
+ # Attention weights are not explicitly computed when using SDPA
332
+ outputs += (None,)
333
+
334
+ return outputs
335
+
336
+
337
+ # ---------------------------------------------------------------------------
338
+ # 5. Experience Stream
339
+ # ---------------------------------------------------------------------------
340
+
341
+ class ExperienceStream(nn.Module):
342
+ """
343
+ Per-layer experience update with a curiosity (prediction-error) auxiliary loss.
344
+
345
+ State flows layer-to-layer within a single forward pass; it is reset to
346
+ zeros at the start of each new sequence.
347
+
348
+ Parameters
349
+ ----------
350
+ d_model : residual hidden dimension
351
+ d_exp : experience state dimension (256)
352
+ """
353
+
354
+ def __init__(self, d_model: int, d_exp: int) -> None:
355
+ super().__init__()
356
+ # Summarise post-attention hidden state → experience space
357
+ self.W_s = nn.Linear(d_model, d_exp, bias=False)
358
+ # Predict current summary from previous state (curiosity signal)
359
+ self.W_pred = nn.Linear(d_exp, d_exp, bias=False)
360
+ # Gated update to experience state
361
+ self.W_e = nn.Linear(d_exp, d_exp, bias=False)
362
+ # Learned decay gate: sigmoid(3.0) ~ 0.95 retains most state initially
363
+ self.decay_raw = nn.Parameter(torch.full((d_exp,), 3.0))
364
+ # Layer-norm on experience state to prevent magnitude drift
365
+ self.exp_norm = nn.LayerNorm(d_exp)
366
+
367
+ def forward(
368
+ self,
369
+ h: torch.Tensor, # (B, L, d_model) post-attention
370
+ experience_state: torch.Tensor, # (B, L, d_exp) previous state
371
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
372
+ """
373
+ Returns
374
+ -------
375
+ new_experience_state : (B, L, d_exp)
376
+ curiosity_loss : scalar tensor
377
+ """
378
+ # 1. Summarise current hidden state
379
+ s = self.W_s(h) # (B, L, d_exp)
380
+
381
+ # 2. Stop-gradient on s for the curiosity term (CRITICAL for stability)
382
+ s_sg = s.detach()
383
+
384
+ # 3. Predict current summary from previous experience state
385
+ s_pred = self.W_pred(experience_state) # (B, L, d_exp)
386
+
387
+ # 4. Curiosity = mean squared prediction error
388
+ curiosity_loss = (s_sg - s_pred).pow(2).mean()
389
+
390
+ # 5. Update experience state with SiLU-gated delta
391
+ decay = torch.sigmoid(self.decay_raw) # (d_exp,) in [0, 1]
392
+ delta = F.silu(self.W_e(s)) # (B, L, d_exp)
393
+ new_state = decay * experience_state + delta
394
+ new_state = self.exp_norm(new_state)
395
+
396
+ return new_state, curiosity_loss
397
+
398
+
399
+ # ---------------------------------------------------------------------------
400
+ # 6. Feed-forward networks
401
+ # ---------------------------------------------------------------------------
402
+
403
+ class StandardFFN(nn.Module):
404
+ """Standard SwiGLU FFN (used on odd-indexed layers)."""
405
+
406
+ def __init__(self, d_model: int, d_ff: int) -> None:
407
+ super().__init__()
408
+ self.W1 = nn.Linear(d_model, d_ff, bias=False) # gate projection
409
+ self.W3 = nn.Linear(d_model, d_ff, bias=False) # up projection
410
+ self.W2 = nn.Linear(d_ff, d_model, bias=False) # down projection
411
+
412
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
413
+ # SwiGLU: silu(gate) ⊙ up, then project down
414
+ return self.W2(F.silu(self.W1(x)) * self.W3(x))
415
+
416
+
417
+ class ResonantFFN(nn.Module):
418
+ """
419
+ Resonant FFN (used on even-indexed layers).
420
+
421
+ Augments standard SwiGLU with a learned sinusoidal modulation.
422
+ The blend is controlled by a per-layer scalar alpha (init ≈ 0.95).
423
+
424
+ ffn_out = W2(silu(W1(x)) * W3(x)) # standard SwiGLU
425
+ mod = sin(W_freq @ x) # sinusoidal modulation
426
+ alpha = sigmoid(alpha_raw) # ≈ 0.95 at init
427
+ output = alpha * ffn_out + (1-alpha) * ffn_out * (1 + mod)
428
+ = ffn_out * (alpha + (1-alpha) * (1 + mod))
429
+ """
430
+
431
+ def __init__(self, d_model: int, d_ff: int, alpha_init: float = 3.0) -> None:
432
+ super().__init__()
433
+ # Shared SwiGLU matrices
434
+ self.W1 = nn.Linear(d_model, d_ff, bias=False)
435
+ self.W3 = nn.Linear(d_model, d_ff, bias=False)
436
+ self.W2 = nn.Linear(d_ff, d_model, bias=False)
437
+ # Sinusoidal modulation projection
438
+ self.W_freq = nn.Linear(d_model, d_model, bias=False)
439
+ # Per-layer blending scalar; init so sigmoid(alpha_raw) ≈ 0.95
440
+ self.alpha_raw = nn.Parameter(torch.tensor(alpha_init))
441
+
442
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
443
+ # Standard SwiGLU output
444
+ ffn_out = self.W2(F.silu(self.W1(x)) * self.W3(x)) # (B, L, d_model)
445
+
446
+ # Sinusoidal modulation
447
+ mod = torch.sin(self.W_freq(x)) # (B, L, d_model)
448
+
449
+ # Learned blend
450
+ alpha = torch.sigmoid(self.alpha_raw) # scalar ∈ (0,1)
451
+ output = alpha * ffn_out + (1.0 - alpha) * (ffn_out * (1.0 + mod))
452
+ return output
453
+
454
+
455
+ # ---------------------------------------------------------------------------
456
+ # 7. Transformer Block
457
+ # ---------------------------------------------------------------------------
458
+
459
+ class SABERBlock(nn.Module):
460
+ """
461
+ Single SABER transformer block.
462
+
463
+ Structure (pre-norm):
464
+ h = h + Attention(RMSNorm(h))
465
+ h = h + FFN(RMSNorm(h))
466
+ experience_state, curiosity = ExperienceStream(h, experience_state)
467
+ """
468
+
469
+ def __init__(self, config: SABERConfig, layer_idx: int) -> None:
470
+ super().__init__()
471
+ self.config = config
472
+ self.layer_idx = layer_idx
473
+
474
+ self.input_layernorm = SABERRMSNorm(config.d_model, eps=config.rms_norm_eps)
475
+ self.post_attention_layernorm = SABERRMSNorm(config.d_model, eps=config.rms_norm_eps)
476
+
477
+ self.self_attn = SABERAttention(config, layer_idx=layer_idx)
478
+
479
+ # Select FFN type based on layer index
480
+ if layer_idx in config.resonant_layers:
481
+ self.ffn: nn.Module = ResonantFFN(
482
+ d_model=config.d_model,
483
+ d_ff=config.d_ff,
484
+ alpha_init=config.resonant_alpha_init,
485
+ )
486
+ else:
487
+ self.ffn = StandardFFN(d_model=config.d_model, d_ff=config.d_ff)
488
+
489
+ self.experience_stream = ExperienceStream(
490
+ d_model=config.d_model,
491
+ d_exp=config.d_exp,
492
+ )
493
+
494
+ def forward(
495
+ self,
496
+ hidden_states: torch.Tensor, # (B, L, d_model)
497
+ experience_state: torch.Tensor, # (B, L, d_exp)
498
+ attention_mask: Optional[torch.Tensor] = None,
499
+ position_ids: Optional[torch.LongTensor] = None,
500
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
501
+ use_cache: bool = False,
502
+ output_attentions: bool = False,
503
+ ) -> Tuple:
504
+ residual = hidden_states
505
+
506
+ # ---- Pre-norm attention ----
507
+ normed = self.input_layernorm(hidden_states)
508
+ attn_outputs = self.self_attn(
509
+ normed,
510
+ attention_mask=attention_mask,
511
+ position_ids=position_ids,
512
+ past_key_value=past_key_value,
513
+ use_cache=use_cache,
514
+ output_attentions=output_attentions,
515
+ )
516
+ attn_out = attn_outputs[0]
517
+ hidden_states = residual + attn_out # residual connection
518
+
519
+ # ---- Pre-norm FFN ----
520
+ residual = hidden_states
521
+ hidden_states = residual + self.ffn(self.post_attention_layernorm(hidden_states))
522
+
523
+ # ---- Experience stream update ----
524
+ if getattr(self.config, 'enable_experience', True):
525
+ experience_state, curiosity_loss = self.experience_stream(
526
+ hidden_states, experience_state
527
+ )
528
+ else:
529
+ curiosity_loss = torch.tensor(0.0, device=hidden_states.device)
530
+
531
+ # Pack remaining outputs
532
+ extra = attn_outputs[1:] # present_kv and/or attention_weights
533
+ return (hidden_states, experience_state, curiosity_loss) + extra
534
+
535
+
536
+ # ---------------------------------------------------------------------------
537
+ # 8. Base Model
538
+ # ---------------------------------------------------------------------------
539
+
540
+ class SABERModel(PreTrainedModel):
541
+ """
542
+ SABER base model: token embeddings → blocks → final RMSNorm.
543
+
544
+ Does not include the LM head — use ``SABERForCausalLM`` for training.
545
+ """
546
+
547
+ config_class = SABERConfig
548
+ base_model_prefix = "model"
549
+ supports_gradient_checkpointing = True
550
+ _no_split_modules = ["SABERBlock"]
551
+ _supports_flash_attn_2 = True
552
+
553
+ def __init__(self, config: SABERConfig) -> None:
554
+ super().__init__(config)
555
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
556
+ self.layers = nn.ModuleList(
557
+ [SABERBlock(config, layer_idx=i) for i in range(config.n_layers)]
558
+ )
559
+ self.norm = SABERRMSNorm(config.d_model, eps=config.rms_norm_eps)
560
+
561
+ self.gradient_checkpointing = False
562
+ self.post_init() # weight init + gradient-checkpointing setup
563
+
564
+ # ------------------------------------------------------------------ #
565
+ # Weight initialization (called by post_init via _init_weights)
566
+ # ------------------------------------------------------------------ #
567
+
568
+ def _init_weights(self, module: nn.Module) -> None:
569
+ std = self.config.initializer_range
570
+ if isinstance(module, nn.Linear):
571
+ nn.init.normal_(module.weight, mean=0.0, std=std)
572
+ if module.bias is not None:
573
+ nn.init.zeros_(module.bias)
574
+ elif isinstance(module, nn.Embedding):
575
+ nn.init.normal_(module.weight, mean=0.0, std=std)
576
+ elif isinstance(module, SABERRMSNorm):
577
+ nn.init.ones_(module.weight)
578
+ elif isinstance(module, SlipAnchors):
579
+ # Handled inside SlipAnchors._init_weights; no-op here
580
+ pass
581
+ # ResonantFFN.alpha_raw: initialised inside the class (default=3.0)
582
+
583
+ # ------------------------------------------------------------------ #
584
+ # Accessors
585
+ # ------------------------------------------------------------------ #
586
+
587
+ def get_input_embeddings(self) -> nn.Embedding:
588
+ return self.embed_tokens
589
+
590
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
591
+ self.embed_tokens = value
592
+
593
+ # ------------------------------------------------------------------ #
594
+ # Forward
595
+ # ------------------------------------------------------------------ #
596
+
597
+ def forward(
598
+ self,
599
+ input_ids: Optional[torch.LongTensor] = None,
600
+ attention_mask: Optional[torch.Tensor] = None,
601
+ position_ids: Optional[torch.LongTensor] = None,
602
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
603
+ inputs_embeds: Optional[torch.FloatTensor] = None,
604
+ use_cache: Optional[bool] = None,
605
+ output_attentions: Optional[bool] = None,
606
+ output_hidden_states: Optional[bool] = None,
607
+ return_dict: Optional[bool] = None,
608
+ ) -> Union[BaseModelOutputWithPast, Tuple]:
609
+
610
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
611
+ output_attentions = output_attentions or False
612
+ output_hidden_states = output_hidden_states or False
613
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
614
+
615
+ # ---- Embeddings ----
616
+ if inputs_embeds is None:
617
+ if input_ids is None:
618
+ raise ValueError("Provide either input_ids or inputs_embeds.")
619
+ inputs_embeds = self.embed_tokens(input_ids)
620
+
621
+ B, L, _ = inputs_embeds.shape
622
+
623
+ # ---- Position ids ----
624
+ if position_ids is None:
625
+ past_len = past_key_values[0][0].shape[-2] if past_key_values else 0
626
+ position_ids = torch.arange(
627
+ past_len, past_len + L,
628
+ dtype=torch.long,
629
+ device=inputs_embeds.device,
630
+ ).unsqueeze(0).expand(B, -1)
631
+
632
+ # ---- Attention mask conversion for SDPA ----
633
+ # We rely on SDPA's built-in is_causal flag; user-supplied masks are
634
+ # passed as-is (e.g., padding masks in float format).
635
+ # If a 2-D (B, L) boolean mask is supplied, convert to additive float.
636
+ causal_mask: Optional[torch.Tensor] = None
637
+ if attention_mask is not None and attention_mask.dim() == 2:
638
+ # 0 → masked (−∞), 1 → attended (0)
639
+ # Expand to (B, 1, 1, L) for SDPA broadcasting
640
+ causal_mask = (
641
+ (1.0 - attention_mask[:, None, None, :].float())
642
+ * torch.finfo(inputs_embeds.dtype).min
643
+ )
644
+
645
+ # ---- Initialise experience state ----
646
+ # Shape: (B, L, d_exp) — zeros at the start of each sequence.
647
+ # Note: when using KV cache the sequence length L changes per step;
648
+ # experience state is kept external to the model for incremental
649
+ # decoding (callers may pass zeros each step for generation).
650
+ experience_state = torch.zeros(
651
+ B, L, self.config.d_exp,
652
+ dtype=inputs_embeds.dtype,
653
+ device=inputs_embeds.device,
654
+ )
655
+
656
+ # ---- Layer loop ----
657
+ hidden_states = inputs_embeds
658
+ all_hidden_states = () if output_hidden_states else None
659
+ all_self_attns = () if output_attentions else None
660
+ next_cache = []
661
+ total_curiosity = torch.tensor(0.0, device=inputs_embeds.device,
662
+ dtype=inputs_embeds.dtype)
663
+
664
+ for i, layer in enumerate(self.layers):
665
+ if output_hidden_states:
666
+ all_hidden_states += (hidden_states,)
667
+
668
+ past_kv = past_key_values[i] if past_key_values is not None else None
669
+
670
+ if self.gradient_checkpointing and self.training:
671
+ # Wrap block forward through torch.utils.checkpoint.
672
+ # Curiosity loss gradient flows normally; only activations
673
+ # are recomputed.
674
+ def _make_ckpt_fn(layer, experience_state):
675
+ def _fn(hidden_states, causal_mask, position_ids):
676
+ return layer(
677
+ hidden_states,
678
+ experience_state=experience_state,
679
+ attention_mask=causal_mask,
680
+ position_ids=position_ids,
681
+ past_key_value=None,
682
+ use_cache=False,
683
+ output_attentions=output_attentions,
684
+ )
685
+ return _fn
686
+
687
+ layer_outputs = torch.utils.checkpoint.checkpoint(
688
+ _make_ckpt_fn(layer, experience_state),
689
+ hidden_states,
690
+ causal_mask,
691
+ position_ids,
692
+ use_reentrant=False,
693
+ )
694
+ else:
695
+ layer_outputs = layer(
696
+ hidden_states,
697
+ experience_state=experience_state,
698
+ attention_mask=causal_mask,
699
+ position_ids=position_ids,
700
+ past_key_value=past_kv,
701
+ use_cache=use_cache,
702
+ output_attentions=output_attentions,
703
+ )
704
+
705
+ hidden_states = layer_outputs[0]
706
+ experience_state = layer_outputs[1]
707
+ total_curiosity = total_curiosity + layer_outputs[2]
708
+
709
+ # Collect KV cache
710
+ if use_cache:
711
+ # present_kv is at index 3 (after hidden, exp_state, curiosity)
712
+ next_cache.append(layer_outputs[3] if len(layer_outputs) > 3 else None)
713
+
714
+ if output_attentions:
715
+ # attn weights at last position when output_attentions=True
716
+ all_self_attns += (layer_outputs[-1],)
717
+
718
+ hidden_states = self.norm(hidden_states)
719
+
720
+ if output_hidden_states:
721
+ all_hidden_states += (hidden_states,)
722
+
723
+ # Average curiosity loss across layers
724
+ mean_curiosity = total_curiosity / self.config.n_layers
725
+
726
+ next_cache_out = next_cache if use_cache else None
727
+
728
+ if not return_dict:
729
+ # Always emit a fixed-position tuple so SABERForCausalLM can
730
+ # index reliably:
731
+ # [0] hidden_states
732
+ # [1] mean_curiosity
733
+ # [2] past_key_values (None when use_cache=False)
734
+ # [3] all_hidden_states (None when output_hidden_states=False)
735
+ # [4] all_self_attns (None when output_attentions=False)
736
+ return (
737
+ hidden_states,
738
+ mean_curiosity,
739
+ next_cache_out,
740
+ all_hidden_states,
741
+ all_self_attns,
742
+ )
743
+
744
+ return BaseModelOutputWithPast(
745
+ last_hidden_state=hidden_states,
746
+ past_key_values=next_cache_out,
747
+ hidden_states=all_hidden_states,
748
+ attentions=all_self_attns,
749
+ ), mean_curiosity
750
+
751
+
752
+ # ---------------------------------------------------------------------------
753
+ # 9. Causal LM wrapper
754
+ # ---------------------------------------------------------------------------
755
+
756
+ class SABERForCausalLM(PreTrainedModel, GenerationMixin):
757
+ """
758
+ Eve-3-SABER-1B for causal language modelling.
759
+
760
+ Compatible with HuggingFace ``Trainer``, ``SFTTrainer``, PEFT, and
761
+ standard ``generate()`` pipelines.
762
+
763
+ Loss = L_CE + curiosity_coeff * L_curiosity
764
+ """
765
+
766
+ config_class = SABERConfig
767
+ base_model_prefix = "model"
768
+ supports_gradient_checkpointing = True
769
+ _no_split_modules = ["SABERBlock"]
770
+ _supports_flash_attn_2 = True
771
+ # Map required for AutoModel/AutoModelForCausalLM
772
+ # Dict mapping parameter to its tied source (HF 5.x format)
773
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
774
+
775
+ def __init__(self, config: SABERConfig) -> None:
776
+ super().__init__(config)
777
+ self.model = SABERModel(config)
778
+ # LM head — tied to token embeddings (no extra params)
779
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
780
+ self.post_init()
781
+
782
+ # ------------------------------------------------------------------ #
783
+ # Weight tying (called by post_init)
784
+ # ------------------------------------------------------------------ #
785
+
786
+ def get_input_embeddings(self) -> nn.Embedding:
787
+ return self.model.embed_tokens
788
+
789
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
790
+ self.model.embed_tokens = value
791
+
792
+ def get_output_embeddings(self) -> nn.Linear:
793
+ return self.lm_head
794
+
795
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
796
+ self.lm_head = new_embeddings
797
+
798
+ def tie_weights(self, **kwargs) -> None:
799
+ """Tie lm_head.weight ← embed_tokens.weight."""
800
+ self.lm_head.weight = self.model.embed_tokens.weight
801
+
802
+ # ------------------------------------------------------------------ #
803
+ # Forward
804
+ # ------------------------------------------------------------------ #
805
+
806
+ def forward(
807
+ self,
808
+ input_ids: Optional[torch.LongTensor] = None,
809
+ attention_mask: Optional[torch.Tensor] = None,
810
+ position_ids: Optional[torch.LongTensor] = None,
811
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
812
+ inputs_embeds: Optional[torch.FloatTensor] = None,
813
+ labels: Optional[torch.LongTensor] = None,
814
+ use_cache: Optional[bool] = None,
815
+ output_attentions: Optional[bool] = None,
816
+ output_hidden_states: Optional[bool] = None,
817
+ return_dict: Optional[bool] = None,
818
+ ) -> Union[CausalLMOutputWithPast, Tuple]:
819
+
820
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
821
+
822
+ # ---- Base model (always use tuple return for clean unpacking) ----
823
+ # SABERModel always returns (hidden_states, curiosity, [pkv], [all_hs], [all_attn])
824
+ # when return_dict=False. We unpack manually and re-wrap for return_dict=True.
825
+ base_out = self.model(
826
+ input_ids=input_ids,
827
+ attention_mask=attention_mask,
828
+ position_ids=position_ids,
829
+ past_key_values=past_key_values,
830
+ inputs_embeds=inputs_embeds,
831
+ use_cache=use_cache,
832
+ output_attentions=output_attentions,
833
+ output_hidden_states=output_hidden_states,
834
+ return_dict=False, # always False so we get a plain tuple
835
+ )
836
+ # base_out: (hidden_states, curiosity_loss, [pkv], [all_hs], [all_attn])
837
+ hidden_states = base_out[0] # (B, L, d_model)
838
+ curiosity_loss = base_out[1] # scalar
839
+ pkv = base_out[2] if len(base_out) > 2 else None
840
+ all_hs = base_out[3] if len(base_out) > 3 else None
841
+ all_attn = base_out[4] if len(base_out) > 4 else None
842
+
843
+ # ---- LM logits ----
844
+ logits = self.lm_head(hidden_states) # (B, L, vocab_size)
845
+
846
+ # ---- Loss computation ----
847
+ loss: Optional[torch.Tensor] = None
848
+ if labels is not None:
849
+ # Causal LM: predict token t+1 from position t.
850
+ # Shift logits left by one, labels right by one.
851
+ shift_logits = logits[:, :-1, :].contiguous() # (B, L-1, V)
852
+ shift_labels = labels[:, 1:].contiguous() # (B, L-1)
853
+
854
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
855
+ ce_loss = loss_fct(
856
+ shift_logits.view(-1, self.config.vocab_size),
857
+ shift_labels.view(-1),
858
+ )
859
+ loss = ce_loss + self.config.curiosity_coeff * curiosity_loss
860
+
861
+ if not return_dict:
862
+ out = (logits,)
863
+ if loss is not None:
864
+ out = (loss,) + out
865
+ if pkv is not None:
866
+ out += (pkv,)
867
+ return out
868
+
869
+ # Return dict during training (allows extra keys), ModelOutput for inference
870
+ if labels is not None:
871
+ return {
872
+ "loss": loss,
873
+ "logits": logits,
874
+ "past_key_values": pkv,
875
+ "hidden_states": all_hs,
876
+ "attentions": all_attn,
877
+ "ce_loss": ce_loss,
878
+ "curiosity_loss": curiosity_loss,
879
+ }
880
+ return CausalLMOutputWithPast(
881
+ loss=loss,
882
+ logits=logits,
883
+ past_key_values=pkv,
884
+ hidden_states=all_hs,
885
+ attentions=all_attn,
886
+ )
887
+
888
+ # ------------------------------------------------------------------ #
889
+ # Generation helpers
890
+ # ------------------------------------------------------------------ #
891
+
892
+ def prepare_inputs_for_generation(
893
+ self,
894
+ input_ids: torch.LongTensor,
895
+ past_key_values: Optional[List] = None,
896
+ attention_mask: Optional[torch.Tensor] = None,
897
+ inputs_embeds: Optional[torch.FloatTensor] = None,
898
+ **kwargs,
899
+ ) -> dict:
900
+ if past_key_values is not None:
901
+ # Only pass the last token during incremental decoding
902
+ input_ids = input_ids[:, -1:]
903
+
904
+ # Build position_ids from the current seq length
905
+ position_ids = kwargs.get("position_ids", None)
906
+ if attention_mask is not None and position_ids is None:
907
+ position_ids = attention_mask.long().cumsum(-1) - 1
908
+ position_ids.masked_fill_(attention_mask == 0, 1)
909
+ if past_key_values is not None:
910
+ position_ids = position_ids[:, -1:]
911
+
912
+ model_inputs: dict = {}
913
+ if inputs_embeds is not None and past_key_values is None:
914
+ model_inputs["inputs_embeds"] = inputs_embeds
915
+ else:
916
+ model_inputs["input_ids"] = input_ids
917
+
918
+ model_inputs.update(
919
+ {
920
+ "position_ids": position_ids,
921
+ "past_key_values": past_key_values,
922
+ "use_cache": kwargs.get("use_cache", True),
923
+ "attention_mask": attention_mask,
924
+ }
925
+ )
926
+ return model_inputs
927
+
928
+ @staticmethod
929
+ def _reorder_cache(
930
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
931
+ beam_idx: torch.LongTensor,
932
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
933
+ """Re-order KV cache for beam search."""
934
+ return [
935
+ (
936
+ past_kv[0].index_select(0, beam_idx.to(past_kv[0].device)),
937
+ past_kv[1].index_select(0, beam_idx.to(past_kv[1].device)),
938
+ )
939
+ for past_kv in past_key_values
940
+ ]
941
+
942
+
943
+ # ---------------------------------------------------------------------------
944
+ # Auto-class registration hint (used by HF hub auto-loading)
945
+ # ---------------------------------------------------------------------------
946
+
947
+ SABERConfig.register_for_auto_class("AutoConfig")
948
+ SABERForCausalLM.register_for_auto_class("AutoModelForCausalLM")
optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a361bdf219883b82d1436a8b13b5c5c17e522a9f40a59e952cadb1d4063a93e8
3
+ size 3997491658
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "errors": "replace",
7
+ "is_local": false,
8
+ "model_max_length": 1024,
9
+ "pad_token": "<|endoftext|>",
10
+ "tokenizer_class": "GPT2Tokenizer",
11
+ "unk_token": "<|endoftext|>"
12
+ }