|
|
""" |
|
|
train.py |
|
|
|
|
|
Main training script for VITRA Vision-Language-Action (VLA) models. |
|
|
Supports distributed training with FSDP (Fully Sharded Data Parallel) strategy. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import copy |
|
|
import datetime |
|
|
import faulthandler |
|
|
import json |
|
|
import os |
|
|
import random |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import wandb |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from vitra.datasets.materialize import get_vla_dataset_and_collator |
|
|
from vitra.models.vla_builder import build_vla, load_vla_checkpoint |
|
|
from vitra.training import VLAMetrics |
|
|
from vitra.utils import ( |
|
|
find_last_checkpoint, |
|
|
get_epoch_and_step_from_checkpoint, |
|
|
set_global_seed, |
|
|
setup_seed, |
|
|
) |
|
|
from vitra.training.fsdp import VLAFSDPStrategy |
|
|
from vitra.utils.config_utils import load_config |
|
|
from vitra.utils.overwatch import initialize_overwatch |
|
|
|
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
|
|
|
overwatch = initialize_overwatch(__name__) |
|
|
|
|
|
def experiment(variant): |
|
|
""" |
|
|
Main training experiment function for VITRA VLA models. |
|
|
|
|
|
Args: |
|
|
variant: Configuration dictionary containing all training parameters including: |
|
|
- Model architecture settings |
|
|
- Training hyperparameters |
|
|
- Dataset configurations |
|
|
- Logging and checkpoint paths |
|
|
""" |
|
|
|
|
|
torch.cuda.set_device(device_id := overwatch.local_rank()) |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
overwatch.info("VITRA VLA Training :: Creating Folders", ctx_level=1) |
|
|
wandb_api_key = os.getenv("WANDB_API_KEY") |
|
|
if wandb_api_key is None: |
|
|
raise ValueError("Please set the WANDB_API_KEY environment variable.") |
|
|
wandb.login(key=wandb_api_key) |
|
|
|
|
|
|
|
|
os.makedirs(variant["log_root"], exist_ok=True) |
|
|
os.makedirs(variant["output_root"], exist_ok=True) |
|
|
os.makedirs(variant["cache_root"], exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
run_id = variant["task_name"] if "task_name" in variant else None |
|
|
batch_size = variant["batch_size"] |
|
|
total_batch_size = variant["total_batch_size"] |
|
|
run_id = f"{run_id}_TB{total_batch_size}_B{batch_size}_bf16{variant['use_bf16']}" |
|
|
|
|
|
checkpoint_dir = os.path.join(variant["output_root"], run_id) |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
worker_init_fn = set_global_seed(variant["seed"], get_worker_init_fn=True) |
|
|
|
|
|
|
|
|
def posix_to_str(d): |
|
|
if isinstance(d, dict): |
|
|
return {k: posix_to_str(v) for k, v in d.items()} |
|
|
elif isinstance(d, list): |
|
|
return [posix_to_str(v) for v in d] |
|
|
elif isinstance(d, Path): |
|
|
return str(d) |
|
|
else: |
|
|
return d |
|
|
|
|
|
variant_str = copy.deepcopy(variant) |
|
|
copied_variant = posix_to_str(variant_str) |
|
|
|
|
|
if overwatch.rank() == 0: |
|
|
with open(os.path.join(checkpoint_dir, "config.json"), "w") as f: |
|
|
json.dump(copied_variant, f, indent=2) |
|
|
overwatch.info(f"Config saved to {checkpoint_dir}", ctx_level=1) |
|
|
print(json.dumps(copied_variant, indent=2)) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
overwatch.info("Loading model", ctx_level=1) |
|
|
resume_step = 0 |
|
|
resume_epoch = 0 |
|
|
model_load_path = variant["model_load_path"] |
|
|
|
|
|
|
|
|
if variant["resume"]: |
|
|
|
|
|
if model_load_path is None: |
|
|
model_load_path = find_last_checkpoint(checkpoint_dir) |
|
|
|
|
|
|
|
|
if model_load_path is not None: |
|
|
resume_epoch, resume_step = get_epoch_and_step_from_checkpoint(model_load_path) |
|
|
if overwatch.rank() == 0: |
|
|
overwatch.info( |
|
|
f"Resume from {model_load_path}, epoch: {resume_epoch}, step: {resume_step}", |
|
|
ctx_level=1 |
|
|
) |
|
|
|
|
|
|
|
|
model = build_vla(configs=variant) |
|
|
pretrain_path = variant.get("pretrain_path", None) |
|
|
if variant['resume'] and model_load_path is not None: |
|
|
model = load_vla_checkpoint(model, os.path.join(model_load_path, "weights.pt")) |
|
|
elif pretrain_path is not None: |
|
|
if os.path.isdir(pretrain_path): |
|
|
model = load_vla_checkpoint(model, os.path.join(pretrain_path, "weights.pt")) |
|
|
else: |
|
|
model = load_vla_checkpoint(model, pretrain_path) |
|
|
|
|
|
model = model.train() |
|
|
model.trainable_params_setup() |
|
|
model.model.use_bf16 = variant["use_bf16"] |
|
|
model.use_bf16 = variant["use_bf16"] |
|
|
|
|
|
|
|
|
if variant.get("debug", False): |
|
|
for p in model.model.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
all_params = sum(p.numel() for p in model.parameters()) |
|
|
if overwatch.rank() == 0: |
|
|
overwatch.info(f"Trainable Model Parameters: {total_params/1e6:.2f}M/{all_params/1e6:.2f}M") |
|
|
|
|
|
processor = model.processor |
|
|
|
|
|
|
|
|
|
|
|
vla_dataset, collator, batch_sampler = get_vla_dataset_and_collator( |
|
|
variant["train_dataset"]["data_root_dir"], |
|
|
variant["train_dataset"]["data_mix"], |
|
|
augmentation=variant["train_dataset"]["augmentation"], |
|
|
shard_num=dist.get_world_size(), |
|
|
shard_index=dist.get_rank(), |
|
|
seed=variant["seed"], |
|
|
future_action_window_size=variant["fwd_pred_next_n"] - 1, |
|
|
processor=processor, |
|
|
batch_size=batch_size, |
|
|
normalization=variant["train_dataset"].get("normalization", True), |
|
|
flip_augmentation=variant["train_dataset"].get("flip_augmentation", 1.0), |
|
|
set_none_ratio=variant["train_dataset"].get("set_none_ratio", 0.0), |
|
|
action_type=variant["train_dataset"].get('action_type', 'angle'), |
|
|
use_rel=variant["train_dataset"].get('use_rel', False), |
|
|
rel_mode=variant["train_dataset"].get('rel_mode', "step"), |
|
|
clip_len=variant["train_dataset"].get('clip_len', None), |
|
|
state_mask_prob=variant["train_dataset"].get('state_mask_prob', 0.1), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
training_strategy = VLAFSDPStrategy( |
|
|
vla=model, |
|
|
device_id=overwatch.local_rank(), |
|
|
stage=None, |
|
|
epochs=variant["trainer"]["max_epochs"], |
|
|
max_steps=variant["trainer"]["max_steps"], |
|
|
global_batch_size=variant["total_batch_size"], |
|
|
per_device_batch_size=batch_size, |
|
|
learning_rate=variant["trainer"]["learning_rate"], |
|
|
weight_decay=variant["trainer"]["weight_decay"], |
|
|
max_grad_norm=variant["trainer"]["gradient_clip_val"], |
|
|
lr_scheduler_type=variant["trainer"]["lr_scheduler_type"], |
|
|
warmup_ratio=variant["trainer"]["warmup_ratio"], |
|
|
enable_gradient_checkpointing=variant["trainer"]["enable_gradient_checkpointing"], |
|
|
enable_mixed_precision_training=variant["trainer"]["enable_mixed_precision_training"], |
|
|
reduce_in_full_precision=variant["trainer"]["reduce_in_full_precision"], |
|
|
action_model_learning_rate=variant["trainer"].get("action_model_learning_rate", None), |
|
|
action_model_weight_decay=variant["trainer"].get("action_model_weight_decay", None), |
|
|
sharding_strategy=variant["trainer"].get("sharding_strategy", "shard-grad-op"), |
|
|
cognition_token_weight_decay=variant["trainer"].get("cognition_token_weight_decay", True), |
|
|
llm_freeze_step=variant["trainer"].get("llm_freeze_step", 0), |
|
|
move_word_embedding_to_action_model=variant["trainer"].get("move_word_embedding_to_action_head", False), |
|
|
optimizer_betas=variant["trainer"].get("optimizer_betas", (0.9, 0.999)), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if variant["vla_name"] == "VITRA_Paligemma": |
|
|
auto_wrap_policy, checkpointing_policy = get_fsdp_wrap_policy_and_checkpointing(variant["trainer"]) |
|
|
else: |
|
|
raise NotImplementedError(f"Unsupported VLA name: {variant['vla_name']}") |
|
|
|
|
|
|
|
|
training_strategy.run_setup( |
|
|
run_dir=checkpoint_dir, |
|
|
n_train_examples=len(vla_dataset), |
|
|
auto_wrap_policy_modules=auto_wrap_policy, |
|
|
checkpointing_policy_modules=checkpointing_policy, |
|
|
) |
|
|
|
|
|
|
|
|
if variant["resume"] == True and model_load_path is not None: |
|
|
training_strategy.load_optimizer_and_scheduler(model_load_path) |
|
|
|
|
|
|
|
|
|
|
|
trackers = ["wandb"] |
|
|
overwatch.info(f"Creating Metrics with Active Trackers => `{trackers}`") |
|
|
metrics = VLAMetrics( |
|
|
trackers, |
|
|
hparams=variant_str, |
|
|
run_id=run_id, |
|
|
run_dir=checkpoint_dir, |
|
|
wandb_project=variant["wandb_project"], |
|
|
wandb_entity=variant["wandb_entity"], |
|
|
resume_step=resume_step, |
|
|
resume_epoch=resume_epoch, |
|
|
) |
|
|
|
|
|
|
|
|
overwatch.info("Creating Dataloader", ctx_level=1) |
|
|
|
|
|
num_workers = variant["num_workers"] if variant["num_workers"] is not None else variant["train_dataset"]["num_workers"] |
|
|
prefetch_factor = variant["prefetch_factor"] if variant["prefetch_factor"] is not None else variant["train_dataset"]["prefetch_factor"] |
|
|
|
|
|
if num_workers == 0 or prefetch_factor == 0: |
|
|
prefetch_factor = None |
|
|
|
|
|
if overwatch.rank() == 0: |
|
|
print(f"num_workers: {num_workers}, prefetch_factor: {prefetch_factor}") |
|
|
|
|
|
|
|
|
batch_sampler.set_epoch(resume_epoch, resume_step * training_strategy.grad_accumulation_steps) |
|
|
|
|
|
setup_seed(variant["seed"], rank=torch.distributed.get_rank()) |
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
vla_dataset, |
|
|
batch_sampler=batch_sampler, |
|
|
collate_fn=collator, |
|
|
num_workers=num_workers, |
|
|
prefetch_factor=prefetch_factor, |
|
|
worker_init_fn=worker_init_fn, |
|
|
persistent_workers=num_workers > 0, |
|
|
pin_memory=num_workers > 0, |
|
|
) |
|
|
|
|
|
|
|
|
overwatch.info("Starting VLA Training Loop") |
|
|
training_strategy.run_training( |
|
|
dataloader, |
|
|
metrics, |
|
|
save_interval=variant["save_steps"], |
|
|
start_global_step=resume_step, |
|
|
start_epoch=resume_epoch, |
|
|
) |
|
|
|
|
|
|
|
|
overwatch.info("Done with Training =>> Finalizing Metrics") |
|
|
metrics.finalize() |
|
|
|
|
|
|
|
|
overwatch.info("... and that's all, folks!") |
|
|
dist.barrier() |
|
|
dist.destroy_process_group() |
|
|
|
|
|
def get_fsdp_wrap_policy_and_checkpointing(configs): |
|
|
""" |
|
|
Get FSDP auto-wrapping policy and activation checkpointing policy for PaliGemma models. |
|
|
|
|
|
The auto-wrap policy determines which module types should be individually wrapped by FSDP, |
|
|
allowing for efficient memory usage and communication in distributed training. |
|
|
|
|
|
The checkpointing policy determines which modules should use activation checkpointing |
|
|
(gradient checkpointing) to trade computation for memory during training. |
|
|
|
|
|
Args: |
|
|
configs: Trainer configuration dictionary containing strategy settings |
|
|
|
|
|
Returns: |
|
|
Tuple of (auto_wrap_policy, checkpointing_policy): |
|
|
- auto_wrap_policy: Set of module classes to wrap with FSDP |
|
|
- checkpointing_policy: Set of module classes to apply gradient checkpointing, or None |
|
|
""" |
|
|
if 'strategy' not in configs or configs['strategy'] == 'ddp': |
|
|
raise NotImplementedError("FSDP strategy not specified or DDP selected.") |
|
|
|
|
|
|
|
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer |
|
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector |
|
|
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer, SiglipVisionTransformer |
|
|
|
|
|
from vitra.models.action_model import DiT |
|
|
from vitra.utils.nn_utils import MLPProjector |
|
|
|
|
|
|
|
|
policy = { |
|
|
SiglipEncoderLayer, |
|
|
SiglipVisionTransformer, |
|
|
DiT, |
|
|
Gemma2DecoderLayer, |
|
|
PaliGemmaMultiModalProjector, |
|
|
MLPProjector |
|
|
} |
|
|
|
|
|
|
|
|
checkpointing_policy = ( |
|
|
{Gemma2DecoderLayer} |
|
|
if configs["strategy"] == "fsdp_paligemma_with_checkpointing" |
|
|
else None |
|
|
) |
|
|
|
|
|
return policy, checkpointing_policy |
|
|
|
|
|
def update_configs(configs, args): |
|
|
""" |
|
|
Update configuration dictionary with command-line arguments. |
|
|
|
|
|
Command-line arguments take precedence over config file values. This function |
|
|
handles both top-level parameters and nested dictionaries (e.g., trainer settings). |
|
|
|
|
|
Args: |
|
|
configs: Base configuration dictionary loaded from YAML/JSON config file |
|
|
args: Parsed command-line arguments dictionary |
|
|
|
|
|
Returns: |
|
|
Updated configuration dictionary with command-line overrides applied |
|
|
""" |
|
|
if args["task_name"] is not None: |
|
|
configs["task_name"] = args["task_name"] |
|
|
|
|
|
configs["use_bf16"] = ( |
|
|
args["use_bf16"] |
|
|
if args["use_bf16"] is not None |
|
|
else configs.get("use_bf16", False) |
|
|
) |
|
|
|
|
|
if args["data_mix"] is not None: |
|
|
configs["train_dataset"]["data_mix"] = args["data_mix"] |
|
|
|
|
|
configs["output_root"] = Path(configs["output_root"]) |
|
|
configs["log_root"] = Path(configs["log_root"]) |
|
|
configs["cache_root"] = Path(configs["cache_root"]) / configs["model"] |
|
|
|
|
|
|
|
|
for k, v in args.items(): |
|
|
if k not in configs: |
|
|
print(f"{k} not in config. The value is {v}.") |
|
|
configs[k] = v |
|
|
elif isinstance(v, dict): |
|
|
for sub_k, sub_v in v.items(): |
|
|
if sub_v is not None: |
|
|
configs[k][sub_k] = sub_v |
|
|
elif v is not None: |
|
|
configs[k] = v |
|
|
|
|
|
return configs |
|
|
|
|
|
def parse_args(): |
|
|
""" |
|
|
Parse command-line arguments for training configuration. |
|
|
|
|
|
Arguments are organized into two groups: |
|
|
1. Global arguments (experiment settings, paths, data configuration) |
|
|
2. Trainer arguments (training hyperparameters and strategy) |
|
|
|
|
|
Returns: |
|
|
Dictionary with structure: |
|
|
{ |
|
|
'config': str, |
|
|
'seed': int, |
|
|
...other global args..., |
|
|
'trainer': { |
|
|
'strategy': str, |
|
|
'gradient_clip_val': float, |
|
|
...other trainer args... |
|
|
} |
|
|
} |
|
|
""" |
|
|
parser = argparse.ArgumentParser(description="VITRA VLA Training Script") |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--config", |
|
|
type=str, |
|
|
help="Path to YAML/JSON configuration file for training" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
default=None, |
|
|
type=int, |
|
|
help="Random seed for reproducibility" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--log_root", |
|
|
default=None, |
|
|
type=str, |
|
|
help="Root directory for logging" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_root", |
|
|
default=None, |
|
|
type=str, |
|
|
help="Root directory for checkpoints and outputs" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_load_path", |
|
|
default=None, |
|
|
type=str, |
|
|
help="Path to checkpoint for resuming training" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--task_name", |
|
|
default=None, |
|
|
type=str, |
|
|
help="Unique identifier for this training run" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--use_bf16", |
|
|
default=None, |
|
|
action="store_true", |
|
|
help="Enable bfloat16 mixed precision training" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--data_mix", |
|
|
default=None, |
|
|
type=str, |
|
|
help="Dataset mixture configuration" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--debug", |
|
|
default=False, |
|
|
action="store_true", |
|
|
help="Enable debug mode (freezes model parameters)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fwd_pred_next_n", |
|
|
default=None, |
|
|
type=int, |
|
|
help="Number of future action steps to predict" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch_size", |
|
|
default=None, |
|
|
type=int, |
|
|
help="Per-device batch size" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--total_batch_size", |
|
|
default=None, |
|
|
type=int, |
|
|
help="Global batch size across all devices" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_workers", |
|
|
default=None, |
|
|
type=int, |
|
|
help="Number of data loading workers per process" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prefetch_factor", |
|
|
default=None, |
|
|
type=int, |
|
|
help="Number of batches to prefetch per worker" |
|
|
) |
|
|
|
|
|
|
|
|
global_names = set(vars(parser.parse_known_args()[0]).keys()) |
|
|
|
|
|
|
|
|
trainer_parser = parser.add_argument_group("trainer", "Training strategy and hyperparameters") |
|
|
trainer_parser.add_argument( |
|
|
"--strategy", |
|
|
default=None, |
|
|
type=str, |
|
|
help="Training strategy (e.g., 'fsdp')" |
|
|
) |
|
|
trainer_parser.add_argument( |
|
|
"--gradient_clip_val", |
|
|
default=None, |
|
|
type=float, |
|
|
help="Maximum gradient norm for clipping" |
|
|
) |
|
|
trainer_parser.add_argument( |
|
|
"--max_steps", |
|
|
default=None, |
|
|
type=int, |
|
|
help="Maximum number of training steps (overrides epochs)" |
|
|
) |
|
|
|
|
|
|
|
|
trainer_names = set(vars(parser.parse_known_args()[0]).keys()) - global_names |
|
|
|
|
|
|
|
|
args = {} |
|
|
trainer_args = {} |
|
|
temp_args = vars(parser.parse_args()) |
|
|
|
|
|
|
|
|
for k, v in temp_args.items(): |
|
|
if k in global_names: |
|
|
args[k] = v |
|
|
elif k in trainer_names: |
|
|
trainer_args[k] = v |
|
|
|
|
|
|
|
|
args["trainer"] = trainer_args |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
faulthandler.enable() |
|
|
|
|
|
args = parse_args() |
|
|
|
|
|
configs = load_config(args.get("config")) |
|
|
configs = update_configs(configs, args) |
|
|
|
|
|
|
|
|
if not dist.is_initialized(): |
|
|
dist.init_process_group(backend="nccl") |
|
|
|
|
|
experiment(variant=configs) |