|
|
""" |
|
|
torch_utils.py |
|
|
|
|
|
General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch. |
|
|
|
|
|
Random `set_global_seed` functionality is taken directly from PyTorch-Lighting: |
|
|
> Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py |
|
|
|
|
|
This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our |
|
|
Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime |
|
|
we inject randomness from non-PyTorch sources (e.g., numpy, random)! |
|
|
> Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ |
|
|
|
|
|
Terminology |
|
|
-> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous! |
|
|
-> Rank :: Integer index of current process in the total world size |
|
|
-> Local Rank :: Local index on given node in [0, Devices per Node] |
|
|
""" |
|
|
|
|
|
import os |
|
|
import random |
|
|
from typing import Callable, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
def setup_seed(seed, rank=0): |
|
|
"""Set random seeds for reproducibility.""" |
|
|
torch.manual_seed(seed + rank) |
|
|
np.random.seed(seed + rank) |
|
|
random.seed(seed + rank) |
|
|
|
|
|
|
|
|
def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]: |
|
|
"""Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`""" |
|
|
assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!" |
|
|
|
|
|
|
|
|
os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
return worker_init_function if get_worker_init_fn else None |
|
|
|
|
|
|
|
|
def worker_init_function(worker_id: int) -> None: |
|
|
""" |
|
|
Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo: |
|
|
> Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 |
|
|
|
|
|
Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that |
|
|
you can run iterative splitting on to get new (predictable) randomness. |
|
|
|
|
|
:param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question. |
|
|
""" |
|
|
|
|
|
global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed() |
|
|
|
|
|
|
|
|
|
|
|
base_seed = process_seed - worker_id |
|
|
|
|
|
|
|
|
seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank]) |
|
|
|
|
|
|
|
|
np.random.seed(seed_seq.generate_state(4)) |
|
|
|
|
|
|
|
|
torch_seed_seq, random_seed_seq = seed_seq.spawn(2) |
|
|
|
|
|
|
|
|
torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0]) |
|
|
|
|
|
|
|
|
random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum() |
|
|
random.seed(random_seed) |
|
|
|
|
|
|
|
|
def get_epoch_and_step_from_checkpoint(checkpoint_path): |
|
|
"""Parse epoch and step numbers from checkpoint path.""" |
|
|
if checkpoint_path is None: |
|
|
return 0, 0 |
|
|
|
|
|
try: |
|
|
basename = os.path.basename(checkpoint_path) |
|
|
arr = basename.split('.')[0].split('-') |
|
|
epoch = int(arr[0].split('=')[1]) |
|
|
step = int(arr[1].split('=')[1]) |
|
|
return epoch, step |
|
|
except Exception as e: |
|
|
print(f"Error parsing checkpoint path {checkpoint_path}: {e}") |
|
|
return 0, 0 |
|
|
|
|
|
def find_last_checkpoint(checkpoint_dir): |
|
|
"""Find the last checkpoint in a directory based on step number.""" |
|
|
checkpoint_dir = os.path.join(checkpoint_dir, "checkpoints") |
|
|
if not os.path.exists(checkpoint_dir): |
|
|
return None |
|
|
|
|
|
checkpoint_list = os.listdir(checkpoint_dir) |
|
|
print(f"All checkpoints: {checkpoint_list}") |
|
|
|
|
|
last_checkpoint_info = None |
|
|
for folder in checkpoint_list: |
|
|
folder_path = os.path.join(checkpoint_dir, folder) |
|
|
if not os.path.isdir(folder_path): |
|
|
continue |
|
|
|
|
|
files = os.listdir(folder_path) |
|
|
if 'weights.pt' not in files: |
|
|
continue |
|
|
|
|
|
epoch, step = get_epoch_and_step_from_checkpoint(folder_path) |
|
|
if last_checkpoint_info is None or step > last_checkpoint_info[1]: |
|
|
last_checkpoint_info = (folder_path, step, epoch) |
|
|
|
|
|
return last_checkpoint_info[0] if last_checkpoint_info else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_bloat16_supported() -> bool: |
|
|
try: |
|
|
import packaging.version |
|
|
import torch.cuda.nccl as nccl |
|
|
import torch.distributed as dist |
|
|
|
|
|
return ( |
|
|
(torch.version.cuda is not None) |
|
|
and torch.cuda.is_bf16_supported() |
|
|
and (packaging.version.parse(torch.version.cuda).release >= (11, 0)) |
|
|
and dist.is_nccl_available() |
|
|
and (nccl.version() >= (2, 10)) |
|
|
) |
|
|
|
|
|
except Exception: |
|
|
return False |
|
|
|