Granite-Abstract: Continuous Head Embeddings for Reasoning

A language model with learnable continuous embeddings for abstract reasoning, trained SFT and reinforcement learning.

Overview

Granite-Abstract extends Granite with a continuous head that generates soft embeddings instead of discrete tokens. This enables the model to perform internal reasoning more efficiently while maintaining compatibility with standard generation.

Key characteristics:

  • Parallel continuous head for soft embeddings
  • Two-phase generation: abstract reasoning β†’ natural output
  • Trainable continuous components via reinforcement learning
  • 500M parameters total (Granite 350M + Continuous Head 150M)

Architecture

Dual-Head System

The model operates with two parallel heads:

Input β†’ Embedding Layer β†’ Granite Backbone β†’ Hidden State
                                                 ↓
                                    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
                                    ↓                           ↓
                            Continuous Head              LM Head
                          (soft embeddings)          (discrete tokens)

Generation Process

Phase 1: Abstract Mode - Model generates continuous embeddings for internal reasoning

cont_logits = continuous_head(hidden_state)
top_logits, top_indices = topk(cont_logits, 256)
next_embedding = softmax(top_logits) @ embed_layer[top_indices]

Phase 2: Transition - At </think> token, switches to output mode

Phase 3: Natural Mode - Generates discrete tokens for human-readable output

logits = lm_head(hidden_state)
next_token = argmax(softmax(logits))

Training Process

Stage 1: Supervised Fine-Tuning (SFT)

  • Train Granite on reasoning tasks with <think>...</think>Answer format
  • Best checkpoint: epoch 2 (validation loss: 1.2729)

Stage 2: Initialize Continuous Head

  • Copy embedding and LM head weights from SFT model
  • Freeze Granite backbone
  • Only continuous components are trainable

Stage 3: Reinforcement Learning

  • Algorithm: Gaussian Group Relative Policy Optimization (GRPO)
  • Sample 4 completions per prompt
  • Score with reward = answer_correctness + format_compliance
  • Update via policy gradient with AdamW (lr=1e-5)
  • Gaussian noise (Οƒ=0.1) during training for exploration

Performance

Benchmark Evaluation

Benchmark Performance by Model

Results on three standard benchmarks (1024 samples each):

Model MMLU GSM8K DROP Overall
Gemma 3 (270M) 19.24% 0.39% 0.98% 6.87%
Granite 4 Nano (350M) 7.13% 4.00% 9.96% 7.03%
Granite 4 Nano SFT (350M) 6.17% 6.54% 10.06% 7.59%
Abstract Granite RL (500M) 5.76% 7.42% 10.45% 7.876%
DeepSeek R1 Qwen (1.5B) 0.20% 0.49% 0.10% 0.26%

Efficiency Analysis

Model Size vs Accuracy

Despite having fewer parameters than larger baselines, Abstract Granite achieves competitive performance through specialized training for reasoning tasks.

Usage

Installation

pip install torch transformers

Basic Inference

from abstract_model import AbstractModel

model = AbstractModel.load_from_directory(
    output_dir='./models/granite-abstract',
    sft_model_path='./models/granite-sft',
    device='cuda'
)

messages = [{"role": "user", "content": "What is 5 + 3?"}]
formatted = model.tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
input_ids = model.tokenizer(formatted, return_tensors='pt')['input_ids'].to('cuda')

result = model.forward(input_ids, max_length=256, temperature=0.7)
response = model.tokenizer.decode(result['generated_tokens'])
print(response)
print("Modes:", result['mode_sequence'])  # A=abstract, N=natural, T=transition
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for Gavin-Wang/granite-abstract

Finetuned
(9)
this model

Datasets used to train Gavin-Wang/granite-abstract