TRELLIS.2 / trellis2 /models /sc_vaes /sparse_unet_vae.py
JeffreyXiang's picture
Finalize
a1e3f5f
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...modules.utils import convert_module_to_f16, convert_module_to_f32, zero_module
from ...modules import sparse as sp
from ...modules.norm import LayerNorm32
class SparseResBlock3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
downsample: bool = False,
upsample: bool = False,
resample_mode: Literal['nearest', 'spatial2channel'] = 'nearest',
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.downsample = downsample
self.upsample = upsample
self.resample_mode = resample_mode
self.use_checkpoint = use_checkpoint
assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
if resample_mode == 'nearest':
self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
elif resample_mode =='spatial2channel' and not self.downsample:
self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3)
elif resample_mode =='spatial2channel' and self.downsample:
self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
if resample_mode == 'nearest':
self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
elif resample_mode =='spatial2channel' and self.downsample:
self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1))
elif resample_mode =='spatial2channel' and not self.downsample:
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
self.updown = None
if self.downsample:
if resample_mode == 'nearest':
self.updown = sp.SparseDownsample(2)
elif resample_mode =='spatial2channel':
self.updown = sp.SparseSpatial2Channel(2)
elif self.upsample:
self.to_subdiv = sp.SparseLinear(channels, 8)
if resample_mode == 'nearest':
self.updown = sp.SparseUpsample(2)
elif resample_mode =='spatial2channel':
self.updown = sp.SparseChannel2Spatial(2)
def _updown(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
if self.downsample:
x = self.updown(x)
elif self.upsample:
x = self.updown(x, subdiv.replace(subdiv.feats > 0))
return x
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
subdiv = None
if self.upsample:
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
if self.resample_mode == 'spatial2channel':
h = self.conv1(h)
h = self._updown(h, subdiv)
x = self._updown(x, subdiv)
if self.resample_mode == 'nearest':
h = self.conv1(h)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
if self.upsample:
return h, subdiv
return h
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseResBlockDownsample3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
self.updown = sp.SparseDownsample(2)
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.updown(h)
x = self.updown(x)
h = self.conv1(h)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
return h
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseResBlockUpsample3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
pred_subdiv: bool = True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.pred_subdiv = pred_subdiv
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
if self.pred_subdiv:
self.to_subdiv = sp.SparseLinear(channels, 8)
self.updown = sp.SparseUpsample(2)
def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
if self.pred_subdiv:
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized)
x = self.updown(x, subdiv_binarized)
h = self.conv1(h)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
if self.pred_subdiv:
return h, subdiv
else:
return h
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseResBlockS2C3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1))
self.updown = sp.SparseSpatial2Channel(2)
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h)
h = self.updown(h)
x = self.updown(x)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
return h
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseResBlockC2S3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
pred_subdiv: bool = True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.pred_subdiv = pred_subdiv
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
if pred_subdiv:
self.to_subdiv = sp.SparseLinear(channels, 8)
self.updown = sp.SparseChannel2Spatial(2)
def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
if self.pred_subdiv:
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h)
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized)
x = self.updown(x, subdiv_binarized)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
if self.pred_subdiv:
return h, subdiv
else:
return h
def forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, subdiv, use_reentrant=False)
else:
return self._forward(x, subdiv)
class SparseConvNeXtBlock3d(nn.Module):
def __init__(
self,
channels: int,
mlp_ratio: float = 4.0,
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.use_checkpoint = use_checkpoint
self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.conv = sp.SparseConv3d(channels, channels, 3)
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
nn.SiLU(),
zero_module(nn.Linear(int(channels * mlp_ratio), channels)),
)
def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = self.conv(x)
h = h.replace(self.norm(h.feats))
h = h.replace(self.mlp(h.feats))
return h + x
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseUnetVaeEncoder(nn.Module):
"""
Sparse Swin Transformer Unet VAE model.
"""
def __init__(
self,
in_channels: int,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
down_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.num_blocks = num_blocks
self.dtype = torch.float16 if use_fp16 else torch.float32
self.dtype = torch.float16 if use_fp16 else torch.float32
self.input_layer = sp.SparseLinear(in_channels, model_channels[0])
self.to_latent = sp.SparseLinear(model_channels[-1], 2 * latent_channels)
self.blocks = nn.ModuleList([])
for i in range(len(num_blocks)):
self.blocks.append(nn.ModuleList([]))
for j in range(num_blocks[i]):
self.blocks[-1].append(
globals()[block_type[i]](
model_channels[i],
**block_args[i],
)
)
if i < len(num_blocks) - 1:
self.blocks[-1].append(
globals()[down_block_type[i]](
model_channels[i],
model_channels[i+1],
**block_args[i],
)
)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
self.blocks.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
self.blocks.apply(convert_module_to_f32)
def initialize_weights(self) -> None:
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False):
h = self.input_layer(x)
h = h.type(self.dtype)
for i, res in enumerate(self.blocks):
for j, block in enumerate(res):
h = block(h)
h = h.type(x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.to_latent(h)
# Sample from the posterior distribution
mean, logvar = h.feats.chunk(2, dim=-1)
if sample_posterior:
std = torch.exp(0.5 * logvar)
z = mean + std * torch.randn_like(std)
else:
z = mean
z = h.replace(z)
if return_raw:
return z, mean, logvar
else:
return z
class SparseUnetVaeDecoder(nn.Module):
"""
Sparse Swin Transformer Unet VAE model.
"""
def __init__(
self,
out_channels: int,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
up_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
pred_subdiv: bool = True,
):
super().__init__()
self.out_channels = out_channels
self.model_channels = model_channels
self.num_blocks = num_blocks
self.use_fp16 = use_fp16
self.pred_subdiv = pred_subdiv
self.dtype = torch.float16 if use_fp16 else torch.float32
self.low_vram = False
self.output_layer = sp.SparseLinear(model_channels[-1], out_channels)
self.from_latent = sp.SparseLinear(latent_channels, model_channels[0])
self.blocks = nn.ModuleList([])
for i in range(len(num_blocks)):
self.blocks.append(nn.ModuleList([]))
for j in range(num_blocks[i]):
self.blocks[-1].append(
globals()[block_type[i]](
model_channels[i],
**block_args[i],
)
)
if i < len(num_blocks) - 1:
self.blocks[-1].append(
globals()[up_block_type[i]](
model_channels[i],
model_channels[i+1],
pred_subdiv=pred_subdiv,
**block_args[i],
)
)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
self.blocks.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
self.blocks.apply(convert_module_to_f32)
def initialize_weights(self) -> None:
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor]] = None, return_subs: bool = False) -> sp.SparseTensor:
assert guide_subs is None or self.pred_subdiv == False, "Only decoders with pred_subdiv=False can be used with guide_subs"
assert return_subs == False or self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with return_subs"
h = self.from_latent(x)
h = h.type(self.dtype)
subs_gt = []
subs = []
for i, res in enumerate(self.blocks):
for j, block in enumerate(res):
if i < len(self.blocks) - 1 and j == len(res) - 1:
if self.pred_subdiv:
if self.training:
subs_gt.append(h.get_spatial_cache('subdivision'))
h, sub = block(h)
subs.append(sub)
else:
h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None)
else:
h = block(h)
h = h.type(x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.output_layer(h)
if self.training and self.pred_subdiv:
return h, subs_gt, subs
else:
if return_subs:
return h, subs
else:
return h
def upsample(self, x: sp.SparseTensor, upsample_times: int) -> torch.Tensor:
assert self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with upsampling"
h = self.from_latent(x)
h = h.type(self.dtype)
for i, res in enumerate(self.blocks):
if i == upsample_times:
return h.coords
for j, block in enumerate(res):
if i < len(self.blocks) - 1 and j == len(res) - 1:
h, sub = block(h)
else:
h = block(h)