| """SFT: teach Gemma-3-1B-IT to speak Slipstream (Slipstream-TQT). |
| |
| Run in Colab (recommended) or any GPU machine. |
| |
| Key requirements: |
| - transformers >= 4.50.0 for Gemma 3 |
| - trl, peft, datasets, accelerate |
| |
| Example: |
| python sft_gemma3_slipstream.py \ |
| --base_model google/gemma-3-1b-it \ |
| --dataset anthonym21/slipstream-tqt \ |
| --output_dir ./gemma3-slipstream-sft \ |
| --push_to_hub anthonym21/gemma-3-1b-it-slipstream-sft |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from typing import Dict, List |
|
|
| import torch |
| from datasets import load_dataset |
| from peft import LoraConfig |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments |
|
|
| from trl import SFTTrainer |
|
|
|
|
| def to_gemma_messages(system: str, user: str, assistant: str) -> List[Dict]: |
| |
| def seg(text: str): |
| return [{"type": "text", "text": text}] |
|
|
| msgs: List[Dict] = [] |
| if system.strip(): |
| msgs.append({"role": "system", "content": seg(system)}) |
| msgs.append({"role": "user", "content": seg(user)}) |
| msgs.append({"role": "assistant", "content": seg(assistant)}) |
| return msgs |
|
|
| def extract_slip_line(text: str) -> str: |
| """Extract the wire-format Slipstream line from a TQT response. |
| |
| The dataset examples often look like: |
| THOUGHT: ... |
| QUANTIZE: ... |
| SLIP: SLIP v1 ... |
| |
| We train the model to emit ONLY the final `SLIP v1 ...` line. |
| """ |
| t = (text or "").strip() |
| if not t: |
| return "" |
|
|
| |
| for line in t.splitlines(): |
| s = line.strip() |
| if s.startswith("SLIP:"): |
| s = s[len("SLIP:"):].strip() |
| if s.startswith("SLIP v1"): |
| return s |
| |
| for line in t.splitlines(): |
| if "SLIP v1" in line: |
| s = line.strip() |
| j = s.find("SLIP v1") |
| return s[j:].strip() |
| return t.splitlines()[-1].strip() |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--base_model", type=str, default="google/gemma-3-1b-it") |
| ap.add_argument("--dataset", type=str, default="anthonym21/slipstream-tqt") |
| ap.add_argument("--split", type=str, default="train") |
| ap.add_argument("--output_dir", type=str, default="./gemma3-slipstream-sft") |
| ap.add_argument("--max_seq_len", type=int, default=1024) |
| ap.add_argument("--num_train_epochs", type=float, default=1.0) |
| ap.add_argument("--per_device_train_batch_size", type=int, default=4) |
| ap.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| ap.add_argument("--learning_rate", type=float, default=2e-4) |
| ap.add_argument("--warmup_ratio", type=float, default=0.03) |
| ap.add_argument("--logging_steps", type=int, default=10) |
| ap.add_argument("--save_steps", type=int, default=200) |
| ap.add_argument("--push_to_hub", type=str, default="") |
| ap.add_argument("--hub_private_repo", action="store_true") |
| args = ap.parse_args() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| ds = load_dataset(args.dataset, split=args.split) |
|
|
| SYSTEM = ( |
| "You are a Slipstream protocol speaker. " |
| "Given a user intent, output ONLY a single wire-format line: `SLIP v1 ...`." |
| ) |
|
|
| def formatting_func(example): |
| |
| conv = example["conversations"] |
| user = next(m["value"] for m in conv if m["from"] == "human") |
| assistant = next(m["value"] for m in conv if m["from"] == "gpt") |
| assistant = extract_slip_line(assistant) |
| msgs = to_gemma_messages(SYSTEM, user, assistant) |
| return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) |
|
|
| peft_config = LoraConfig( |
| r=16, |
| lora_alpha=32, |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| ) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| args.base_model, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto", |
| ) |
|
|
| train_args = TrainingArguments( |
| output_dir=args.output_dir, |
| num_train_epochs=args.num_train_epochs, |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| warmup_ratio=args.warmup_ratio, |
| lr_scheduler_type="cosine", |
| logging_steps=args.logging_steps, |
| save_steps=args.save_steps, |
| save_total_limit=2, |
| bf16=torch.cuda.is_available(), |
| fp16=False, |
| optim="adamw_torch", |
| report_to=[], |
| push_to_hub=bool(args.push_to_hub), |
| hub_model_id=args.push_to_hub or None, |
| hub_private_repo=args.hub_private_repo, |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model, |
| args=train_args, |
| train_dataset=ds, |
| formatting_func=formatting_func, |
| max_seq_length=args.max_seq_len, |
| peft_config=peft_config, |
| ) |
|
|
| trainer.train() |
| trainer.save_model(args.output_dir) |
|
|
| if args.push_to_hub: |
| trainer.push_to_hub() |
|
|
| print("SFT complete:", args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|