anthonym21's picture
Initial Commit with GRPO notebook
935a6ef
"""SFT: teach Gemma-3-1B-IT to speak Slipstream (Slipstream-TQT).
Run in Colab (recommended) or any GPU machine.
Key requirements:
- transformers >= 4.50.0 for Gemma 3
- trl, peft, datasets, accelerate
Example:
python sft_gemma3_slipstream.py \
--base_model google/gemma-3-1b-it \
--dataset anthonym21/slipstream-tqt \
--output_dir ./gemma3-slipstream-sft \
--push_to_hub anthonym21/gemma-3-1b-it-slipstream-sft
"""
from __future__ import annotations
import argparse
from typing import Dict, List
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
def to_gemma_messages(system: str, user: str, assistant: str) -> List[Dict]:
# Gemma 3 chat template supports multimodal; we use text-only segments.
def seg(text: str):
return [{"type": "text", "text": text}]
msgs: List[Dict] = []
if system.strip():
msgs.append({"role": "system", "content": seg(system)})
msgs.append({"role": "user", "content": seg(user)})
msgs.append({"role": "assistant", "content": seg(assistant)})
return msgs
def extract_slip_line(text: str) -> str:
"""Extract the wire-format Slipstream line from a TQT response.
The dataset examples often look like:
THOUGHT: ...
QUANTIZE: ...
SLIP: SLIP v1 ...
We train the model to emit ONLY the final `SLIP v1 ...` line.
"""
t = (text or "").strip()
if not t:
return ""
# Prefer an explicit `SLIP:` line
for line in t.splitlines():
s = line.strip()
if s.startswith("SLIP:"):
s = s[len("SLIP:"):].strip()
if s.startswith("SLIP v1"):
return s
# Fallback: first line containing `SLIP v1`
for line in t.splitlines():
if "SLIP v1" in line:
s = line.strip()
j = s.find("SLIP v1")
return s[j:].strip()
return t.splitlines()[-1].strip()
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--base_model", type=str, default="google/gemma-3-1b-it")
ap.add_argument("--dataset", type=str, default="anthonym21/slipstream-tqt")
ap.add_argument("--split", type=str, default="train")
ap.add_argument("--output_dir", type=str, default="./gemma3-slipstream-sft")
ap.add_argument("--max_seq_len", type=int, default=1024)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--per_device_train_batch_size", type=int, default=4)
ap.add_argument("--gradient_accumulation_steps", type=int, default=4)
ap.add_argument("--learning_rate", type=float, default=2e-4)
ap.add_argument("--warmup_ratio", type=float, default=0.03)
ap.add_argument("--logging_steps", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=200)
ap.add_argument("--push_to_hub", type=str, default="")
ap.add_argument("--hub_private_repo", action="store_true")
args = ap.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
ds = load_dataset(args.dataset, split=args.split)
SYSTEM = (
"You are a Slipstream protocol speaker. "
"Given a user intent, output ONLY a single wire-format line: `SLIP v1 ...`."
)
def formatting_func(example):
# Dataset structure: {"conversations": [{"from": "human"|"gpt", "value": "..."}]}
conv = example["conversations"]
user = next(m["value"] for m in conv if m["from"] == "human")
assistant = next(m["value"] for m in conv if m["from"] == "gpt")
assistant = extract_slip_line(assistant)
msgs = to_gemma_messages(SYSTEM, user, assistant)
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
train_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type="cosine",
logging_steps=args.logging_steps,
save_steps=args.save_steps,
save_total_limit=2,
bf16=torch.cuda.is_available(),
fp16=False,
optim="adamw_torch",
report_to=[],
push_to_hub=bool(args.push_to_hub),
hub_model_id=args.push_to_hub or None,
hub_private_repo=args.hub_private_repo,
)
trainer = SFTTrainer(
model=model,
args=train_args,
train_dataset=ds,
formatting_func=formatting_func,
max_seq_length=args.max_seq_len,
peft_config=peft_config,
)
trainer.train()
trainer.save_model(args.output_dir)
if args.push_to_hub:
trainer.push_to_hub()
print("SFT complete:", args.output_dir)
if __name__ == "__main__":
main()