Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from ...modules.utils import convert_module_to_f16, convert_module_to_f32, zero_module | |
| from ...modules import sparse as sp | |
| from ...modules.norm import LayerNorm32 | |
| class SparseResBlock3d(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| out_channels: Optional[int] = None, | |
| downsample: bool = False, | |
| upsample: bool = False, | |
| resample_mode: Literal['nearest', 'spatial2channel'] = 'nearest', | |
| use_checkpoint: bool = False, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.downsample = downsample | |
| self.upsample = upsample | |
| self.resample_mode = resample_mode | |
| self.use_checkpoint = use_checkpoint | |
| assert not (downsample and upsample), "Cannot downsample and upsample at the same time" | |
| self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) | |
| self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) | |
| if resample_mode == 'nearest': | |
| self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) | |
| elif resample_mode =='spatial2channel' and not self.downsample: | |
| self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3) | |
| elif resample_mode =='spatial2channel' and self.downsample: | |
| self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3) | |
| self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) | |
| if resample_mode == 'nearest': | |
| self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() | |
| elif resample_mode =='spatial2channel' and self.downsample: | |
| self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1)) | |
| elif resample_mode =='spatial2channel' and not self.downsample: | |
| self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) | |
| self.updown = None | |
| if self.downsample: | |
| if resample_mode == 'nearest': | |
| self.updown = sp.SparseDownsample(2) | |
| elif resample_mode =='spatial2channel': | |
| self.updown = sp.SparseSpatial2Channel(2) | |
| elif self.upsample: | |
| self.to_subdiv = sp.SparseLinear(channels, 8) | |
| if resample_mode == 'nearest': | |
| self.updown = sp.SparseUpsample(2) | |
| elif resample_mode =='spatial2channel': | |
| self.updown = sp.SparseChannel2Spatial(2) | |
| def _updown(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: | |
| if self.downsample: | |
| x = self.updown(x) | |
| elif self.upsample: | |
| x = self.updown(x, subdiv.replace(subdiv.feats > 0)) | |
| return x | |
| def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
| subdiv = None | |
| if self.upsample: | |
| subdiv = self.to_subdiv(x) | |
| h = x.replace(self.norm1(x.feats)) | |
| h = h.replace(F.silu(h.feats)) | |
| if self.resample_mode == 'spatial2channel': | |
| h = self.conv1(h) | |
| h = self._updown(h, subdiv) | |
| x = self._updown(x, subdiv) | |
| if self.resample_mode == 'nearest': | |
| h = self.conv1(h) | |
| h = h.replace(self.norm2(h.feats)) | |
| h = h.replace(F.silu(h.feats)) | |
| h = self.conv2(h) | |
| h = h + self.skip_connection(x) | |
| if self.upsample: | |
| return h, subdiv | |
| return h | |
| def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
| if self.use_checkpoint: | |
| return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) | |
| else: | |
| return self._forward(x) | |
| class SparseResBlockDownsample3d(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| out_channels: Optional[int] = None, | |
| use_checkpoint: bool = False, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_checkpoint = use_checkpoint | |
| self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) | |
| self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) | |
| self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) | |
| self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) | |
| self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() | |
| self.updown = sp.SparseDownsample(2) | |
| def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
| h = x.replace(self.norm1(x.feats)) | |
| h = h.replace(F.silu(h.feats)) | |
| h = self.updown(h) | |
| x = self.updown(x) | |
| h = self.conv1(h) | |
| h = h.replace(self.norm2(h.feats)) | |
| h = h.replace(F.silu(h.feats)) | |
| h = self.conv2(h) | |
| h = h + self.skip_connection(x) | |
| return h | |
| def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
| if self.use_checkpoint: | |
| return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) | |
| else: | |
| return self._forward(x) | |
| class SparseResBlockUpsample3d(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| out_channels: Optional[int] = None, | |
| use_checkpoint: bool = False, | |
| pred_subdiv: bool = True, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_checkpoint = use_checkpoint | |
| self.pred_subdiv = pred_subdiv | |
| self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) | |
| self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) | |
| self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) | |
| self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) | |
| self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() | |
| if self.pred_subdiv: | |
| self.to_subdiv = sp.SparseLinear(channels, 8) | |
| self.updown = sp.SparseUpsample(2) | |
| def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: | |
| if self.pred_subdiv: | |
| subdiv = self.to_subdiv(x) | |
| h = x.replace(self.norm1(x.feats)) | |
| h = h.replace(F.silu(h.feats)) | |
| subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None | |
| h = self.updown(h, subdiv_binarized) | |
| x = self.updown(x, subdiv_binarized) | |
| h = self.conv1(h) | |
| h = h.replace(self.norm2(h.feats)) | |
| h = h.replace(F.silu(h.feats)) | |
| h = self.conv2(h) | |
| h = h + self.skip_connection(x) | |
| if self.pred_subdiv: | |
| return h, subdiv | |
| else: | |
| return h | |
| def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
| if self.use_checkpoint: | |
| return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) | |
| else: | |
| return self._forward(x) | |
| class SparseResBlockS2C3d(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| out_channels: Optional[int] = None, | |
| use_checkpoint: bool = False, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_checkpoint = use_checkpoint | |
| self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) | |
| self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) | |
| self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3) | |
| self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) | |
| self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1)) | |
| self.updown = sp.SparseSpatial2Channel(2) | |
| def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
| h = x.replace(self.norm1(x.feats)) | |
| h = h.replace(F.silu(h.feats)) | |
| h = self.conv1(h) | |
| h = self.updown(h) | |
| x = self.updown(x) | |
| h = h.replace(self.norm2(h.feats)) | |
| h = h.replace(F.silu(h.feats)) | |
| h = self.conv2(h) | |
| h = h + self.skip_connection(x) | |
| return h | |
| def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
| if self.use_checkpoint: | |
| return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) | |
| else: | |
| return self._forward(x) | |
| class SparseResBlockC2S3d(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| out_channels: Optional[int] = None, | |
| use_checkpoint: bool = False, | |
| pred_subdiv: bool = True, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_checkpoint = use_checkpoint | |
| self.pred_subdiv = pred_subdiv | |
| self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) | |
| self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) | |
| self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3) | |
| self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) | |
| self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) | |
| if pred_subdiv: | |
| self.to_subdiv = sp.SparseLinear(channels, 8) | |
| self.updown = sp.SparseChannel2Spatial(2) | |
| def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: | |
| if self.pred_subdiv: | |
| subdiv = self.to_subdiv(x) | |
| h = x.replace(self.norm1(x.feats)) | |
| h = h.replace(F.silu(h.feats)) | |
| h = self.conv1(h) | |
| subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None | |
| h = self.updown(h, subdiv_binarized) | |
| x = self.updown(x, subdiv_binarized) | |
| h = h.replace(self.norm2(h.feats)) | |
| h = h.replace(F.silu(h.feats)) | |
| h = self.conv2(h) | |
| h = h + self.skip_connection(x) | |
| if self.pred_subdiv: | |
| return h, subdiv | |
| else: | |
| return h | |
| def forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: | |
| if self.use_checkpoint: | |
| return torch.utils.checkpoint.checkpoint(self._forward, x, subdiv, use_reentrant=False) | |
| else: | |
| return self._forward(x, subdiv) | |
| class SparseConvNeXtBlock3d(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| mlp_ratio: float = 4.0, | |
| use_checkpoint: bool = False, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.use_checkpoint = use_checkpoint | |
| self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) | |
| self.conv = sp.SparseConv3d(channels, channels, 3) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(channels, int(channels * mlp_ratio)), | |
| nn.SiLU(), | |
| zero_module(nn.Linear(int(channels * mlp_ratio), channels)), | |
| ) | |
| def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
| h = self.conv(x) | |
| h = h.replace(self.norm(h.feats)) | |
| h = h.replace(self.mlp(h.feats)) | |
| return h + x | |
| def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: | |
| if self.use_checkpoint: | |
| return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) | |
| else: | |
| return self._forward(x) | |
| class SparseUnetVaeEncoder(nn.Module): | |
| """ | |
| Sparse Swin Transformer Unet VAE model. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| model_channels: List[int], | |
| latent_channels: int, | |
| num_blocks: List[int], | |
| block_type: List[str], | |
| down_block_type: List[str], | |
| block_args: List[Dict[str, Any]], | |
| use_fp16: bool = False, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.model_channels = model_channels | |
| self.num_blocks = num_blocks | |
| self.dtype = torch.float16 if use_fp16 else torch.float32 | |
| self.dtype = torch.float16 if use_fp16 else torch.float32 | |
| self.input_layer = sp.SparseLinear(in_channels, model_channels[0]) | |
| self.to_latent = sp.SparseLinear(model_channels[-1], 2 * latent_channels) | |
| self.blocks = nn.ModuleList([]) | |
| for i in range(len(num_blocks)): | |
| self.blocks.append(nn.ModuleList([])) | |
| for j in range(num_blocks[i]): | |
| self.blocks[-1].append( | |
| globals()[block_type[i]]( | |
| model_channels[i], | |
| **block_args[i], | |
| ) | |
| ) | |
| if i < len(num_blocks) - 1: | |
| self.blocks[-1].append( | |
| globals()[down_block_type[i]]( | |
| model_channels[i], | |
| model_channels[i+1], | |
| **block_args[i], | |
| ) | |
| ) | |
| self.initialize_weights() | |
| if use_fp16: | |
| self.convert_to_fp16() | |
| def device(self) -> torch.device: | |
| """ | |
| Return the device of the model. | |
| """ | |
| return next(self.parameters()).device | |
| def convert_to_fp16(self) -> None: | |
| """ | |
| Convert the torso of the model to float16. | |
| """ | |
| self.blocks.apply(convert_module_to_f16) | |
| def convert_to_fp32(self) -> None: | |
| """ | |
| Convert the torso of the model to float32. | |
| """ | |
| self.blocks.apply(convert_module_to_f32) | |
| def initialize_weights(self) -> None: | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False): | |
| h = self.input_layer(x) | |
| h = h.type(self.dtype) | |
| for i, res in enumerate(self.blocks): | |
| for j, block in enumerate(res): | |
| h = block(h) | |
| h = h.type(x.dtype) | |
| h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) | |
| h = self.to_latent(h) | |
| # Sample from the posterior distribution | |
| mean, logvar = h.feats.chunk(2, dim=-1) | |
| if sample_posterior: | |
| std = torch.exp(0.5 * logvar) | |
| z = mean + std * torch.randn_like(std) | |
| else: | |
| z = mean | |
| z = h.replace(z) | |
| if return_raw: | |
| return z, mean, logvar | |
| else: | |
| return z | |
| class SparseUnetVaeDecoder(nn.Module): | |
| """ | |
| Sparse Swin Transformer Unet VAE model. | |
| """ | |
| def __init__( | |
| self, | |
| out_channels: int, | |
| model_channels: List[int], | |
| latent_channels: int, | |
| num_blocks: List[int], | |
| block_type: List[str], | |
| up_block_type: List[str], | |
| block_args: List[Dict[str, Any]], | |
| use_fp16: bool = False, | |
| pred_subdiv: bool = True, | |
| ): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.model_channels = model_channels | |
| self.num_blocks = num_blocks | |
| self.use_fp16 = use_fp16 | |
| self.pred_subdiv = pred_subdiv | |
| self.dtype = torch.float16 if use_fp16 else torch.float32 | |
| self.low_vram = False | |
| self.output_layer = sp.SparseLinear(model_channels[-1], out_channels) | |
| self.from_latent = sp.SparseLinear(latent_channels, model_channels[0]) | |
| self.blocks = nn.ModuleList([]) | |
| for i in range(len(num_blocks)): | |
| self.blocks.append(nn.ModuleList([])) | |
| for j in range(num_blocks[i]): | |
| self.blocks[-1].append( | |
| globals()[block_type[i]]( | |
| model_channels[i], | |
| **block_args[i], | |
| ) | |
| ) | |
| if i < len(num_blocks) - 1: | |
| self.blocks[-1].append( | |
| globals()[up_block_type[i]]( | |
| model_channels[i], | |
| model_channels[i+1], | |
| pred_subdiv=pred_subdiv, | |
| **block_args[i], | |
| ) | |
| ) | |
| self.initialize_weights() | |
| if use_fp16: | |
| self.convert_to_fp16() | |
| def device(self) -> torch.device: | |
| """ | |
| Return the device of the model. | |
| """ | |
| return next(self.parameters()).device | |
| def convert_to_fp16(self) -> None: | |
| """ | |
| Convert the torso of the model to float16. | |
| """ | |
| self.blocks.apply(convert_module_to_f16) | |
| def convert_to_fp32(self) -> None: | |
| """ | |
| Convert the torso of the model to float32. | |
| """ | |
| self.blocks.apply(convert_module_to_f32) | |
| def initialize_weights(self) -> None: | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor]] = None, return_subs: bool = False) -> sp.SparseTensor: | |
| assert guide_subs is None or self.pred_subdiv == False, "Only decoders with pred_subdiv=False can be used with guide_subs" | |
| assert return_subs == False or self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with return_subs" | |
| h = self.from_latent(x) | |
| h = h.type(self.dtype) | |
| subs_gt = [] | |
| subs = [] | |
| for i, res in enumerate(self.blocks): | |
| for j, block in enumerate(res): | |
| if i < len(self.blocks) - 1 and j == len(res) - 1: | |
| if self.pred_subdiv: | |
| if self.training: | |
| subs_gt.append(h.get_spatial_cache('subdivision')) | |
| h, sub = block(h) | |
| subs.append(sub) | |
| else: | |
| h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) | |
| else: | |
| h = block(h) | |
| h = h.type(x.dtype) | |
| h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) | |
| h = self.output_layer(h) | |
| if self.training and self.pred_subdiv: | |
| return h, subs_gt, subs | |
| else: | |
| if return_subs: | |
| return h, subs | |
| else: | |
| return h | |
| def upsample(self, x: sp.SparseTensor, upsample_times: int) -> torch.Tensor: | |
| assert self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with upsampling" | |
| h = self.from_latent(x) | |
| h = h.type(self.dtype) | |
| for i, res in enumerate(self.blocks): | |
| if i == upsample_times: | |
| return h.coords | |
| for j, block in enumerate(res): | |
| if i < len(self.blocks) - 1 and j == len(res) - 1: | |
| h, sub = block(h) | |
| else: | |
| h = block(h) | |