Spaces:
Paused
Paused
| """ Main training script """ | |
| import argparse | |
| import glob | |
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import wandb | |
| from data import get_data | |
| from distributed import init_distributed_device, world_info_from_env | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
| from train_utils import ( | |
| train_one_epoch, | |
| get_mp_policy_dtype, | |
| save_checkpoint, | |
| ) | |
| from transformers import ( | |
| get_constant_schedule_with_warmup, | |
| get_cosine_schedule_with_warmup, | |
| get_linear_schedule_with_warmup, | |
| ) | |
| from torch.distributed.fsdp import ( | |
| CPUOffload, | |
| MixedPrecision, | |
| ShardingStrategy, | |
| BackwardPrefetch, | |
| ) | |
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
| checkpoint_wrapper, | |
| CheckpointWrapper, | |
| CheckpointImpl, | |
| apply_activation_checkpointing, | |
| ) | |
| from torch.distributed.fsdp._init_utils import _init_intra_and_inter_node_groups | |
| from torch.distributed.distributed_c10d import _get_default_group | |
| import functools | |
| from open_flamingo import create_model_and_transforms | |
| def random_seed(seed=42, rank=0): | |
| torch.manual_seed(seed + rank) | |
| np.random.seed(seed + rank) | |
| random.seed(seed + rank) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| # model configuration args | |
| parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) | |
| parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) | |
| parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str) | |
| parser.add_argument( | |
| "--tokenizer_path", | |
| default="facebook/opt-30b", | |
| type=str, | |
| help="path to tokenizer", | |
| ) | |
| parser.add_argument( | |
| "--cross_attn_every_n_layers", | |
| type=int, | |
| default=1, | |
| help="how often to add a cross-attention layer after each transformer layer", | |
| ) | |
| # training args | |
| parser.add_argument( | |
| "--run_name", | |
| type=str, | |
| default="openflamingo3B", | |
| help="used to name saving directory and wandb run", | |
| ) | |
| parser.add_argument( | |
| "--resume_from_checkpoint", | |
| type=str, | |
| help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--delete_previous_checkpoint", | |
| action="store_true", | |
| help="delete previous checkpoint when saving new checkpoint", | |
| ) | |
| parser.add_argument("--batch_size_mmc4", type=int, default=128) | |
| parser.add_argument("--batch_size_laion", type=int, default=128) | |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=1) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--learning_rate", default=1e-4, type=float) | |
| parser.add_argument( | |
| "--lr_scheduler", | |
| default="constant", | |
| type=str, | |
| help="constant, linear, or cosine", | |
| ) | |
| parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0) | |
| parser.add_argument("--loss_multiplier_laion", type=float, default=1.0) | |
| parser.add_argument("--warmup_steps", default=5000, type=int) | |
| parser.add_argument("--weight_decay", default=0.1, type=float) | |
| parser.add_argument( | |
| "--precision", | |
| choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], | |
| default="fp32", | |
| help="Floating point precision.", | |
| ) | |
| parser.add_argument( | |
| "--gradient_checkpointing", | |
| action="store_true", | |
| help="whether to train with gradient/activation checkpointing", | |
| ) | |
| parser.add_argument( | |
| "--num_epochs", | |
| type=int, | |
| default=1, | |
| help="we define an 'epoch' as a fixed number of examples (train_num_samples_mmc4, train_num_samples_laion), not a pass through the entire dataset", | |
| ) | |
| parser.add_argument("--offline", action="store_true") | |
| parser.add_argument( | |
| "--freeze_lm_embeddings", | |
| action="store_true", | |
| help="if True, we freeze the LM embeddings during training. Otherwise, we train the <image> and <|endofchunk|> embeddings.", | |
| ) | |
| parser.add_argument( | |
| "--logging_steps", type=int, default=100, help="log loss every n steps" | |
| ) | |
| # data args | |
| parser.add_argument( | |
| "--laion_shards", | |
| type=str, | |
| help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", | |
| ) | |
| parser.add_argument( | |
| "--mmc4_shards", | |
| type=str, | |
| help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", | |
| ) | |
| parser.add_argument("--workers", type=int, default=1) | |
| parser.add_argument("--train_num_samples_mmc4", type=int, default=10000) | |
| parser.add_argument("--train_num_samples_laion", type=int, default=10000) | |
| parser.add_argument("--dataset_resampled", action="store_true") | |
| parser.add_argument( | |
| "--mmc4_textsim_threshold", | |
| default=30, | |
| type=float, | |
| help="threshold for filtering images in mmc4 based on image-text similarity", | |
| ) | |
| parser.add_argument( | |
| "--mmc4_max_num_images", | |
| default=6, | |
| type=int, | |
| help="max number of images per sequence in mmc4 / chatgpt", | |
| ) | |
| parser.add_argument( | |
| "--mmc4_min_num_images", | |
| default=1, | |
| type=int, | |
| help="min number of images per sequence in mmc4 / chatgpt", | |
| ) | |
| # distributed training args | |
| parser.add_argument( | |
| "--dist-url", | |
| default="env://", | |
| type=str, | |
| help="url used to set up distributed training", | |
| ) | |
| parser.add_argument( | |
| "--dist-backend", default="nccl", type=str, help="distributed backend" | |
| ) | |
| parser.add_argument( | |
| "--horovod", | |
| default=False, | |
| action="store_true", | |
| help="Use horovod for distributed training.", | |
| ) | |
| parser.add_argument( | |
| "--no-set-device-rank", | |
| default=False, | |
| action="store_true", | |
| help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", | |
| ) | |
| parser.add_argument( | |
| "--fsdp", | |
| default=False, | |
| action="store_true", | |
| help="Use FullyShardedDataParallel for distributed training.", | |
| ) | |
| parser.add_argument( | |
| "--fsdp_use_orig_params", | |
| default=False, | |
| action="store_true", | |
| help="Passed into the FSDP constructor. Enables param_groups and gradient masking for weight_decay. Does not work with OPT.", | |
| ) | |
| parser.add_argument( | |
| "--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid"] | |
| ) | |
| # wandb args | |
| parser.add_argument("--report_to_wandb", default=False, action="store_true") | |
| parser.add_argument( | |
| "--wandb_project", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--wandb_entity", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--save_checkpoints_to_wandb", | |
| default=False, | |
| action="store_true", | |
| help="save checkpoints to wandb", | |
| ) | |
| args = parser.parse_args() | |
| # Validate args | |
| if args.laion_shards.startswith("s3"): | |
| args.laion_shards = f"pipe:aws s3 cp {args.laion_shards} -" | |
| if args.mmc4_shards.startswith("s3"): | |
| args.mmc4_shards = f"pipe:aws s3 cp {args.mmc4_shards} -" | |
| if args.save_checkpoints_to_wandb and not args.report_to_wandb: | |
| raise ValueError("save_checkpoints_to_wandb requires report_to_wandb") | |
| if args.fsdp and not args.fsdp_use_orig_params: | |
| print( | |
| "Warning: FSDP is running without fsdp_use_orig_params flag. " | |
| + "This is not recommended because it means we will use uniform weight decay" | |
| + " and train all embeddings, not just the newly added ones. " | |
| + "Note: OPT models are not compatible with fsdp_use_orig_params flag." | |
| ) | |
| if args.fsdp and args.fsdp_sharding_strategy == "hybrid": | |
| print( | |
| "Warning: As of torch=2.0.1, the FSDP logic for optim_state_dict() is broken for hybrid sharding." | |
| + "To make this method work, we need to modify torch.distributed.fsdp._optim_utils.py" | |
| + "Copy and paste the code from the _optim_utils.py in this repo into the torch file." | |
| + "The main issue was the missing group kwarg on line 1596 in _all_gather_optim_state." | |
| ) | |
| assert (args.train_num_samples_laion // args.batch_size_laion) == ( | |
| args.train_num_samples_mmc4 // args.batch_size_mmc4 | |
| ), "number of samples per epoch must be equal for mmc4 and laion" | |
| # Set up distributed training | |
| if args.offline: | |
| os.environ["WANDB_MODE"] = "offline" | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| args.local_rank, args.rank, args.world_size = world_info_from_env() | |
| device_id = init_distributed_device(args) | |
| random_seed(args.seed) | |
| # Initialize model | |
| model, image_processor, tokenizer = create_model_and_transforms( | |
| args.vision_encoder_path, | |
| args.vision_encoder_pretrained, | |
| args.lm_path, | |
| args.tokenizer_path if args.tokenizer_path else args.lm_path, | |
| cross_attn_every_n_layers=args.cross_attn_every_n_layers, | |
| use_local_files=args.offline, | |
| gradient_checkpointing=args.gradient_checkpointing, | |
| freeze_lm_embeddings=args.freeze_lm_embeddings, | |
| ) | |
| random_seed(args.seed, args.rank) | |
| # Initialize logging | |
| print(f"Start running training on rank {args.rank}.") | |
| if args.rank == 0 and args.report_to_wandb: | |
| wandb.init( | |
| project=args.wandb_project, | |
| entity=args.wandb_entity, | |
| name=args.run_name, | |
| config=vars(args), | |
| ) | |
| # Load model checkpoint on CPU | |
| if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None: | |
| # if args do not specify a checkpoint to resume from, check if checkpoints exist for this run | |
| # and automatically resume from the latest checkpoint | |
| checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") | |
| if len(checkpoint_list) == 0: | |
| print(f"Found no checkpoints for run {args.run_name}.") | |
| else: | |
| args.resume_from_checkpoint = sorted( | |
| checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]) | |
| )[-1] | |
| print( | |
| f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}." | |
| ) | |
| resume_from_epoch = 0 | |
| if args.resume_from_checkpoint is not None: | |
| if args.rank == 0: | |
| print(f"Loading checkpoint from {args.resume_from_checkpoint}") | |
| checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") | |
| msd = checkpoint["model_state_dict"] | |
| msd = {k.replace("module.", ""): v for k, v in msd.items()} | |
| resume_from_epoch = checkpoint["epoch"] + 1 | |
| # for fsdp, only one rank needs to load the state dict | |
| if not args.fsdp or args.rank == 0: | |
| model.load_state_dict(msd, False) | |
| # Initialize FSDP / DDP, and ensure the model is on GPU | |
| print(f"Initializing distributed training with {args.world_size} GPUs.") | |
| if args.fsdp: | |
| print( | |
| f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}" | |
| ) | |
| # init MixedPrecision | |
| if args.precision != "fp32": | |
| cast_dtype = get_mp_policy_dtype(args.precision) | |
| mp_policy = MixedPrecision( | |
| param_dtype=torch.float32, | |
| reduce_dtype=cast_dtype, # gradient communication | |
| buffer_dtype=cast_dtype, | |
| ) | |
| else: | |
| mp_policy = None | |
| # init process groups | |
| if args.fsdp_sharding_strategy == "hybrid": | |
| intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( | |
| _get_default_group() | |
| ) | |
| args.my_group = intra_node_group # for optimizer saving | |
| process_group = (intra_node_group, inter_node_group) # for FSDP init | |
| else: | |
| args.my_group = None # for optimizer saving | |
| process_group = None # for FSDP init | |
| # init FSDP | |
| wrapper_kwargs = dict( | |
| process_group=process_group, | |
| cpu_offload=CPUOffload(offload_params=False), | |
| device_id=device_id, | |
| sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks | |
| sharding_strategy=ShardingStrategy.FULL_SHARD | |
| if args.fsdp_sharding_strategy == "full" | |
| else ShardingStrategy.HYBRID_SHARD, | |
| use_orig_params=args.fsdp_use_orig_params, | |
| mixed_precision=mp_policy, | |
| forward_prefetch=True, | |
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | |
| limit_all_gathers=True, | |
| ) | |
| model.wrap_fsdp(wrapper_kwargs, device_id) | |
| ddp_model = model | |
| print( | |
| f"After FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}" | |
| ) | |
| print( | |
| f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}" | |
| ) | |
| else: | |
| model = model.to(device_id) | |
| ddp_model = DDP(model, device_ids=[device_id]) | |
| # Initialize gradient checkpointing | |
| if args.gradient_checkpointing: | |
| non_reentrant_wrapper = functools.partial( | |
| checkpoint_wrapper, | |
| offload_to_cpu=True, | |
| checkpoint_impl=CheckpointImpl.NO_REENTRANT, | |
| ) | |
| apply_activation_checkpointing( | |
| ddp_model, | |
| checkpoint_wrapper_fn=non_reentrant_wrapper, | |
| check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) | |
| and not isinstance(m, FSDP) | |
| and not isinstance(m, CheckpointWrapper), | |
| ) | |
| # Initialize optimizer | |
| params_to_optimize = ddp_model.named_parameters() | |
| params_to_optimize = list( | |
| filter( | |
| lambda x: x[1].requires_grad | |
| and not getattr(x[1], "exclude_from_optimizer", False), | |
| params_to_optimize, | |
| ) | |
| ) | |
| if not args.fsdp or args.fsdp_use_orig_params: | |
| # apply weight decay only to params in the xattn layers | |
| def get_grouped_params(model): | |
| params_with_wd, params_without_wd = [], [] | |
| for n, p in params_to_optimize: | |
| if "gated_cross_attn" in n: | |
| params_with_wd.append(p) | |
| else: | |
| params_without_wd.append(p) | |
| return [ | |
| {"params": params_with_wd, "weight_decay": args.weight_decay}, | |
| {"params": params_without_wd, "weight_decay": 0.0}, | |
| ] | |
| optimizer = torch.optim.AdamW( | |
| get_grouped_params(params_to_optimize), lr=args.learning_rate | |
| ) | |
| else: | |
| # unclear if we should be using no weight decay or small weight decay for all parameters | |
| optimizer = torch.optim.AdamW( | |
| (p for _, p in params_to_optimize), | |
| lr=args.learning_rate, | |
| weight_decay=args.weight_decay, | |
| ) | |
| # load optimizer checkpoint | |
| if args.resume_from_checkpoint is not None: | |
| osd = checkpoint["optimizer_state_dict"] | |
| if args.fsdp: | |
| osd = FSDP.optim_state_dict_to_load(osd, ddp_model, optimizer) | |
| optimizer.load_state_dict(osd) | |
| # Initialize data loaders | |
| laion_dataset = get_data(args, image_processor, tokenizer, "image_text") | |
| mmc4_dataset = get_data(args, image_processor, tokenizer, "mmc4") | |
| total_training_steps = ( | |
| (args.train_num_samples_mmc4) // (args.batch_size_mmc4 * args.world_size) | |
| ) * args.num_epochs | |
| if args.rank == 0: | |
| print(f"Total training steps: {total_training_steps}") | |
| # Initialize lr scheduler | |
| if args.lr_scheduler == "linear": | |
| lr_scheduler = get_linear_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=args.warmup_steps, | |
| num_training_steps=total_training_steps, | |
| ) | |
| elif args.lr_scheduler == "cosine": | |
| lr_scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=args.warmup_steps, | |
| num_training_steps=total_training_steps, | |
| ) | |
| else: | |
| lr_scheduler = get_constant_schedule_with_warmup( | |
| optimizer, num_warmup_steps=args.warmup_steps | |
| ) | |
| # load lr scheduler checkpoint | |
| if args.resume_from_checkpoint is not None: | |
| lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) | |
| # Start training! | |
| ddp_model.train() | |
| for epoch in range(resume_from_epoch, args.num_epochs): | |
| laion_dataset.set_epoch(epoch) | |
| laion_loader = laion_dataset.dataloader | |
| mmc4_dataset.set_epoch(epoch) | |
| mmc4_loader = mmc4_dataset.dataloader | |
| train_one_epoch( | |
| args=args, | |
| model=ddp_model, | |
| epoch=epoch, | |
| tokenizer=tokenizer, | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| laion_loader=laion_loader, | |
| mmc4_loader=mmc4_loader, | |
| device_id=device_id, | |
| wandb=wandb, | |
| ) | |
| save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args) | |
| # save final checkpoint | |
| save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args) | |
| if __name__ == "__main__": | |
| main() | |