model_tools / sparsify3.py
Naphula's picture
Upload 2 files
e7a095d verified
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only
from enum import Enum
from typing import Optional
import torch
class SparsificationMethod(str, Enum):
magnitude = "magnitude"
random = "random"
magnitude_outliers = "magnitude_outliers"
della_magprune = "della_magprune"
class RescaleNorm(str, Enum):
l1 = "l1"
l2 = "l2"
linf = "linf"
def rescaled_masked_tensor(
tensor: torch.Tensor,
mask: torch.Tensor,
norm: Optional[RescaleNorm],
eps: float = 1e-7,
) -> torch.Tensor:
"""Apply a mask to a tensor and rescale to match the original tensor norm.
Args:
tensor (torch.Tensor): Input tensor.
mask (torch.Tensor): Mask to apply.
norm (RescaleNorm): Which norm to match (l1, l2, linf).
eps (float): Tolerance for small norms to avoid division by zero.
"""
masked = tensor * mask
if norm is None:
return masked
elif norm == RescaleNorm.l1:
before_scale = tensor.abs().sum()
after_scale = masked.abs().sum()
elif norm == RescaleNorm.l2:
before_scale = tensor.norm()
after_scale = masked.norm()
elif norm == RescaleNorm.linf:
before_scale = tensor.abs().max()
after_scale = masked.abs().max()
else:
raise NotImplementedError(norm)
if before_scale < eps or after_scale < eps:
return masked
return masked * (before_scale / after_scale)
def magnitude(
tensor: torch.Tensor, density: float, rescale_norm: Optional[RescaleNorm] = None
) -> torch.Tensor:
"""Masks out the smallest values, retaining a proportion of `density`."""
if density >= 1:
return tensor
k = int(density * tensor.numel())
assert k > 0, "not gonna zero out the whole tensor buddy"
mask = torch.zeros_like(tensor)
w = tensor.abs().view(-1)
if w.device.type == "cpu":
w = w.float()
topk = torch.argsort(w, descending=True)[:k]
mask.view(-1)[topk] = 1
res = rescaled_masked_tensor(tensor, mask, rescale_norm)
return res
def magnitude_outliers(
tensor: torch.Tensor,
density: float,
rescale_norm: Optional[RescaleNorm] = None,
gamma: float = 0.01,
):
"""Masks out smallest values in addition to large outliers.
The `gamma` proportion of the largest weights are first removed, then the
smallest weights are removed to achieve the desired density.
Args:
tensor (torch.Tensor): The tensor to sparsify.
density (float): The proportion of weights to retain.
gamma (float): Percent of largest weights to remove.
"""
if density >= 1:
return tensor
num_elems = tensor.numel()
target_n = int(density * num_elems)
n_top = int(gamma * num_elems)
n_bot = num_elems - target_n - n_top
if n_bot < 0:
# cut down on the number of large weights to remove in
# order to hit the target density
n_top += n_bot
n_bot = 0
w = tensor.abs().view(-1)
if w.device.type == "cpu":
w = w.float()
indices = torch.sort(w, descending=False).indices
mask = torch.zeros_like(tensor)
mask.view(-1)[indices[n_bot:-n_top]] = 1
res = rescaled_masked_tensor(tensor, mask, rescale_norm)
return res
def bernoulli(
tensor: torch.Tensor, density: float, rescale_norm: Optional[RescaleNorm] = None
) -> torch.Tensor:
if density >= 1:
return tensor
if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16:
work_dtype = tensor.dtype
else:
# torch.bernoulli not implemented for float16 on CPU, upcast to float32
work_dtype = torch.float32
mask = torch.bernoulli(
torch.full_like(input=tensor, fill_value=density, dtype=work_dtype)
)
res = rescaled_masked_tensor(tensor.to(work_dtype), mask, rescale_norm)
return res.to(tensor.dtype)
def della_magprune(
tensor: torch.Tensor,
density: float,
epsilon: float,
rescale_norm: Optional[RescaleNorm] = None,
) -> torch.Tensor:
if density >= 1:
return tensor
if density <= 0:
return torch.zeros_like(tensor)
# --- SAFETY GUARD START ---
# Ensure density isn't exactly 0 or 1
density = max(1e-4, min(1.0 - 1e-4, density))
# Epsilon must be < density AND < (1 - density)
# If the optimizer guessed a bad epsilon, we shrink it to the max allowed value
max_epsilon = min(density, 1.0 - density) - 1e-4
if abs(epsilon) > max_epsilon:
epsilon = max_epsilon if epsilon > 0 else -max_epsilon
# --- SAFETY GUARD END ---
orig_shape = tensor.shape
work_dtype = (
tensor.dtype
if tensor.device.type != "cpu" or tensor.dtype == torch.bfloat16
else torch.float32
)
if len(tensor.shape) < 2:
tensor = tensor.unsqueeze(0)
magnitudes = tensor.abs()
sorted_indices = torch.argsort(magnitudes, dim=1, descending=False)
ranks = sorted_indices.argsort(dim=1).to(work_dtype) + 1
min_ranks = ranks.min(dim=1, keepdim=True).values
max_ranks = ranks.max(dim=1, keepdim=True).values
rank_norm = ((ranks - min_ranks) / (max_ranks - min_ranks).clamp(min=1e-8)).clamp(0, 1)
# Now this line is guaranteed not to produce values < 0 or > 1
probs = (density - epsilon) + rank_norm * 2 * epsilon
mask = torch.bernoulli(probs.clamp(0, 1)).to(work_dtype)
res = rescaled_masked_tensor(tensor.to(work_dtype), mask, rescale_norm)
return res.to(tensor.dtype).reshape(orig_shape)
def sparsify(
tensor: torch.Tensor,
density: float,
method: SparsificationMethod,
gamma: float = 0,
epsilon: float = 0,
rescale_norm: Optional[RescaleNorm] = None,
) -> torch.Tensor:
if method == SparsificationMethod.magnitude:
return magnitude(tensor, density=density, rescale_norm=rescale_norm)
elif method == SparsificationMethod.random:
return bernoulli(tensor, density=density, rescale_norm=rescale_norm)
elif method == SparsificationMethod.magnitude_outliers:
return magnitude_outliers(
tensor, density=density, rescale_norm=rescale_norm, gamma=gamma
)
elif method == SparsificationMethod.della_magprune:
return della_magprune(
tensor, density=density, epsilon=epsilon, rescale_norm=rescale_norm
)
else:
raise NotImplementedError(method)