Granite-Abstract: Continuous Head Embeddings for Reasoning
A language model with learnable continuous embeddings for abstract reasoning, trained SFT and reinforcement learning.
Overview
Granite-Abstract extends Granite with a continuous head that generates soft embeddings instead of discrete tokens. This enables the model to perform internal reasoning more efficiently while maintaining compatibility with standard generation.
Key characteristics:
- Parallel continuous head for soft embeddings
- Two-phase generation: abstract reasoning β natural output
- Trainable continuous components via reinforcement learning
- 500M parameters total (Granite 350M + Continuous Head 150M)
Architecture
Dual-Head System
The model operates with two parallel heads:
Input β Embedding Layer β Granite Backbone β Hidden State
β
βββββββββββββββ΄ββββββββββββββ
β β
Continuous Head LM Head
(soft embeddings) (discrete tokens)
Generation Process
Phase 1: Abstract Mode - Model generates continuous embeddings for internal reasoning
cont_logits = continuous_head(hidden_state)
top_logits, top_indices = topk(cont_logits, 256)
next_embedding = softmax(top_logits) @ embed_layer[top_indices]
Phase 2: Transition - At </think> token, switches to output mode
Phase 3: Natural Mode - Generates discrete tokens for human-readable output
logits = lm_head(hidden_state)
next_token = argmax(softmax(logits))
Training Process
Stage 1: Supervised Fine-Tuning (SFT)
- Train Granite on reasoning tasks with
<think>...</think>Answerformat - Best checkpoint: epoch 2 (validation loss: 1.2729)
Stage 2: Initialize Continuous Head
- Copy embedding and LM head weights from SFT model
- Freeze Granite backbone
- Only continuous components are trainable
Stage 3: Reinforcement Learning
- Algorithm: Gaussian Group Relative Policy Optimization (GRPO)
- Sample 4 completions per prompt
- Score with reward = answer_correctness + format_compliance
- Update via policy gradient with AdamW (lr=1e-5)
- Gaussian noise (Ο=0.1) during training for exploration
Performance
Benchmark Evaluation
Results on three standard benchmarks (1024 samples each):
| Model | MMLU | GSM8K | DROP | Overall |
|---|---|---|---|---|
| Gemma 3 (270M) | 19.24% | 0.39% | 0.98% | 6.87% |
| Granite 4 Nano (350M) | 7.13% | 4.00% | 9.96% | 7.03% |
| Granite 4 Nano SFT (350M) | 6.17% | 6.54% | 10.06% | 7.59% |
| Abstract Granite RL (500M) | 5.76% | 7.42% | 10.45% | 7.876% |
| DeepSeek R1 Qwen (1.5B) | 0.20% | 0.49% | 0.10% | 0.26% |
Efficiency Analysis
Despite having fewer parameters than larger baselines, Abstract Granite achieves competitive performance through specialized training for reasoning tasks.
Usage
Installation
pip install torch transformers
Basic Inference
from abstract_model import AbstractModel
model = AbstractModel.load_from_directory(
output_dir='./models/granite-abstract',
sft_model_path='./models/granite-sft',
device='cuda'
)
messages = [{"role": "user", "content": "What is 5 + 3?"}]
formatted = model.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
input_ids = model.tokenizer(formatted, return_tensors='pt')['input_ids'].to('cuda')
result = model.forward(input_ids, max_length=256, temperature=0.7)
response = model.tokenizer.decode(result['generated_tokens'])
print(response)
print("Modes:", result['mode_sequence']) # A=abstract, N=natural, T=transition
Model tree for Gavin-Wang/granite-abstract
Base model
ibm-granite/granite-4.0-h-350m-base
