diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000000000000000000000000000000000000..b7faf403d915ca307532bb0eb9cceaf0214e8e5b --- /dev/null +++ b/.gitignore @@ -0,0 +1,207 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ diff --git a/README.md b/README.md index 86606fe56dfdbb3307ea71e79190bfae90517aa4..3ba267f8f085a01c395b2467d97aa8275d09f3c0 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,14 @@ --- title: TRELLIS.2 -emoji: 📚 -colorFrom: yellow -colorTo: pink +emoji: 🏢 +colorFrom: indigo +colorTo: blue sdk: gradio sdk_version: 6.1.0 app_file: app.py pinned: false license: mit +short_description: High-fidelity 3D Generation from images --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..eab91423821fc00b72038b6d38f3b03171ea50d9 --- /dev/null +++ b/app.py @@ -0,0 +1,335 @@ +import gradio as gr +import spaces + +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1' +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +from datetime import datetime +import shutil +import cv2 +from typing import * +import torch +import numpy as np +from PIL import Image +from trellis2.modules.sparse import SparseTensor +from trellis2.pipelines import Trellis2ImageTo3DPipeline +from trellis2.renderers import EnvMap +from trellis2.utils import render_utils +import o_voxel + + +MAX_SEED = np.iinfo(np.int32).max +TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') +os.makedirs(TMP_DIR, exist_ok=True) + + +def start_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + + +def end_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shutil.rmtree(user_dir) + + +def preprocess_image(image: Image.Image) -> Image.Image: + """ + Preprocess the input image. + + Args: + image (Image.Image): The input image. + + Returns: + Image.Image: The preprocessed image. + """ + processed_image = pipeline.preprocess_image(image) + return processed_image + + +def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict: + shape_slat, tex_slat, res = latents + return { + 'shape_slat_feats': shape_slat.feats.cpu().numpy(), + 'tex_slat_feats': tex_slat.feats.cpu().numpy(), + 'coords': shape_slat.coords.cpu().numpy(), + 'res': res, + } + + +def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]: + shape_slat = SparseTensor( + feats=torch.from_numpy(state['shape_slat_feats']).cuda(), + coords=torch.from_numpy(state['coords']).cuda(), + ) + tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda()) + return shape_slat, tex_slat, state['res'] + + +def get_seed(randomize_seed: bool, seed: int) -> int: + """ + Get the random seed. + """ + return np.random.randint(0, MAX_SEED) if randomize_seed else seed + + +@spaces.GPU(duration=120) +def image_to_3d( + image: Image.Image, + seed: int, + resolution: str, + ss_guidance_strength: float, + ss_guidance_rescale: float, + ss_sampling_steps: int, + ss_rescale_t: float, + shape_slat_guidance_strength: float, + shape_slat_guidance_rescale: float, + shape_slat_sampling_steps: int, + shape_slat_rescale_t: float, + tex_slat_guidance_strength: float, + tex_slat_guidance_rescale: float, + tex_slat_sampling_steps: int, + tex_slat_rescale_t: float, + req: gr.Request, + progress=gr.Progress(track_tqdm=True), +) -> str: + """ + Convert an image to a 3D model. + + Args: + image (Image.Image): The input image. + seed (int): The random seed. + ss_guidance_strength (float): The guidance strength for sparse structure generation. + ss_sampling_steps (int): The number of sampling steps for sparse structure generation. + shape_slat_guidance_strength (float): The guidance strength for shape slat generation. + shape_slat_sampling_steps (int): The number of sampling steps for shape slat generation. + tex_slat_guidance_strength (float): The guidance strength for texture slat generation. + tex_slat_sampling_steps (int): The number of sampling steps for texture slat generation. + + Returns: + str: The path to the preview video of the 3D model. + str: The path to the 3D model. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + outputs, latents = pipeline.run( + image, + seed=seed, + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "guidance_strength": ss_guidance_strength, + "guidance_rescale": ss_guidance_rescale, + "rescale_t": ss_rescale_t, + }, + shape_slat_sampler_params={ + "steps": shape_slat_sampling_steps, + "guidance_strength": shape_slat_guidance_strength, + "guidance_rescale": shape_slat_guidance_rescale, + "rescale_t": shape_slat_rescale_t, + }, + tex_slat_sampler_params={ + "steps": tex_slat_sampling_steps, + "guidance_strength": tex_slat_guidance_strength, + "guidance_rescale": tex_slat_guidance_rescale, + "rescale_t": tex_slat_rescale_t, + }, + pipeline_type={ + "512": "512", + "1024": "512->1024", + "1536": "512->1536", + }[resolution], + return_latent=True, + ) + images = render_utils.make_pbr_vis_frames( + render_utils.render_snapshot(outputs[0], resolution=1024, r=2, fov=36, envmap=envmap), + resolution=1024 + ) + state = pack_state(latents) + torch.cuda.empty_cache() + return state, [Image.fromarray(image) for image in images] + + +@spaces.GPU(duration=120) +def extract_glb( + state: dict, + decimation_target: int, + texture_size: int, + req: gr.Request, + progress=gr.Progress(track_tqdm=True), +) -> Tuple[str, str]: + """ + Extract a GLB file from the 3D model. + + Args: + state (dict): The state of the generated 3D model. + decimation_target (int): The target face count for decimation. + texture_size (int): The texture resolution. + + Returns: + str: The path to the extracted GLB file. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shape_slat, tex_slat, res = unpack_state(state) + mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] + glb = o_voxel.postprocess.to_glb( + vertices=mesh.vertices, + faces=mesh.faces, + attr_volume=mesh.attrs, + coords=mesh.coords, + attr_layout=pipeline.pbr_attr_layout, + grid_size=res, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + decimation_target=decimation_target, + texture_size=texture_size, + use_tqdm=True, + )[0] + now = datetime.now() + timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" + os.makedirs(user_dir, exist_ok=True) + glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') + glb.export(glb_path) + torch.cuda.empty_cache() + return glb_path, glb_path + + +css = """ +.stepper-wrapper { + padding: 0; +} + +.stepper-container { + padding: 0; + align-items: center; +} + +.step-button { + flex-direction: row; +} + +.step-connector { + transform: none; +} + +.step-number { + width: 16px; + height: 16px; +} + +.step-label { + position: relative; + bottom: 0; +} +""" + + +with gr.Blocks(delete_cache=(600, 600)) as demo: + gr.Markdown(""" + ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/trellis.2) + * Upload an image and click "Generate" to create a 3D asset. + * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it. + """) + + with gr.Row(): + with gr.Column(scale=1, min_width=360): + image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400) + + resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="512") + seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) + randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) + decimation_target = gr.Slider(10000, 500000, label="Decimation Target", value=100000, step=10000) + texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) + + with gr.Accordion(label="Advanced Settings", open=False): + gr.Markdown("Stage 1: Sparse Structure Generation") + with gr.Row(): + ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01) + ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1) + gr.Markdown("Stage 2: Shape Generation") + with gr.Row(): + shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) + shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) + gr.Markdown("Stage 3: Material Generation") + with gr.Row(): + tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) + tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) + tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) + + generate_btn = gr.Button("Generate") + + with gr.Column(scale=10): + with gr.Walkthrough(selected=0) as walkthrough: + with gr.Step("Preview", id=0): + preview_output = gr.Gallery(label="3D Asset Preview", height=800, show_label=True, preview=True) + extract_btn = gr.Button("Extract GLB") + with gr.Step("Extract", id=1): + glb_output = gr.Model3D(label="Extracted GLB", height=800, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0)) + download_btn = gr.DownloadButton(label="Download GLB") + + with gr.Column(scale=1, min_width=172): + examples = gr.Examples( + examples=[ + f'assets/example_image/{image}' + for image in os.listdir("assets/example_image") + ], + inputs=[image_prompt], + fn=preprocess_image, + outputs=[image_prompt], + run_on_click=True, + examples_per_page=18, + ) + + output_buf = gr.State() + + + # Handlers + demo.load(start_session) + demo.unload(end_session) + + image_prompt.upload( + preprocess_image, + inputs=[image_prompt], + outputs=[image_prompt], + ) + + generate_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).then( + lambda: gr.Walkthrough(selected=0), outputs=walkthrough + ).then( + image_to_3d, + inputs=[ + image_prompt, seed, resolution, + ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, + shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, + tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, + ], + outputs=[output_buf, preview_output], + ) + + extract_btn.click( + lambda: gr.Walkthrough(selected=1), outputs=walkthrough + ).then( + extract_glb, + inputs=[output_buf, decimation_target, texture_size], + outputs=[glb_output, download_btn], + ) + + +# Launch the Gradio app +if __name__ == "__main__": + pipeline = Trellis2ImageTo3DPipeline.from_pretrained('JeffreyXiang/TRELLIS.2-4B') + pipeline.cuda() + + envmap = EnvMap(torch.tensor( + cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), + dtype=torch.float32, device='cuda' + )) + + demo.launch(css=css, mcp_server=True) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9d50b1fb825b6b8a7c0a71d7560d37ec1dc8d7b4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +--extra-index-url https://download.pytorch.org/whl/cu124 + +torch==2.6.0 +torchvision==0.21.0 +triton==3.2.0 +pillow==12.0.0 +imageio==2.37.2 +imageio-ffmpeg==0.6.0 +tqdm==4.67.1 +easydict==1.13 +opencv-python-headless==4.12.0.88 +trimesh==4.10.1 +transformers==4.46.3 +git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8 +https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl +https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl?download=true +https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl?download=true +https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl?download=true +https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/nvdiffrast-0.3.5-cp310-cp310-linux_x86_64?download=true +https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/nvdiffrec_render-0.0.0-cp310-cp310-linux_x86_64.whl?download=true diff --git a/trellis2/__init__.py b/trellis2/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..20d240afc9c26a21aee76954628b3d4ef9a1ccbd --- /dev/null +++ b/trellis2/__init__.py @@ -0,0 +1,6 @@ +from . import models +from . import modules +from . import pipelines +from . import renderers +from . import representations +from . import utils diff --git a/trellis2/models/__init__.py b/trellis2/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..d4fed035ed7c1f6352d93e9787f1aaed072876d5 --- /dev/null +++ b/trellis2/models/__init__.py @@ -0,0 +1,78 @@ +import importlib + +__attributes = { + # Sparse Structure + 'SparseStructureEncoder': 'sparse_structure_vae', + 'SparseStructureDecoder': 'sparse_structure_vae', + 'SparseStructureFlowModel': 'sparse_structure_flow', + + # SLat Generation + 'SLatFlowModel': 'structured_latent_flow', + 'ElasticSLatFlowModel': 'structured_latent_flow', + + # SC-VAEs + 'SparseUnetVaeEncoder': 'sc_vaes.sparse_unet_vae', + 'SparseUnetVaeDecoder': 'sc_vaes.sparse_unet_vae', + 'FlexiDualGridVaeEncoder': 'sc_vaes.fdg_vae', + 'FlexiDualGridVaeDecoder': 'sc_vaes.fdg_vae' +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str, **kwargs): + """ + Load a model from a pretrained checkpoint. + + Args: + path: The path to the checkpoint. Can be either local path or a Hugging Face model name. + NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. + **kwargs: Additional arguments for the model constructor. + """ + import os + import json + from safetensors.torch import load_file + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") + + if is_local: + config_file = f"{path}.json" + model_file = f"{path}.safetensors" + else: + from huggingface_hub import hf_hub_download + path_parts = path.split('/') + repo_id = f'{path_parts[0]}/{path_parts[1]}' + model_name = '/'.join(path_parts[2:]) + config_file = hf_hub_download(repo_id, f"{model_name}.json") + model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") + + with open(config_file, 'r') as f: + config = json.load(f) + model = __getattr__(config['name'])(**config['args'], **kwargs) + model.load_state_dict(load_file(model_file), strict=False) + + return model + + +# For Pylance +if __name__ == '__main__': + from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder + from .sparse_structure_flow import SparseStructureFlowModel + from .structured_latent_flow import SLatFlowModel, ElasticSLatFlowModel + + from .sc_vaes.sparse_unet_vae import SparseUnetVaeEncoder, SparseUnetVaeDecoder + from .sc_vaes.fdg_vae import FlexiDualGridVaeEncoder, FlexiDualGridVaeDecoder diff --git a/trellis2/models/sc_vaes/fdg_vae.py b/trellis2/models/sc_vaes/fdg_vae.py new file mode 100755 index 0000000000000000000000000000000000000000..c9b5b072b8e8b366df9cb8d8a558eaf3bc9957e6 --- /dev/null +++ b/trellis2/models/sc_vaes/fdg_vae.py @@ -0,0 +1,110 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .sparse_unet_vae import ( + SparseResBlock3d, + SparseConvNeXtBlock3d, + + SparseResBlockDownsample3d, + SparseResBlockUpsample3d, + SparseResBlockS2C3d, + SparseResBlockC2S3d, +) +from .sparse_unet_vae import ( + SparseUnetVaeEncoder, + SparseUnetVaeDecoder, +) +from ...representations import Mesh +from o_voxel.convert import flexible_dual_grid_to_mesh + + +class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder): + def __init__( + self, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__( + 6, + model_channels, + latent_channels, + num_blocks, + block_type, + down_block_type, + block_args, + use_fp16, + ) + + def forward(self, vertices: sp.SparseTensor, intersected: sp.SparseTensor, sample_posterior=False, return_raw=False): + x = vertices.replace(torch.cat([ + vertices.feats - 0.5, + intersected.feats.float() - 0.5, + ], dim=1)) + return super().forward(x, sample_posterior, return_raw) + + +class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): + def __init__( + self, + resolution: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + voxel_margin: float = 0.5, + use_fp16: bool = False, + ): + self.resolution = resolution + self.voxel_margin = voxel_margin + + super().__init__( + 7, + model_channels, + latent_channels, + num_blocks, + block_type, + up_block_type, + block_args, + use_fp16, + ) + + def set_resolution(self, resolution: int) -> None: + self.resolution = resolution + + def forward(self, x: sp.SparseTensor, gt_intersected: sp.SparseTensor = None, **kwargs): + decoded = super().forward(x, **kwargs) + if self.training: + h, subs_gt, subs = decoded + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected_logits = h.replace(h.feats[..., 3:6]) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(flexible_dual_grid_to_mesh( + h.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=True + )) for v, i, q in zip(vertices, gt_intersected, quad_lerp)] + return mesh, vertices, intersected_logits, subs_gt, subs + else: + out_list = list(decoded) if isinstance(decoded, tuple) else [decoded] + h = out_list[0] + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected = h.replace(h.feats[..., 3:6] > 0) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(*flexible_dual_grid_to_mesh( + h.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=False + )) for v, i, q in zip(vertices, intersected, quad_lerp)] + out_list[0] = mesh + return out_list[0] if len(out_list) == 1 else tuple(out_list) diff --git a/trellis2/models/sc_vaes/sparse_unet_vae.py b/trellis2/models/sc_vaes/sparse_unet_vae.py new file mode 100755 index 0000000000000000000000000000000000000000..b9902a155a8a85c5c616a4503be92f43f6fdde27 --- /dev/null +++ b/trellis2/models/sc_vaes/sparse_unet_vae.py @@ -0,0 +1,522 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from ...modules.utils import convert_module_to_f16, convert_module_to_f32, zero_module +from ...modules import sparse as sp +from ...modules.norm import LayerNorm32 + + +class SparseResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + resample_mode: Literal['nearest', 'spatial2channel'] = 'nearest', + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + self.resample_mode = resample_mode + self.use_checkpoint = use_checkpoint + + assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + if resample_mode == 'nearest': + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + elif resample_mode =='spatial2channel' and not self.downsample: + self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3) + elif resample_mode =='spatial2channel' and self.downsample: + self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + if resample_mode == 'nearest': + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + elif resample_mode =='spatial2channel' and self.downsample: + self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1)) + elif resample_mode =='spatial2channel' and not self.downsample: + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + self.updown = None + if self.downsample: + if resample_mode == 'nearest': + self.updown = sp.SparseDownsample(2) + elif resample_mode =='spatial2channel': + self.updown = sp.SparseSpatial2Channel(2) + elif self.upsample: + self.to_subdiv = sp.SparseLinear(channels, 8) + if resample_mode == 'nearest': + self.updown = sp.SparseUpsample(2) + elif resample_mode =='spatial2channel': + self.updown = sp.SparseChannel2Spatial(2) + + def _updown(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.downsample: + x = self.updown(x) + elif self.upsample: + x = self.updown(x, subdiv.replace(subdiv.feats > 0)) + return x + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + subdiv = None + if self.upsample: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + if self.resample_mode == 'spatial2channel': + h = self.conv1(h) + h = self._updown(h, subdiv) + x = self._updown(x, subdiv) + if self.resample_mode == 'nearest': + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.upsample: + return h, subdiv + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockDownsample3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.updown = sp.SparseDownsample(2) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.updown(h) + x = self.updown(x) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockUpsample3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + if self.pred_subdiv: + self.to_subdiv = sp.SparseLinear(channels, 8) + self.updown = sp.SparseUpsample(2) + + def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockS2C3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1)) + self.updown = sp.SparseSpatial2Channel(2) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = self.updown(h) + x = self.updown(x) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockC2S3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + if pred_subdiv: + self.to_subdiv = sp.SparseLinear(channels, 8) + self.updown = sp.SparseChannel2Spatial(2) + + def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h + + def forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, subdiv, use_reentrant=False) + else: + return self._forward(x, subdiv) + + +class SparseConvNeXtBlock3d(nn.Module): + def __init__( + self, + channels: int, + mlp_ratio: float = 4.0, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.use_checkpoint = use_checkpoint + + self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.conv = sp.SparseConv3d(channels, channels, 3) + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.SiLU(), + zero_module(nn.Linear(int(channels * mlp_ratio), channels)), + ) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.conv(x) + h = h.replace(self.norm(h.feats)) + h = h.replace(self.mlp(h.feats)) + return h + x + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseUnetVaeEncoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + in_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = sp.SparseLinear(in_channels, model_channels[0]) + self.to_latent = sp.SparseLinear(model_channels[-1], 2 * latent_channels) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[down_block_type[i]]( + model_channels[i], + model_channels[i+1], + **block_args[i], + ) + ) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False): + h = self.input_layer(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.to_latent(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z + + +class SparseUnetVaeDecoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + out_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.out_channels = out_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.use_fp16 = use_fp16 + self.pred_subdiv = pred_subdiv + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.low_vram = False + + self.output_layer = sp.SparseLinear(model_channels[-1], out_channels) + self.from_latent = sp.SparseLinear(latent_channels, model_channels[0]) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[up_block_type[i]]( + model_channels[i], + model_channels[i+1], + pred_subdiv=pred_subdiv, + **block_args[i], + ) + ) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor]] = None, return_subs: bool = False) -> sp.SparseTensor: + assert guide_subs is None or self.pred_subdiv == False, "Only decoders with pred_subdiv=False can be used with guide_subs" + assert return_subs == False or self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with return_subs" + + h = self.from_latent(x) + h = h.type(self.dtype) + subs_gt = [] + subs = [] + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + if self.pred_subdiv: + if self.training: + subs_gt.append(h.get_spatial_cache('subdivision')) + h, sub = block(h) + subs.append(sub) + else: + h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) + else: + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.output_layer(h) + if self.training and self.pred_subdiv: + return h, subs_gt, subs + else: + if return_subs: + return h, subs + else: + return h + + def upsample(self, x: sp.SparseTensor, upsample_times: int) -> torch.Tensor: + assert self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with upsampling" + + h = self.from_latent(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + if i == upsample_times: + return h.coords + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + h, sub = block(h) + else: + h = block(h) + \ No newline at end of file diff --git a/trellis2/models/sparse_elastic_mixin.py b/trellis2/models/sparse_elastic_mixin.py new file mode 100755 index 0000000000000000000000000000000000000000..66d204c89bedabc2afd1795cdfc6f5d58a6b1ac0 --- /dev/null +++ b/trellis2/models/sparse_elastic_mixin.py @@ -0,0 +1,24 @@ +from contextlib import contextmanager +from typing import * +import math +from ..modules import sparse as sp +from ..utils.elastic_utils import ElasticModuleMixin + + +class SparseTransformerElasticMixin(ElasticModuleMixin): + def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs): + return x.feats.shape[0] + + @contextmanager + def with_mem_ratio(self, mem_ratio=1.0): + if mem_ratio == 1.0: + yield 1.0 + return + num_blocks = len(self.blocks) + num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks) + exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks + for i in range(num_blocks): + self.blocks[i].use_checkpoint = i < num_checkpoint_blocks + yield exact_mem_ratio + for i in range(num_blocks): + self.blocks[i].use_checkpoint = False diff --git a/trellis2/models/sparse_structure_flow.py b/trellis2/models/sparse_structure_flow.py new file mode 100755 index 0000000000000000000000000000000000000000..a8e9bebcad41a97a35e962296af55f361db3c687 --- /dev/null +++ b/trellis2/models/sparse_structure_flow.py @@ -0,0 +1,248 @@ +from typing import * +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..trainers.utils import str_to_dtype +from ..modules.utils import convert_module_to, manual_cast +from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock +from ..modules.attention import RotaryPositionEmbedder + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + **kwargs + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = str_to_dtype(dtype) + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + self.register_buffer("pos_emb", pos_emb) + elif pe_mode == "rope": + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + rope_phases = pos_embedder(coords) + self.register_buffer("rope_phases", rope_phases) + + if pe_mode != "rope": + self.rope_phases = None + + self.input_layer = nn.Linear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to(self, dtype: torch.dtype) -> None: + """ + Convert the torso of the model to the specified dtype. + """ + self.dtype = dtype + self.blocks.apply(partial(convert_module_to, dtype=dtype)) + + def initialize_weights(self) -> None: + if self.initialization == 'vanilla': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + elif self.initialization == 'scaled': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels))) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Scaled init for to_out and ffn2 + def _scaled_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels)) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + for block in self.blocks: + block.self_attn.to_out.apply(_scaled_init) + block.cross_attn.to_out.apply(_scaled_init) + block.mlp.mlp[2].apply(_scaled_init) + + # Initialize input layer to make the initial representation have variance 1 + nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) + nn.init.zeros_(self.input_layer.bias) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + if self.pe_mode == "ape": + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + h = manual_cast(h, self.dtype) + cond = manual_cast(cond, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond, self.rope_phases) + h = manual_cast(h, x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() + + return h diff --git a/trellis2/models/sparse_structure_vae.py b/trellis2/models/sparse_structure_vae.py new file mode 100755 index 0000000000000000000000000000000000000000..c3e09136cf294c4c1b47b0f09fa6ee57bad2166d --- /dev/null +++ b/trellis2/models/sparse_structure_vae.py @@ -0,0 +1,306 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..modules.norm import GroupNorm32, ChannelLayerNorm32 +from ..modules.spatial import pixel_shuffle_3d +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + DownsampleBlock3d(ch, channels[i+1]) + ) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ]) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: + h = self.input_layer(x) + h = h.type(self.dtype) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + + mean, logvar = h.chunk(2, dim=1) + + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + + if return_raw: + return z, mean, logvar + return z + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h diff --git a/trellis2/models/structured_latent_flow.py b/trellis2/models/structured_latent_flow.py new file mode 100755 index 0000000000000000000000000000000000000000..c7b0aa61622159e95beb8171808c1617bff4bd0b --- /dev/null +++ b/trellis2/models/structured_latent_flow.py @@ -0,0 +1,208 @@ +from typing import * +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..trainers.utils import str_to_dtype +from ..modules.utils import convert_module_to, manual_cast +from ..modules.transformer import AbsolutePositionEmbedder +from ..modules import sparse as sp +from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock +from .sparse_structure_flow import TimestepEmbedder +from .sparse_elastic_mixin import SparseTransformerElasticMixin + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = str_to_dtype(dtype) + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = sp.SparseLinear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to(self, dtype: torch.dtype) -> None: + """ + Convert the torso of the model to the specified dtype. + """ + self.dtype = dtype + self.blocks.apply(partial(convert_module_to, dtype=dtype)) + + def initialize_weights(self) -> None: + if self.initialization == 'vanilla': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + elif self.initialization == 'scaled': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels))) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Scaled init for to_out and ffn2 + def _scaled_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels)) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + for block in self.blocks: + block.self_attn.to_out.apply(_scaled_init) + block.cross_attn.to_out.apply(_scaled_init) + block.mlp.mlp[2].apply(_scaled_init) + + # Initialize input layer to make the initial representation have variance 1 + nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) + nn.init.zeros_(self.input_layer.bias) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward( + self, + x: sp.SparseTensor, + t: torch.Tensor, + cond: Union[torch.Tensor, List[torch.Tensor]], + concat_cond: Optional[sp.SparseTensor] = None, + **kwargs + ) -> sp.SparseTensor: + if concat_cond is not None: + x = sp.sparse_cat([x, concat_cond], dim=-1) + if isinstance(cond, list): + cond = sp.VarLenTensor.from_tensor_list(cond) + + h = self.input_layer(x) + h = manual_cast(h, self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + cond = manual_cast(cond, self.dtype) + + if self.pe_mode == "ape": + pe = self.pos_embedder(h.coords[:, 1:]) + h = h + manual_cast(pe, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + h = manual_cast(h, x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return h + + +class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel): + """ + SLat Flow Model with elastic memory management. + Used for training with low VRAM. + """ + pass diff --git a/trellis2/modules/attention/__init__.py b/trellis2/modules/attention/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e90e901f9b942e18aaa5022b3ea167784e52d42a --- /dev/null +++ b/trellis2/modules/attention/__init__.py @@ -0,0 +1,3 @@ +from .full_attn import * +from .modules import * +from .rope import * diff --git a/trellis2/modules/attention/config.py b/trellis2/modules/attention/config.py new file mode 100755 index 0000000000000000000000000000000000000000..a6d5180cfc92d0121aa9edd8b7f07f72c4f7ad9f --- /dev/null +++ b/trellis2/modules/attention/config.py @@ -0,0 +1,32 @@ +from typing import * + +BACKEND = 'flash_attn' +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_attn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_attn_debug is not None: + DEBUG = env_attn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug diff --git a/trellis2/modules/attention/full_attn.py b/trellis2/modules/attention/full_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..f56b835bf385dd423102ee7809bbae68d4d29a95 --- /dev/null +++ b/trellis2/modules/attention/full_attn.py @@ -0,0 +1,144 @@ +from typing import * +import torch +import math +from . import config + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if config.BACKEND == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif config.BACKEND == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif config.BACKEND == 'flash_attn_3': + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + if num_all_args == 1: + out = flash_attn_3.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn_3.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn_3.flash_attn_func(q, k, v) + elif config.BACKEND == 'sdpa': + if 'sdpa' not in globals(): + from torch.nn.functional import scaled_dot_product_attention as sdpa + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif config.BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {config.BACKEND}") + + return out diff --git a/trellis2/modules/attention/modules.py b/trellis2/modules/attention/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..492784c7ba8f572c4820b604f51c924ed564ab00 --- /dev/null +++ b/trellis2/modules/attention/modules.py @@ -0,0 +1,102 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention +from .rope import RotaryPositionEmbedder + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + + if self.attn_mode == "full": + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + assert phases is not None, "Phases must be provided for RoPE" + q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases) + k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h diff --git a/trellis2/modules/attention/rope.py b/trellis2/modules/attention/rope.py new file mode 100755 index 0000000000000000000000000000000000000000..1cf6c5b321c2417443ffad3804df97ff5fbe6658 --- /dev/null +++ b/trellis2/modules/attention/rope.py @@ -0,0 +1,48 @@ +from typing import * +import torch +import torch.nn as nn + + +class RotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + @staticmethod + def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, indices: torch.Tensor) -> torch.Tensor: + """ + Args: + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + return phases \ No newline at end of file diff --git a/trellis2/modules/norm.py b/trellis2/modules/norm.py new file mode 100755 index 0000000000000000000000000000000000000000..78675d0f850d2e34d3c90b5d6bc14db708e5b400 --- /dev/null +++ b/trellis2/modules/norm.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn +from .utils import manual_cast + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/trellis2/modules/sparse/__init__.py b/trellis2/modules/sparse/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e73f232abc6f31cafeac4172a96150906ba20b7b --- /dev/null +++ b/trellis2/modules/sparse/__init__.py @@ -0,0 +1,69 @@ +from . import config +import importlib + +__attributes = { + 'VarLenTensor': 'basic', + 'varlen_cat': 'basic', + 'varlen_unbind': 'basic', + 'SparseTensor': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_cross_attention': 'attention', + 'SparseRotaryPositionEmbedder': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide': 'spatial', + 'SparseSpatial2Channel': 'spatial', + 'SparseChannel2Spatial': 'spatial', + 'sparse_nearest_interpolate': 'spatial', + 'sparse_trilinear_interpolate': 'spatial', + 'encode_seq': 'serialize', + 'decode_seq': 'serialize', +} + +__submodules = ['transformer', 'conv'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + from .serialize import * + import transformer + import conv diff --git a/trellis2/modules/sparse/attention/__init__.py b/trellis2/modules/sparse/attention/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..18ab3cc0a2c96b430009c3b709db3eb28ce7ccc0 --- /dev/null +++ b/trellis2/modules/sparse/attention/__init__.py @@ -0,0 +1,3 @@ +from .full_attn import * +from .windowed_attn import * +from .modules import * diff --git a/trellis2/modules/sparse/attention/full_attn.py b/trellis2/modules/sparse/attention/full_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..37d70b78368db1c9ed6d1ae2f360f379ad72587f --- /dev/null +++ b/trellis2/modules/sparse/attention/full_attn.py @@ -0,0 +1,214 @@ +from typing import * +import torch +from .. import VarLenTensor +from .. import config + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: VarLenTensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (VarLenTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, kv: Union[VarLenTensor, torch.Tensor]) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (VarLenTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: VarLenTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, C] dense tensor containing Qs. + kv (VarLenTensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, k: VarLenTensor, v: VarLenTensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, k: torch.Tensor, v: torch.Tensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: VarLenTensor, v: VarLenTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, VarLenTensor), f"qkv must be a VarLenTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, VarLenTensor) and isinstance(kv, (VarLenTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, VarLenTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, VarLenTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, VarLenTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, VarLenTensor) and isinstance(k, (VarLenTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, VarLenTensor) and isinstance(v, VarLenTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, VarLenTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, VarLenTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif config.ATTN == 'flash_attn_3': + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn_3.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn_3.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + else: + raise ValueError(f"Unknown attention module: {config.ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/trellis2/modules/sparse/attention/modules.py b/trellis2/modules/sparse/attention/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..d762b4b2e01335d26d856b3fe82eca31a2735123 --- /dev/null +++ b/trellis2/modules/sparse/attention/modules.py @@ -0,0 +1,141 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import VarLenTensor, SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from .rope import SparseRotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, VarLenTensor): + x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale) + else: + x = F.normalize(x, dim=-1) * self.gamma * self.scale + return x.to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed", "double_windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed", "double_windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + if attn_mode == 'double_windowed': + assert window_size % 2 == 0, "Window size must be even for double windowed attention" + assert num_heads % 2 == 0, "Number of heads must be even for double windowed attention" + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) + + @staticmethod + def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats + + def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=-3) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + q, k = self.rope(q, k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + elif self.attn_mode == "double_windowed": + qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) + qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) + h0 = sparse_windowed_scaled_dot_product_self_attention( + qkv0, self.window_size, shift_window=(0, 0, 0) + ) + h1 = sparse_windowed_scaled_dot_product_self_attention( + qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) + ) + h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=-3) + k = self.k_rms_norm(k) + h = sparse_scaled_dot_product_attention(q, k, v) + else: + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h diff --git a/trellis2/modules/sparse/attention/rope.py b/trellis2/modules/sparse/attention/rope.py new file mode 100755 index 0000000000000000000000000000000000000000..fb877291f3430f2cff4329a67b3592e0e3c3f137 --- /dev/null +++ b/trellis2/modules/sparse/attention/rope.py @@ -0,0 +1,58 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor + + +class SparseRotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (SparseTensor): [..., N, H, D] tensor of queries + k (SparseTensor): [..., N, H, D] tensor of keys + """ + assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" + phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' + phases = q.get_spatial_cache(phases_cache_name) + if phases is None: + coords = q.coords[..., 1:] + phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + q.register_spatial_cache(phases_cache_name, phases) + q_embed = q.replace(self._rotary_embedding(q.feats, phases)) + if k is None: + return q_embed + k_embed = k.replace(self._rotary_embedding(k.feats, phases)) + return q_embed, k_embed \ No newline at end of file diff --git a/trellis2/modules/sparse/attention/windowed_attn.py b/trellis2/modules/sparse/attention/windowed_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..043078899833e52ebf29d3ba1898cbd90e03f1d4 --- /dev/null +++ b/trellis2/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,190 @@ +from typing import * +import torch +import math +from .. import SparseTensor +from .. import config + + +__all__ = [ + 'sparse_windowed_scaled_dot_product_self_attention', + 'sparse_windowed_scaled_dot_product_cross_attention', +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (torch.Tensor): Sequence lengths. + (dict): Attn func args. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)] + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + mask = seq_lens != 0 + seq_lens = seq_lens[mask] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + attn_func_args = { + 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + } + elif config.ATTN == 'flash_attn': + attn_func_args = { + 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), + 'max_seqlen': torch.max(seq_lens) + } + + return fwd_indices, bwd_indices, seq_lens, attn_func_args + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (SparseTensor): [N, *, H, C] sparse tensor containing the output features. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args)) + else: + fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if config.DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start:start+seq_lens[i]] + assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + start += seq_lens[i] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if config.DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) + + +def sparse_windowed_scaled_dot_product_cross_attention( + q: SparseTensor, + kv: SparseTensor, + q_window_size: int, + kv_window_size: int, + q_shift_window: Tuple[int, int, int] = (0, 0, 0), + kv_shift_window: Tuple[int, int, int] = (0, 0, 0), +) -> SparseTensor: + """ + Apply windowed scaled dot product cross attention to two sparse tensors. + + Args: + q (SparseTensor): [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor): [N, *, 2, H, C] sparse tensor containing Ks and Vs. + q_window_size (int): The window size to use for Qs. + kv_window_size (int): The window size to use for Ks and Vs. + q_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Qs. + kv_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Ks and Vs. + + Returns: + (SparseTensor): [N, *, H, C] sparse tensor containing the output features. + """ + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + + q_serialization_spatial_cache_name = f'windowed_attention_{q_window_size}_{q_shift_window}' + q_serialization_spatial_cache = q.get_spatial_cache(q_serialization_spatial_cache_name) + if q_serialization_spatial_cache is None: + q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = calc_window_partition(q, q_window_size, q_shift_window) + q.register_spatial_cache(q_serialization_spatial_cache_name, (q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args)) + else: + q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = q_serialization_spatial_cache + kv_serialization_spatial_cache_name = f'windowed_attention_{kv_window_size}_{kv_shift_window}' + kv_serialization_spatial_cache = kv.get_spatial_cache(kv_serialization_spatial_cache_name) + if kv_serialization_spatial_cache is None: + kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = calc_window_partition(kv, kv_window_size, kv_shift_window) + kv.register_spatial_cache(kv_serialization_spatial_cache_name, (kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args)) + else: + kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = kv_serialization_spatial_cache + + assert len(q_seq_lens) == len(kv_seq_lens), "Number of sequences in q and kv must match" + + q_feats = q.feats[q_fwd_indices] # [M, H, C] + kv_feats = kv.feats[kv_fwd_indices] # [M, 2, H, C] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + k, v = kv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seq_lens, kv_seq_lens) + out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)[0] # [M, H, C] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_kvpacked_func(q_feats, kv_feats, + cu_seqlens_q=q_attn_func_args['cu_seqlens'], cu_seqlens_k=kv_attn_func_args['cu_seqlens'], + max_seqlen_q=q_attn_func_args['max_seqlen'], max_seqlen_k=kv_attn_func_args['max_seqlen'], + ) # [M, H, C] + + out = out[q_bwd_indices] # [T, H, C] + + return q.replace(out) diff --git a/trellis2/modules/sparse/basic.py b/trellis2/modules/sparse/basic.py new file mode 100755 index 0000000000000000000000000000000000000000..880973b8dd6bdafcfca4ca7c529308d2ef2ad266 --- /dev/null +++ b/trellis2/modules/sparse/basic.py @@ -0,0 +1,836 @@ +from typing import * +from fractions import Fraction +import torch +from . import config + + +__all__ = [ + 'VarLenTensor', + 'varlen_cat', + 'varlen_unbind', + 'SparseTensor', + 'sparse_cat', + 'sparse_unbind', +] + + +class VarLenTensor: + """ + Sequential tensor with variable length. + + Args: + feats (torch.Tensor): Features of the varlen tensor. + layout (List[slice]): Layout of the varlen tensor for each batch + """ + def __init__(self, feats: torch.Tensor, layout: List[slice]=None): + self.feats = feats + self.layout = layout if layout is not None else [slice(0, feats.shape[0])] + self._cache = {} + + @staticmethod + def layout_from_seqlen(seqlen: list) -> List[slice]: + """ + Create a layout from a tensor of sequence lengths. + """ + layout = [] + start = 0 + for l in seqlen: + layout.append(slice(start, start + l)) + start += l + return layout + + @staticmethod + def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': + """ + Create a VarLenTensor from a list of tensors. + """ + feats = torch.cat(tensor_list, dim=0) + layout = [] + start = 0 + for tensor in tensor_list: + layout.append(slice(start, start + tensor.shape[0])) + start += tensor.shape[0] + return VarLenTensor(feats, layout) + + def to_tensor_list(self) -> List[torch.Tensor]: + """ + Convert a VarLenTensor to a list of tensors. + """ + tensor_list = [] + for s in self.layout: + tensor_list.append(self.feats[s]) + return tensor_list + + def __len__(self) -> int: + return len(self.layout) + + @property + def shape(self) -> torch.Size: + return torch.Size([len(self.layout), *self.feats.shape[1:]]) + + def dim(self) -> int: + return len(self.shape) + + @property + def ndim(self) -> int: + return self.dim() + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + if 'seqlen' not in self._cache: + self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + return self._cache['seqlen'] + + @property + def cum_seqlen(self) -> torch.LongTensor: + if 'cum_seqlen' not in self._cache: + self._cache['cum_seqlen'] = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + return self._cache['cum_seqlen'] + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + if 'batch_boardcast_map' not in self._cache: + self._cache['batch_boardcast_map'] = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + return self._cache['batch_boardcast_map'] + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + def to(self, *args, **kwargs) -> 'VarLenTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'VarLenTensor': + new_feats = self.feats.cpu() + return self.replace(new_feats) + + def cuda(self) -> 'VarLenTensor': + new_feats = self.feats.cuda() + return self.replace(new_feats) + + def half(self) -> 'VarLenTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'VarLenTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'VarLenTensor': + new_feats = self.feats.detach() + return self.replace(new_feats) + + def reshape(self, *shape) -> 'VarLenTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['VarLenTensor']: + return varlen_unbind(self, dim) + + def replace(self, feats: torch.Tensor) -> 'VarLenTensor': + new_tensor = VarLenTensor( + feats=feats, + layout=self.layout, + ) + new_tensor._cache = self._cache + return new_tensor + + def to_dense(self, max_length=None) -> torch.Tensor: + """ + Convert a VarLenTensor to a dense representation without for-loop. + + Returns: + dense (torch.Tensor): (N, L, C) dense tensor + mask (torch.BoolTensor): (N, L) mask indicating valid positions + """ + N = len(self) + L = max_length or self.seqlen.max().item() + spatial = self.feats.shape[1:] + idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L) + mask = (idx < self.seqlen.unsqueeze(1)) + mapping = mask.reshape(-1).cumsum(dim=0) - 1 + dense = self.feats[mapping] + dense = dense.reshape(N, L, *spatial) + return dense, mask + + def __neg__(self) -> 'VarLenTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_feats = [] + new_layout = [] + start = 0 + for new_idx, old_idx in enumerate(idx): + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_feats[-1]))) + start += len(new_feats[-1]) + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = VarLenTensor(feats=new_feats, layout=new_layout) + return new_tensor + + def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + if isinstance(dim, int): + dim = (dim,) + + if op =='mean': + red = self.feats.mean(dim=dim, keepdim=keepdim) + elif op =='sum': + red = self.feats.sum(dim=dim, keepdim=keepdim) + elif op == 'prod': + red = self.feats.prod(dim=dim, keepdim=keepdim) + else: + raise ValueError(f"Unsupported reduce operation: {op}") + + if dim is None or 0 in dim: + return red + + red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) + return red + + def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='mean', dim=dim, keepdim=keepdim) + + def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='sum', dim=dim, keepdim=keepdim) + + def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='prod', dim=dim, keepdim=keepdim) + + def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + mean = self.mean(dim=dim, keepdim=True) + mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True) + std = (mean2 - mean ** 2).sqrt() + return std + + def __repr__(self) -> str: + return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + + +def varlen_cat(inputs: List[VarLenTensor], dim: int = 0) -> VarLenTensor: + """ + Concatenate a list of varlen tensors. + + Args: + inputs (List[VarLenTensor]): List of varlen tensors to concatenate. + """ + if dim == 0: + new_feats = torch.cat([input.feats for input in inputs], dim=0) + start = 0 + new_layout = [] + for input in inputs: + for l in input.layout: + new_layout.append(slice(start, start + l.stop - l.start)) + start += l.stop - l.start + output = VarLenTensor(feats=new_feats, layout=new_layout) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: + """ + Unbind a varlen tensor along a dimension. + + Args: + input (VarLenTensor): Varlen tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(len(input))] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + + +class SparseTensor(VarLenTensor): + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + SparseTensorData = None + + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + if self.SparseTensorData is None: + import importlib + if config.CONV == 'torchsparse': + self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif config.CONV == 'spconv': + self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape = args + (None,) * (3 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + if config.CONV == 'torchsparse': + self.data = self.SparseTensorData(feats, coords, **kwargs) + elif config.CONV == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1) + self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs) + self.data._features = feats + else: + self.data = { + 'feats': feats, + 'coords': coords, + } + elif method_id == 1: + data, shape = args + (None,) * (2 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + self.data = data + + self._shape = shape + self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1))) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if config.DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + @staticmethod + def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor': + """ + Create a SparseTensor from a list of tensors. + """ + feats = torch.cat(feats_list, dim=0) + coords = [] + for i, coord in enumerate(coords_list): + coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1) + coords.append(coord) + coords = torch.cat(coords, dim=0) + return SparseTensor(feats, coords) + + def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Convert a SparseTensor to list of tensors. + """ + feats_list = [] + coords_list = [] + for s in self.layout: + feats_list.append(self.feats[s]) + coords_list.append(self.coords[s]) + return feats_list, coords_list + + def __len__(self) -> int: + return len(self.layout) + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + def __cal_spatial_shape(self, coords): + return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist()) + + @property + def shape(self) -> torch.Size: + if self._shape is None: + self._shape = self.__cal_shape(self.feats, self.coords) + return self._shape + + @property + def layout(self) -> List[slice]: + layout = self.get_spatial_cache('layout') + if layout is None: + layout = self.__cal_layout(self.coords, self.shape[0]) + self.register_spatial_cache('layout', layout) + return layout + + @property + def spatial_shape(self) -> torch.Size: + spatial_shape = self.get_spatial_cache('shape') + if spatial_shape is None: + spatial_shape = self.__cal_spatial_shape(self.coords) + self.register_spatial_cache('shape', spatial_shape) + return spatial_shape + + @property + def feats(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.F + elif config.CONV == 'spconv': + return self.data.features + else: + return self.data['feats'] + + @feats.setter + def feats(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.F = value + elif config.CONV == 'spconv': + self.data.features = value + else: + self.data['feats'] = value + + @property + def coords(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.C + elif config.CONV == 'spconv': + return self.data.indices + else: + return self.data['coords'] + + @coords.setter + def coords(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.C = value + elif config.CONV == 'spconv': + self.data.indices = value + else: + self.data['coords'] = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + seqlen = self.get_spatial_cache('seqlen') + if seqlen is None: + seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + self.register_spatial_cache('seqlen', seqlen) + return seqlen + + @property + def cum_seqlen(self) -> torch.LongTensor: + cum_seqlen = self.get_spatial_cache('cum_seqlen') + if cum_seqlen is None: + cum_seqlen = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + self.register_spatial_cache('cum_seqlen', cum_seqlen) + return cum_seqlen + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map') + if batch_boardcast_map is None: + batch_boardcast_map = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map) + return batch_boardcast_map + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + if config.CONV == 'torchsparse': + new_data = self.SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif config.CONV == 'spconv': + new_data = self.SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + else: + new_data = { + 'feats': feats, + 'coords': self.data['coords'] if coords is None else coords, + } + new_tensor = SparseTensor( + new_data, + shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None, + scale=self._scale, + spatial_cache=self._spatial_cache + ) + return new_tensor + + def to_dense(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.dense() + elif config.CONV == 'spconv': + return self.data.dense() + else: + spatial_shape = self.spatial_shape + ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device) + idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1) + ret[tuple(idx)] = self.feats + return ret + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_coords = [] + new_feats = [] + new_layout = [] + new_shape = torch.Size([len(idx)] + list(self.shape[1:])) + start = 0 + for new_idx, old_idx in enumerate(idx): + new_coords.append(self.coords[self.layout[old_idx]].clone()) + new_coords[-1][:, 0] = new_idx + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_coords[-1]))) + start += len(new_coords[-1]) + new_coords = torch.cat(new_coords, dim=0).contiguous() + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape) + new_tensor.register_spatial_cache('layout', new_layout) + return new_tensor + + def clear_spatial_cache(self) -> None: + """ + Clear all spatial caches. + """ + self._spatial_cache = {} + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + def __repr__(self) -> str: + return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/trellis2/modules/sparse/config.py b/trellis2/modules/sparse/config.py new file mode 100755 index 0000000000000000000000000000000000000000..a5f4d532316efcf989ab70b9b5218ba962bafe07 --- /dev/null +++ b/trellis2/modules/sparse/config.py @@ -0,0 +1,43 @@ +from typing import * + +CONV = 'flex_gemm' +DEBUG = False +ATTN = 'flash_attn' + +def __from_env(): + import os + + global CONV + global DEBUG + global ATTN + + env_sparse_conv_backend = os.environ.get('SPARSE_CONV_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn_backend = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn_backend is None: + env_sparse_attn_backend = os.environ.get('ATTN_BACKEND') + + if env_sparse_conv_backend is not None and env_sparse_conv_backend in ['none', 'spconv', 'torchsparse', 'flex_gemm']: + CONV = env_sparse_conv_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn_backend is not None and env_sparse_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3']: + ATTN = env_sparse_attn_backend + + print(f"[SPARSE] Conv backend: {CONV}; Attention backend: {ATTN}") + + +__from_env() + + +def set_conv_backend(backend: Literal['none', 'spconv', 'torchsparse', 'flex_gemm']): + global CONV + CONV = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn_backend(backend: Literal['xformers', 'flash_attn']): + global ATTN + ATTN = backend diff --git a/trellis2/modules/sparse/conv/__init__.py b/trellis2/modules/sparse/conv/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a7f5911f2cb266f93e52cfcdea9f63f39be172c6 --- /dev/null +++ b/trellis2/modules/sparse/conv/__init__.py @@ -0,0 +1,2 @@ +from .conv import SparseConv3d, SparseInverseConv3d +from . import config diff --git a/trellis2/modules/sparse/conv/config.py b/trellis2/modules/sparse/conv/config.py new file mode 100755 index 0000000000000000000000000000000000000000..ac0848906703e7811300235e32d14e50ad5aac51 --- /dev/null +++ b/trellis2/modules/sparse/conv/config.py @@ -0,0 +1,3 @@ +SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' +FLEX_GEMM_ALGO = 'masked_implicit_gemm_splitk' # 'explicit_gemm', 'implicit_gemm', 'implicit_gemm_splitk', 'masked_implicit_gemm', 'masked_implicit_gemm_splitk' +FLEX_GEMM_HASHMAP_RATIO = 2.0 # Ratio of hashmap size to input size diff --git a/trellis2/modules/sparse/conv/conv.py b/trellis2/modules/sparse/conv/conv.py new file mode 100755 index 0000000000000000000000000000000000000000..4c7d40707a24e26ba2a11a90c41dbc9eb11e7ab2 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv.py @@ -0,0 +1,30 @@ +from .. import config +import importlib +import torch +import torch.nn as nn +from .. import SparseTensor + + +_backends = {} + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if config.CONV not in _backends: + _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) + _backends[config.CONV].sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key) + + def forward(self, x: SparseTensor) -> SparseTensor: + return _backends[config.CONV].sparse_conv3d_forward(self, x) + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if config.CONV not in _backends: + _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) + _backends[config.CONV].sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, bias, indice_key) + + def forward(self, x: SparseTensor) -> SparseTensor: + return _backends[config.CONV].sparse_inverse_conv3d_forward(self, x) diff --git a/trellis2/modules/sparse/conv/conv_flex_gemm.py b/trellis2/modules/sparse/conv/conv_flex_gemm.py new file mode 100755 index 0000000000000000000000000000000000000000..d25619475e4bf39307f97d47b3828b19c48cd7da --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_flex_gemm.py @@ -0,0 +1,68 @@ +import math +import torch +import torch.nn as nn +from .. import SparseTensor +from . import config +import flex_gemm +from flex_gemm.ops.spconv import sparse_submanifold_conv3d + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + assert stride == 1 and (padding is None), 'Currently flex_gemm implementation only support submanifold sparse convolution (stride=1, padding=None)' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3 + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3 + self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3 + + self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size))) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + # initialize parameters + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + # Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci) + self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous()) + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + flex_gemm.ops.spconv.set_algorithm(config.FLEX_GEMM_ALGO) + flex_gemm.ops.spconv.set_hashmap_ratio(config.FLEX_GEMM_HASHMAP_RATIO) + + # check if neighbor map is already computed + Co, Kd, Kh, Kw, Ci = self.weight.shape + neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' + neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + + out, neighbor_cache_ = sparse_submanifold_conv3d( + x.feats, + x.coords, + torch.Size([*x.shape, *x.spatial_shape]), + self.weight, + self.bias, + neighbor_cache, + self.dilation + ) + + if neighbor_cache is None: + x.register_spatial_cache(neighbor_cache_key, neighbor_cache_) + + out = x.replace(out) + return out + + +def sparse_inverse_conv3d_init(self, *args, **kwargs): + raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet') + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet') diff --git a/trellis2/modules/sparse/conv/conv_spconv.py b/trellis2/modules/sparse/conv/conv_spconv.py new file mode 100755 index 0000000000000000000000000000000000000000..f709708d4f8ee3bb98930c68b28e7f3b897fc591 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from . import config +import spconv.pytorch as spconv + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + algo = None + if config.SPCONV_ALGO == 'native': + algo = spconv.ConvAlgo.Native + elif config.SPCONV_ALGO == 'implicit_gemm': + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features[fwd] + sorted_coords = new_data.indices[fwd] + unsorted_data = new_data + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/trellis2/modules/sparse/conv/conv_torchsparse.py b/trellis2/modules/sparse/conv/conv_torchsparse.py new file mode 100755 index 0000000000000000000000000000000000000000..5234bd15553aa8d71df280672475a898ffe56af7 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +import torchsparse + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + +def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s / stride for s, stride in zip(x._scale, self.conv.stride)]) + return out diff --git a/trellis2/modules/sparse/linear.py b/trellis2/modules/sparse/linear.py new file mode 100755 index 0000000000000000000000000000000000000000..44317709ab16b17fa0132fd48e48519ea0ef9ea9 --- /dev/null +++ b/trellis2/modules/sparse/linear.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from . import VarLenTensor + +__all__ = [ + 'SparseLinear' +] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) diff --git a/trellis2/modules/sparse/nonlinearity.py b/trellis2/modules/sparse/nonlinearity.py new file mode 100755 index 0000000000000000000000000000000000000000..950e5c03c997905162e39fec701db49a2640700c --- /dev/null +++ b/trellis2/modules/sparse/nonlinearity.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from . import VarLenTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' +] + + +class SparseReLU(nn.ReLU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseGELU(nn.GELU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(self.activation(input.feats)) + diff --git a/trellis2/modules/sparse/norm.py b/trellis2/modules/sparse/norm.py new file mode 100755 index 0000000000000000000000000000000000000000..95711203f0adbae1c7ea845e2500a3823997d652 --- /dev/null +++ b/trellis2/modules/sparse/norm.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +from ..utils import manual_cast +from . import VarLenTensor +from . import config + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: VarLenTensor) -> VarLenTensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: VarLenTensor) -> VarLenTensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) diff --git a/trellis2/modules/sparse/spatial/__init__.py b/trellis2/modules/sparse/spatial/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e27425f165d271fd16a2c6f7b7684d4a81202ebd --- /dev/null +++ b/trellis2/modules/sparse/spatial/__init__.py @@ -0,0 +1,2 @@ +from .basic import * +from .spatial2channel import * diff --git a/trellis2/modules/sparse/spatial/basic.py b/trellis2/modules/sparse/spatial/basic.py new file mode 100755 index 0000000000000000000000000000000000000000..eaeb8afefd889d8a94812c579383e528b8b56699 --- /dev/null +++ b/trellis2/modules/sparse/spatial/basic.py @@ -0,0 +1,109 @@ +from typing import * +import torch +import torch.nn as nn +from .. import SparseTensor + +__all__ = [ + 'SparseDownsample', + 'SparseUpsample', +] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + def __init__(self, factor: int, mode: Literal['mean', 'max'] = 'mean'): + super(SparseDownsample, self).__init__() + self.factor = factor + self.mode = mode + assert self.mode in ['mean', 'max'], f'Invalid mode: {self.mode}' + + def forward(self, x: SparseTensor) -> SparseTensor: + cache = x.get_spatial_cache(f'downsample_{self.factor}') + if cache is None: + DIM = x.coords.shape[-1] - 1 + + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx = cache + + new_feats = torch.scatter_reduce( + torch.zeros(new_coords.shape[0], x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype), + dim=0, + index=idx.unsqueeze(1).expand(-1, x.feats.shape[1]), + src=x.feats, + reduce=self.mode, + include_self=False, + ) + out = SparseTensor(new_feats, new_coords, x._shape) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'downsample_{self.factor}', (new_coords, idx)) + out.register_spatial_cache(f'upsample_{self.factor}', (x.coords, idx)) + out.register_spatial_cache(f'shape', torch.Size(MAX)) + if self.training: + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool) + subdivision[idx, subidx] = True + out.register_spatial_cache(f'subdivision', subdivision) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__( + self, factor: int + ): + super(SparseUpsample, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'upsample_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseUpsample with SparseDownsample.') + else: + sub = subdivision.feats + N_leaf = sub.sum(dim=-1) + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx = cache + + new_feats = x.feats[idx] + out = SparseTensor(new_feats, new_coords, x._shape) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + + return out + \ No newline at end of file diff --git a/trellis2/modules/sparse/spatial/spatial2channel.py b/trellis2/modules/sparse/spatial/spatial2channel.py new file mode 100755 index 0000000000000000000000000000000000000000..577f36d208726f64422f8774c3556a1d643f1e2d --- /dev/null +++ b/trellis2/modules/sparse/spatial/spatial2channel.py @@ -0,0 +1,93 @@ +from typing import * +import torch +import torch.nn as nn +from .. import SparseTensor + + +class SparseSpatial2Channel(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as rearranging its features from spatial to channel. + """ + def __init__(self, factor: int = 2): + super(SparseSpatial2Channel, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + cache = x.get_spatial_cache(f'spatial2channel_{self.factor}') + if cache is None: + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx, subidx = cache + + new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype) + new_feats[idx * self.factor ** DIM + subidx] = x.feats + + out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM])) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx)) + out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx)) + out.register_spatial_cache(f'shape', torch.Size(MAX)) + if self.training: + subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool) + subdivision[idx, subidx] = True + out.register_spatial_cache(f'subdivision', subdivision) + + return out + + +class SparseChannel2Spatial(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as rearranging its features from channel to spatial. + """ + def __init__(self, factor: int = 2): + super(SparseChannel2Spatial, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'channel2spatial_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.') + else: + sub = subdivision.feats # [N, self.factor ** DIM] + N_leaf = sub.sum(dim=-1) # [N] + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx, subidx = cache + + x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1) + new_feats = x_feats[idx * self.factor ** DIM + subidx] + out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM])) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + return out diff --git a/trellis2/modules/sparse/transformer/__init__.py b/trellis2/modules/sparse/transformer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/trellis2/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis2/modules/sparse/transformer/blocks.py b/trellis2/modules/sparse/transformer/blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..9d1ec600404fba490894872109d44be6b6477186 --- /dev/null +++ b/trellis2/modules/sparse/transformer/blocks.py @@ -0,0 +1,145 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import VarLenTensor, SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: VarLenTensor) -> VarLenTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/trellis2/modules/sparse/transformer/modulated.py b/trellis2/modules/sparse/transformer/modulated.py new file mode 100755 index 0000000000000000000000000000000000000000..e616932ef7714f0ec8ae3add655411971978eb54 --- /dev/null +++ b/trellis2/modules/sparse/transformer/modulated.py @@ -0,0 +1,166 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import VarLenTensor, SparseTensor +from ..attention import SparseMultiHeadAttention +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) diff --git a/trellis2/modules/spatial.py b/trellis2/modules/spatial.py new file mode 100755 index 0000000000000000000000000000000000000000..79e268d36c2ba49b0275744022a1a1e19983dae3 --- /dev/null +++ b/trellis2/modules/spatial.py @@ -0,0 +1,48 @@ +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) + x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + return x diff --git a/trellis2/modules/transformer/__init__.py b/trellis2/modules/transformer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/trellis2/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis2/modules/transformer/blocks.py b/trellis2/modules/transformer/blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..fb6f5eb5462fec62aa5edc062104f643fca03bfa --- /dev/null +++ b/trellis2/modules/transformer/blocks.py @@ -0,0 +1,186 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = True, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h, phases=phases) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, phases, use_reentrant=False) + else: + return self._forward(x, phases) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.self_attn(h, phases=phases) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, phases, use_reentrant=False) + else: + return self._forward(x, context, phases) + \ No newline at end of file diff --git a/trellis2/modules/transformer/modulated.py b/trellis2/modules/transformer/modulated.py new file mode 100755 index 0000000000000000000000000000000000000000..0d71e584a4e137e88010023981271b2906efa30f --- /dev/null +++ b/trellis2/modules/transformer/modulated.py @@ -0,0 +1,165 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, phases, use_reentrant=False) + else: + return self._forward(x, mod, phases) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) + else: + return self._forward(x, mod, context, phases) + \ No newline at end of file diff --git a/trellis2/modules/utils.py b/trellis2/modules/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..21b75a431379e55547a6ab7e19cefcdd1e70ab31 --- /dev/null +++ b/trellis2/modules/utils.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +from ..modules import sparse as sp + +MIX_PRECISION_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def convert_module_to(l, dtype): + """ + Convert primitive modules to the given dtype. + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.to(dtype) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def manual_cast(tensor, dtype): + """ + Cast if autocast is not enabled. + """ + if not torch.is_autocast_enabled(): + return tensor.type(dtype) + return tensor diff --git a/trellis2/pipelines/__init__.py b/trellis2/pipelines/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..53d8917ecad2d7c814a62447587ff31936a9c339 --- /dev/null +++ b/trellis2/pipelines/__init__.py @@ -0,0 +1,55 @@ +import importlib + +__attributes = { + "Trellis2ImageTo3DPipeline": "trellis2_image_to_3d", + "Trellis2ImageTo3DCascadePipeline": "trellis2_image_to_3d_cascade", + "Trellis2ImageToTexturePipeline": "trellis2_image_to_tex", +} + +__submodules = ['samplers', 'rembg'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str): + """ + Load a pipeline from a model folder or a Hugging Face model hub. + + Args: + path: The path to the model. Can be either local path or a Hugging Face model name. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + config = json.load(f) + return globals()[config['name']].from_pretrained(path) + + +# For PyLance +if __name__ == '__main__': + from . import samplers, rembg + from .trellis_image_to_3d import TrellisImageTo3DPipeline + from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline + from .trellis2_image_to_3d_cascade import Trellis2ImageTo3DCascadePipeline + from .trellis2_image_to_tex import Trellis2ImageToTexturePipeline diff --git a/trellis2/pipelines/base.py b/trellis2/pipelines/base.py new file mode 100755 index 0000000000000000000000000000000000000000..d897825c5f4f64dbbdca54c6a1af9001e8b8f56b --- /dev/null +++ b/trellis2/pipelines/base.py @@ -0,0 +1,70 @@ +from typing import * +import torch +import torch.nn as nn +from .. import models + + +class Pipeline: + """ + A base class for pipelines. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + ): + if models is None: + return + self.models = models + for model in self.models.values(): + model.eval() + + @staticmethod + def from_pretrained(path: str) -> "Pipeline": + """ + Load a pretrained model. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + args = json.load(f)['args'] + + _models = {} + for k, v in args['models'].items(): + try: + _models[k] = models.from_pretrained(f"{path}/{v}") + except Exception as e: + _models[k] = models.from_pretrained(v) + + new_pipeline = Pipeline(_models) + new_pipeline._pretrained_args = args + return new_pipeline + + @property + def device(self) -> torch.device: + if hasattr(self, '_device'): + return self._device + for model in self.models.values(): + if hasattr(model, 'device'): + return model.device + for model in self.models.values(): + if hasattr(model, 'parameters'): + return next(model.parameters()).device + raise RuntimeError("No device found.") + + def to(self, device: torch.device) -> None: + for model in self.models.values(): + model.to(device) + + def cuda(self) -> None: + self.to(torch.device("cuda")) + + def cpu(self) -> None: + self.to(torch.device("cpu")) \ No newline at end of file diff --git a/trellis2/pipelines/rembg/BiRefNet.py b/trellis2/pipelines/rembg/BiRefNet.py new file mode 100755 index 0000000000000000000000000000000000000000..c71a99274823aefe6f18ab921a5beb074177de18 --- /dev/null +++ b/trellis2/pipelines/rembg/BiRefNet.py @@ -0,0 +1,42 @@ +from typing import * +from transformers import AutoModelForImageSegmentation +import torch +from torchvision import transforms +from PIL import Image + + +class BiRefNet: + def __init__(self, model_name: str = "ZhengPeng7/BiRefNet"): + self.model = AutoModelForImageSegmentation.from_pretrained( + model_name, trust_remote_code=True + ) + self.model.eval() + self.transform_image = transforms.Compose( + [ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + def to(self, device: str): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + def __call__(self, image: Image.Image) -> Image.Image: + image_size = image.size + input_images = self.transform_image(image).unsqueeze(0).to("cuda") + # Prediction + with torch.no_grad(): + preds = self.model(input_images)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + image.putalpha(mask) + return image + \ No newline at end of file diff --git a/trellis2/pipelines/rembg/__init__.py b/trellis2/pipelines/rembg/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..fc1eed1ba0c962478fc48132432f2e649fe03411 --- /dev/null +++ b/trellis2/pipelines/rembg/__init__.py @@ -0,0 +1 @@ +from .BiRefNet import * diff --git a/trellis2/pipelines/samplers/__init__.py b/trellis2/pipelines/samplers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..4a69b95e5963c6d3a837518ed9c3cf6972235d60 --- /dev/null +++ b/trellis2/pipelines/samplers/__init__.py @@ -0,0 +1,6 @@ +from .base import Sampler +from .flow_euler import ( + FlowEulerSampler, + FlowEulerCfgSampler, + FlowEulerGuidanceIntervalSampler, +) \ No newline at end of file diff --git a/trellis2/pipelines/samplers/base.py b/trellis2/pipelines/samplers/base.py new file mode 100755 index 0000000000000000000000000000000000000000..1966ce787009a5ee0c1ed06dce491525ff1dbcbf --- /dev/null +++ b/trellis2/pipelines/samplers/base.py @@ -0,0 +1,20 @@ +from typing import * +from abc import ABC, abstractmethod + + +class Sampler(ABC): + """ + A base class for samplers. + """ + + @abstractmethod + def sample( + self, + model, + **kwargs + ): + """ + Sample from a model. + """ + pass + \ No newline at end of file diff --git a/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py new file mode 100755 index 0000000000000000000000000000000000000000..8c7a4da4c1324bed5319f89916b532a159471d84 --- /dev/null +++ b/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py @@ -0,0 +1,29 @@ +from typing import * + + +class ClassifierFreeGuidanceSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, guidance_strength, guidance_rescale=0.0, **kwargs): + if guidance_strength == 1: + return super()._inference_model(model, x_t, t, cond, **kwargs) + elif guidance_strength == 0: + return super()._inference_model(model, x_t, t, neg_cond, **kwargs) + else: + pred_pos = super()._inference_model(model, x_t, t, cond, **kwargs) + pred_neg = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + pred = guidance_strength * pred_pos + (1 - guidance_strength) * pred_neg + + # CFG rescale + if guidance_rescale > 0: + x_0_pos = self._pred_to_xstart(x_t, t, pred_pos) + x_0_cfg = self._pred_to_xstart(x_t, t, pred) + std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True) + std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True) + x_0_rescaled = x_0_cfg * (std_pos / std_cfg) + x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg + pred = self._xstart_to_pred(x_t, t, x_0) + + return pred diff --git a/trellis2/pipelines/samplers/flow_euler.py b/trellis2/pipelines/samplers/flow_euler.py new file mode 100755 index 0000000000000000000000000000000000000000..5ff72b84221f210f6cb06684bdf13aedc6cf20c3 --- /dev/null +++ b/trellis2/pipelines/samplers/flow_euler.py @@ -0,0 +1,208 @@ +from typing import * +import torch +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from .base import Sampler +from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin +from .guidance_interval_mixin import GuidanceIntervalSamplerMixin + + +class FlowEulerSampler(Sampler): + """ + Generate samples from a flow-matching model using Euler sampling. + + Args: + sigma_min: The minimum scale of noise in flow. + """ + def __init__( + self, + sigma_min: float, + ): + self.sigma_min = sigma_min + + def _eps_to_xstart(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) + + def _xstart_to_eps(self, x_t, t, x_0): + assert x_t.shape == x_0.shape + return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _v_to_xstart_eps(self, x_t, t, v): + assert x_t.shape == v.shape + eps = (1 - t) * v + x_t + x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + return x_0, eps + + def _pred_to_xstart(self, x_t, t, pred): + return (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * pred + + def _xstart_to_pred(self, x_t, t, x_0): + return ((1 - self.sigma_min) * x_t - x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _inference_model(self, model, x_t, t, cond=None, **kwargs): + t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + return model(x_t, t, cond, **kwargs) + + def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): + pred_v = self._inference_model(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) + return pred_x_0, pred_eps, pred_v + + @torch.no_grad() + def sample_once( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) + + @torch.no_grad() + def sample( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + tqdm_desc: str = "Sampling", + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + tqdm_desc: A customized tqdm desc. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_seq = t_seq.tolist() + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc=tqdm_desc, disable=not verbose): + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + +class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + guidance_strength: float = 3.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + guidance_strength: The strength of classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, **kwargs) + + +class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + guidance_strength: float = 3.0, + guidance_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + guidance_strength: The strength of classifier-free guidance. + guidance_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, guidance_interval=guidance_interval, **kwargs) diff --git a/trellis2/pipelines/samplers/guidance_interval_mixin.py b/trellis2/pipelines/samplers/guidance_interval_mixin.py new file mode 100755 index 0000000000000000000000000000000000000000..3f57869a17d1626f5b2c58eb3c477127bf464abf --- /dev/null +++ b/trellis2/pipelines/samplers/guidance_interval_mixin.py @@ -0,0 +1,13 @@ +from typing import * + + +class GuidanceIntervalSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance with interval. + """ + + def _inference_model(self, model, x_t, t, cond, guidance_strength, guidance_interval, **kwargs): + if guidance_interval[0] <= t <= guidance_interval[1]: + return super()._inference_model(model, x_t, t, cond, guidance_strength=guidance_strength, **kwargs) + else: + return super()._inference_model(model, x_t, t, cond, guidance_strength=1, **kwargs) diff --git a/trellis2/pipelines/trellis2_image_to_3d.py b/trellis2/pipelines/trellis2_image_to_3d.py new file mode 100755 index 0000000000000000000000000000000000000000..8f276389dd4ddaa9e5dac2f7034e2c53854e0fb5 --- /dev/null +++ b/trellis2/pipelines/trellis2_image_to_3d.py @@ -0,0 +1,587 @@ +from typing import * +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +from .base import Pipeline +from . import samplers, rembg +from .. import trainers +from ..modules import sparse as sp +from ..representations import Mesh, MeshWithVoxel + + +class Trellis2ImageTo3DPipeline(Pipeline): + """ + Pipeline for inferring Trellis2 image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. + shape_slat_sampler (samplers.Sampler): The sampler for the structured latent. + tex_slat_sampler (samplers.Sampler): The sampler for the texture latent. + sparse_structure_sampler_params (dict): The parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): The parameters for the structured latent sampler. + tex_slat_sampler_params (dict): The parameters for the texture latent sampler. + shape_slat_normalization (dict): The normalization parameters for the structured latent. + tex_slat_normalization (dict): The normalization parameters for the texture latent. + image_cond_model (trainers.Trainer): The image conditioning model. + rembg_model (Callable): The model for removing background. + low_vram (bool): Whether to use low-VRAM mode. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + sparse_structure_sampler: samplers.Sampler = None, + shape_slat_sampler: samplers.Sampler = None, + tex_slat_sampler: samplers.Sampler = None, + sparse_structure_sampler_params: dict = None, + shape_slat_sampler_params: dict = None, + tex_slat_sampler_params: dict = None, + shape_slat_normalization: dict = None, + tex_slat_normalization: dict = None, + image_cond_model: Callable = None, + rembg_model: Callable = None, + low_vram: bool = True, + default_pipeline_type: str = '512->1024', + ): + if models is None: + return + super().__init__(models) + self.sparse_structure_sampler = sparse_structure_sampler + self.shape_slat_sampler = shape_slat_sampler + self.tex_slat_sampler = tex_slat_sampler + self.sparse_structure_sampler_params = sparse_structure_sampler_params + self.shape_slat_sampler_params = shape_slat_sampler_params + self.tex_slat_sampler_params = tex_slat_sampler_params + self.shape_slat_normalization = shape_slat_normalization + self.tex_slat_normalization = tex_slat_normalization + self.image_cond_model = image_cond_model + self.rembg_model = rembg_model + self.low_vram = low_vram + self.default_pipeline_type = default_pipeline_type + self.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + self._device = 'cpu' + + @staticmethod + def from_pretrained(path: str) -> "Trellis2ImageTo3DPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super(Trellis2ImageTo3DPipeline, Trellis2ImageTo3DPipeline).from_pretrained(path) + new_pipeline = Trellis2ImageTo3DPipeline() + new_pipeline.__dict__ = pipeline.__dict__ + args = pipeline._pretrained_args + + new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + + new_pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args']) + new_pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params'] + + new_pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) + new_pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] + + new_pipeline.shape_slat_normalization = args['shape_slat_normalization'] + new_pipeline.tex_slat_normalization = args['tex_slat_normalization'] + + new_pipeline.image_cond_model = getattr(trainers, args['image_cond_model']['name'])(**args['image_cond_model']['args']) + new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) + + new_pipeline.low_vram = args.get('low_vram', True) + new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '512->1024') + new_pipeline.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + new_pipeline._device = 'cpu' + + return new_pipeline + + def to(self, device: torch.device) -> None: + self._device = device + if not self.low_vram: + super().to(device) + self.image_cond_model.to(device) + self.rembg_model.to(device) + + def preprocess_image(self, input: Image.Image) -> Image.Image: + """ + Preprocess the input image. + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if has_alpha: + output = input + else: + input = input.convert('RGB') + if self.low_vram: + self.rembg_model.to(self.device) + output = self.rembg_model(input) + if self.low_vram: + self.rembg_model.cpu() + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + output = Image.fromarray((output * 255).astype(np.uint8)) + return output + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + self.image_cond_model.image_size = resolution + if self.low_vram: + self.image_cond_model.to(self.device) + cond = self.image_cond_model(image) + if self.low_vram: + self.image_cond_model.cpu() + if not include_neg_cond: + return {'cond': cond} + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def sample_sparse_structure( + self, + cond: dict, + resolution: int, + num_samples: int = 1, + sampler_params: dict = {}, + ) -> torch.Tensor: + """ + Sample sparse structures with the given conditioning. + + Args: + cond (dict): The conditioning information. + resolution (int): The resolution of the sparse structure. + num_samples (int): The number of samples to generate. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample sparse structure latent + flow_model = self.models['sparse_structure_flow_model'] + reso = flow_model.resolution + in_channels = flow_model.in_channels + noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device) + sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + z_s = self.sparse_structure_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling sparse structure", + ).samples + if self.low_vram: + flow_model.cpu() + + # Decode sparse structure latent + decoder = self.models['sparse_structure_decoder'] + if self.low_vram: + decoder.to(self.device) + decoded = decoder(z_s)>0 + if self.low_vram: + decoder.cpu() + if resolution != decoded.shape[2]: + ratio = decoded.shape[2] // resolution + decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 + coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() + + return coords + + def sample_shape_slat( + self, + cond: dict, + flow_model, + coords: torch.Tensor, + sampler_params: dict = {}, + ) -> sp.SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + noise = sp.SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def sample_shape_slat_cascade( + self, + lr_cond: dict, + cond: dict, + flow_model_lr, + flow_model, + lr_resolution: int, + resolution: int, + coords: torch.Tensor, + sampler_params: dict = {}, + max_num_tokens: int = 49152, + ) -> sp.SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # LR + noise = sp.SparseTensor( + feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model_lr.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model_lr, + noise, + **lr_cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model_lr.cpu() + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + # Upsample + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + hr_coords = self.models['shape_slat_decoder'].upsample(slat, upsample_times=4) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + hr_resolution = resolution + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + coords = quant_coords.unique(dim=0) + num_tokens = coords.shape[0] + if num_tokens < max_num_tokens or hr_resolution == 1024: + if hr_resolution != resolution: + print(f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}.") + break + hr_resolution -= 128 + + # Sample structured latent + noise = sp.SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat, hr_resolution + + def decode_shape_slat( + self, + slat: sp.SparseTensor, + resolution: int, + ) -> Tuple[List[Mesh], List[sp.SparseTensor]]: + """ + Decode the structured latent. + + Args: + slat (sp.SparseTensor): The structured latent. + formats (List[str]): The formats to decode the structured latent to. + + Returns: + List[Mesh]: The decoded meshes. + List[sp.SparseTensor]: The decoded substructures. + """ + self.models['shape_slat_decoder'].set_resolution(resolution) + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + ret = self.models['shape_slat_decoder'](slat, return_subs=True) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + return ret + + def sample_tex_slat( + self, + cond: dict, + flow_model, + shape_slat: sp.SparseTensor, + sampler_params: dict = {}, + ) -> sp.SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + shape_slat (sp.SparseTensor): The structured latent for shape + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device) + shape_slat = (shape_slat - mean) / std + + in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels + noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device)) + sampler_params = {**self.tex_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.tex_slat_sampler.sample( + flow_model, + noise, + concat_cond=shape_slat, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling texture SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def decode_tex_slat( + self, + slat: sp.SparseTensor, + subs: List[sp.SparseTensor], + ) -> sp.SparseTensor: + """ + Decode the structured latent. + + Args: + slat (sp.SparseTensor): The structured latent. + formats (List[str]): The formats to decode the structured latent to. + + Returns: + List[sp.SparseTensor]: The decoded texture voxels + """ + if self.low_vram: + self.models['tex_slat_decoder'].to(self.device) + ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5 + if self.low_vram: + self.models['tex_slat_decoder'].cpu() + return ret + + @torch.no_grad() + def decode_latent( + self, + shape_slat: sp.SparseTensor, + tex_slat: sp.SparseTensor, + resolution: int, + ) -> List[MeshWithVoxel]: + """ + Decode the latent codes. + + Args: + shape_slat (sp.SparseTensor): The structured latent for shape. + tex_slat (sp.SparseTensor): The structured latent for texture. + resolution (int): The resolution of the output. + """ + meshes, subs = self.decode_shape_slat(shape_slat, resolution) + tex_voxels = self.decode_tex_slat(tex_slat, subs) + out_mesh = [] + for m, v in zip(meshes, tex_voxels): + m.fill_holes() + out_mesh.append( + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout=self.pbr_attr_layout + ) + ) + return out_mesh + + @torch.no_grad() + def run( + self, + image: Image.Image, + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + shape_slat_sampler_params: dict = {}, + tex_slat_sampler_params: dict = {}, + preprocess_image: bool = True, + return_latent: bool = False, + pipeline_type: Optional[str] = None, + max_num_tokens: int = 49152, + ) -> List[MeshWithVoxel]: + """ + Run the pipeline. + + Args: + image (Image.Image): The image prompt. + num_samples (int): The number of samples to generate. + seed (int): The random seed. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler. + tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler. + preprocess_image (bool): Whether to preprocess the image. + return_latent (bool): Whether to return the latent codes. + pipeline_type (str): The type of the pipeline. Options: '512', '1024', '512->1024', '512->1536'. + max_num_tokens (int): The maximum number of tokens to use. + """ + # Check pipeline type + pipeline_type = pipeline_type or self.default_pipeline_type + if pipeline_type == '512': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_512' in self.models, "No 512 resolution texture SLat flow model found." + elif pipeline_type == '1024': + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + elif pipeline_type == '512->1024': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + elif pipeline_type == '512->1536': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + else: + raise ValueError(f"Invalid pipeline type: {pipeline_type}") + + if preprocess_image: + image = self.preprocess_image(image) + torch.manual_seed(seed) + cond_512 = self.get_cond([image], 512) + cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None + ss_res = {'512': 32, '1024': 64, '512->1024': 32, '512->1536': 32}[pipeline_type] + coords = self.sample_sparse_structure( + cond_512, ss_res, + num_samples, sparse_structure_sampler_params + ) + if pipeline_type == '512': + shape_slat = self.sample_shape_slat( + cond_512, self.models['shape_slat_flow_model_512'], + coords, shape_slat_sampler_params + ) + tex_slat = self.sample_tex_slat( + cond_512, self.models['tex_slat_flow_model_512'], + shape_slat, tex_slat_sampler_params + ) + res = 512 + elif pipeline_type == '1024': + shape_slat = self.sample_shape_slat( + cond_1024, self.models['shape_slat_flow_model_1024'], + coords, shape_slat_sampler_params + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + res = 1024 + elif pipeline_type == '512->1024': + shape_slat, res = self.sample_shape_slat_cascade( + cond_512, cond_1024, + self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], + 512, 1024, + coords, shape_slat_sampler_params, + max_num_tokens + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + elif pipeline_type == '512->1536': + shape_slat, res = self.sample_shape_slat_cascade( + cond_512, cond_1024, + self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], + 512, 1536, + coords, shape_slat_sampler_params, + max_num_tokens + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + torch.cuda.empty_cache() + out_mesh = self.decode_latent(shape_slat, tex_slat, res) + if return_latent: + return out_mesh, (shape_slat, tex_slat, res) + else: + return out_mesh diff --git a/trellis2/pipelines/trellis2_image_to_tex.py b/trellis2/pipelines/trellis2_image_to_tex.py new file mode 100755 index 0000000000000000000000000000000000000000..ab8b777656e87ff93cd55febcd797584adab26f0 --- /dev/null +++ b/trellis2/pipelines/trellis2_image_to_tex.py @@ -0,0 +1,271 @@ +from typing import * +import numpy as np +import torch +from .. import _C +from flex_gemm.kernels import cuda as flexgemm_kernels + +__all__ = [ + "mesh_to_flexible_dual_grid", + "flexible_dual_grid_to_mesh", +] + +@torch.no_grad() +def mesh_to_flexible_dual_grid( + vertices: torch.Tensor, + faces: torch.Tensor, + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + aabb: Union[list, tuple, np.ndarray, torch.Tensor] = None, + face_weight: float = 1.0, + boundary_weight: float = 1.0, + regularization_weight: float = 0.1, + timing: bool = False, +) -> Union[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Voxelize a mesh into a sparse voxel grid. + + Args: + vertices (torch.Tensor): The vertices of the mesh. + faces (torch.Tensor): The faces of the mesh. + voxel_size (float, list, tuple, np.ndarray, torch.Tensor): The size of each voxel. + grid_size (int, list, tuple, np.ndarray, torch.Tensor): The size of the grid. + NOTE: One of voxel_size and grid_size must be provided. + aabb (list, tuple, np.ndarray, torch.Tensor): The axis-aligned bounding box of the mesh. + If not provided, it will be computed automatically. + face_weight (float): The weight of the face term in the dual contouring algorithm. + boundary_weight (float): The weight of the boundary term in the dual contouring algorithm. + regularization_weight (float): The weight of the regularization term in the dual contouring algorithm. + timing (bool): Whether to time the voxelization process. + + Returns: + torch.Tensor: The indices of the voxels that are occupied by the mesh. + The shape of the tensor is (N, 3), where N is the number of occupied voxels. + torch.Tensor: The dual vertices of the mesh. + torch.Tensor: The intersected flag of each voxel. + """ + + # Load mesh + vertices = vertices.float() + faces = faces.int() + + # Voxelize settings + assert voxel_size is not None or grid_size is not None, "Either voxel_size or grid_size must be provided" + + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32) + assert isinstance(voxel_size, torch.Tensor), f"voxel_size must be a float, list, tuple, np.ndarray, or torch.Tensor, but got {type(voxel_size)}" + assert voxel_size.dim() == 1, f"voxel_size must be a 1D tensor, but got {voxel_size.shape}" + assert voxel_size.size(0) == 3, f"voxel_size must have 3 elements, but got {voxel_size.size(0)}" + + if grid_size is not None: + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32) + assert isinstance(grid_size, torch.Tensor), f"grid_size must be an int, list, tuple, np.ndarray, or torch.Tensor, but got {type(grid_size)}" + assert grid_size.dim() == 1, f"grid_size must be a 1D tensor, but got {grid_size.shape}" + assert grid_size.size(0) == 3, f"grid_size must have 3 elements, but got {grid_size.size(0)}" + + if aabb is not None: + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32) + assert isinstance(aabb, torch.Tensor), f"aabb must be a list, tuple, np.ndarray, or torch.Tensor, but got {type(aabb)}" + assert aabb.dim() == 2, f"aabb must be a 2D tensor, but got {aabb.shape}" + assert aabb.size(0) == 2, f"aabb must have 2 rows, but got {aabb.size(0)}" + assert aabb.size(1) == 3, f"aabb must have 3 columns, but got {aabb.size(1)}" + + # Auto adjust aabb + if aabb is None: + min_xyz = vertices.min(dim=0).values + max_xyz = vertices.max(dim=0).values + + if voxel_size is not None: + padding = torch.ceil((max_xyz - min_xyz) / voxel_size) * voxel_size - (max_xyz - min_xyz) + min_xyz -= padding * 0.5 + max_xyz += padding * 0.5 + if grid_size is not None: + padding = (max_xyz - min_xyz) / (grid_size - 1) + min_xyz -= padding * 0.5 + max_xyz += padding * 0.5 + + aabb = torch.stack([min_xyz, max_xyz], dim=0).float().cuda() + + # Fill voxel size or grid size + if voxel_size is None: + voxel_size = (aabb[1] - aabb[0]) / grid_size + if grid_size is None: + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + + # subdivide mesh + vertices = vertices - aabb[0].reshape(1, 3) + grid_range = torch.stack([torch.zeros_like(grid_size), grid_size], dim=0).int() + + ret = _C.mesh_to_flexible_dual_grid_cpu( + vertices, + faces, + voxel_size, + grid_range, + face_weight, + boundary_weight, + regularization_weight, + timing, + ) + + return ret + + +def flexible_dual_grid_to_mesh( + coords: torch.Tensor, + dual_vertices: torch.Tensor, + intersected_flag: torch.Tensor, + split_weight: Union[torch.Tensor, None], + aabb: Union[list, tuple, np.ndarray, torch.Tensor], + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + train: bool = False, +): + """ + Extract mesh from sparse voxel structures using flexible dual grid. + + Args: + coords (torch.Tensor): The coordinates of the voxels. + dual_vertices (torch.Tensor): The dual vertices. + intersected_flag (torch.Tensor): The intersected flag. + split_weight (torch.Tensor): The split weight of each dual quad. If None, the algorithm + will split based on minimum angle. + aabb (list, tuple, np.ndarray, torch.Tensor): The axis-aligned bounding box of the mesh. + voxel_size (float, list, tuple, np.ndarray, torch.Tensor): The size of each voxel. + grid_size (int, list, tuple, np.ndarray, torch.Tensor): The size of the grid. + NOTE: One of voxel_size and grid_size must be provided. + train (bool): Whether to use training mode. + + Returns: + vertices (torch.Tensor): The vertices of the mesh. + faces (torch.Tensor): The faces of the mesh. + """ + # Static variables + if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset"): + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([ + [[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis + [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis + ], dtype=torch.int, device=coords.device).unsqueeze(0) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1"): + flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=coords.device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2"): + flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=coords.device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train"): + flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=coords.device, requires_grad=False) + + # AABB + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device) + assert isinstance(aabb, torch.Tensor), f"aabb must be a list, tuple, np.ndarray, or torch.Tensor, but got {type(aabb)}" + assert aabb.dim() == 2, f"aabb must be a 2D tensor, but got {aabb.shape}" + assert aabb.size(0) == 2, f"aabb must have 2 rows, but got {aabb.size(0)}" + assert aabb.size(1) == 3, f"aabb must have 3 columns, but got {aabb.size(1)}" + + # Voxel size + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=coords.device) + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + else: + assert grid_size is not None, "Either voxel_size or grid_size must be provided" + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device) + voxel_size = (aabb[1] - aabb[0]) / grid_size + assert isinstance(voxel_size, torch.Tensor), f"voxel_size must be a float, list, tuple, np.ndarray, or torch.Tensor, but got {type(voxel_size)}" + assert voxel_size.dim() == 1, f"voxel_size must be a 1D tensor, but got {voxel_size.shape}" + assert voxel_size.size(0) == 3, f"voxel_size must have 3 elements, but got {voxel_size.size(0)}" + assert isinstance(grid_size, torch.Tensor), f"grid_size must be an int, list, tuple, np.ndarray, or torch.Tensor, but got {type(grid_size)}" + assert grid_size.dim() == 1, f"grid_size must be a 1D tensor, but got {grid_size.shape}" + assert grid_size.size(0) == 3, f"grid_size must have 3 elements, but got {grid_size.size(0)}" + + # Extract mesh + N = dual_vertices.shape[0] + mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5 + + # Store active voxels into hashmap + hashmap = torch.full((2 * int(2 * N),), 0xffffffff, dtype=torch.uint32, device=coords.device) + flexgemm_kernels.hashmap_insert_3d_idx_as_val_cuda(hashmap, torch.cat([torch.zeros_like(coords[:, :1]), coords], dim=-1), *grid_size.tolist()) + + # Find connected voxels + edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3) + connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3) + M = connected_voxel.shape[0] + connected_voxel_hash_key = torch.cat([ + torch.zeros((M * 4, 1), dtype=torch.int, device=coords.device), + connected_voxel.reshape(-1, 3) + ], dim=1) + connected_voxel_indices = flexgemm_kernels.hashmap_lookup_3d_cuda(hashmap, connected_voxel_hash_key, *grid_size.tolist()).reshape(M, 4).int() + connected_voxel_valid = (connected_voxel_indices != 0xffffffff).all(dim=1) + quad_indices = connected_voxel_indices[connected_voxel_valid].int() # (L, 4) + L = quad_indices.shape[0] + + # Construct triangles + if not train: + mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3) + if split_weight is None: + # if split 1 + atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1] + normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]], dim=1) + normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]], dim=1) + normals0 = normals0 / torch.norm(normals0, dim=1, keepdim=True) + normals1 = normals1 / torch.norm(normals1, dim=1, keepdim=True) + align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # if split 2 + atempt_triangles_1 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]], dim=1) + normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]], dim=1) + normals0 = normals0 / torch.norm(normals0, dim=1, keepdim=True) + normals1 = normals1 / torch.norm(normals1, dim=1, keepdim=True) + align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # select split + mesh_triangles = torch.where(align0 > align1, atempt_triangles_0, atempt_triangles_1).reshape(-1, 3) + else: + split_weight_ws = split_weight[quad_indices] + split_weight_ws_02 = split_weight_ws[:, 0] * split_weight_ws[:, 2] + split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3] + mesh_triangles = torch.where( + split_weight_ws_02 > split_weight_ws_13, + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1], + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + ).reshape(-1, 3) + else: + assert split_weight is not None, "split_weight must be provided in training mode" + mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3) + quad_vs = mesh_vertices[quad_indices] + mean_v02 = (quad_vs[:, 0] + quad_vs[:, 2]) / 2 + mean_v13 = (quad_vs[:, 1] + quad_vs[:, 3]) / 2 + split_weight_ws = split_weight[quad_indices] + split_weight_ws_02 = split_weight_ws[:, 0] * split_weight_ws[:, 2] + split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3] + mid_vertices = ( + split_weight_ws_02 * mean_v02 + + split_weight_ws_13 * mean_v13 + ) / (split_weight_ws_02 + split_weight_ws_13) + mesh_vertices = torch.cat([mesh_vertices, mid_vertices], dim=0) + quad_indices = torch.cat([quad_indices, torch.arange(N, N + L, device='cuda').unsqueeze(1)], dim=1) + mesh_triangles = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_train].reshape(-1, 3) + + return mesh_vertices, mesh_triangles diff --git a/trellis2/renderers/__init__.py b/trellis2/renderers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..de3203d1bb16065f912bff039e431f609911782d --- /dev/null +++ b/trellis2/renderers/__init__.py @@ -0,0 +1,33 @@ +import importlib + +__attributes = { + 'MeshRenderer': 'mesh_renderer', + 'VoxelRenderer': 'voxel_renderer', + 'PbrMeshRenderer': 'pbr_mesh_renderer', + 'EnvMap': 'pbr_mesh_renderer', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .mesh_renderer import MeshRenderer + from .voxel_renderer import VoxelRenderer + from .pbr_mesh_renderer import PbrMeshRenderer, EnvMap + \ No newline at end of file diff --git a/trellis2/renderers/mesh_renderer.py b/trellis2/renderers/mesh_renderer.py new file mode 100755 index 0000000000000000000000000000000000000000..7dce68c8b4d141f5d2e8766a5cd4346bd84695d7 --- /dev/null +++ b/trellis2/renderers/mesh_renderer.py @@ -0,0 +1,415 @@ +from typing import * +import torch +from easydict import EasyDict as edict +from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode +import torch.nn.functional as F + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = (far + near) / (far - near) + ret[2, 3] = 2 * near * far / (near - far) + ret[3, 2] = 1. + return ret + + +class MeshRenderer: + """ + Renderer for the Mesh representation. + + Args: + rendering_options (dict): Rendering options. + glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop. + """ + def __init__(self, rendering_options={}, device='cuda'): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "chunk_size": None, + "antialias": True, + "clamp_barycentric_coords": False, + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device=device + + def render( + self, + mesh : Mesh, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_types = ["mask", "normal", "depth"], + transformation : Optional[torch.Tensor] = None + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + return_types (list): list of return types, can be "attr", "mask", "depth", "coord", "normal" + + Returns: + edict based on return_types containing: + attr (torch.Tensor): [C, H, W] rendered attr image + depth (torch.Tensor): [H, W] rendered depth image + normal (torch.Tensor): [3, H, W] rendered normal image + mask (torch.Tensor): [H, W] rendered mask image + """ + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + chunk_size = self.rendering_options["chunk_size"] + antialias = self.rendering_options["antialias"] + clamp_barycentric_coords = self.rendering_options["clamp_barycentric_coords"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + ret_dict = edict() + for type in return_types: + if type == "mask" : + ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "depth": + ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "normal": + ret_dict[type] = torch.full((3, resolution, resolution), 0.5, dtype=torch.float32, device=self.device) + elif type == "coord": + ret_dict[type] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + ret_dict[type] = torch.zeros((mesh.attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device) + else: + ret_dict[type] = torch.zeros((mesh.vertex_attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device) + return ret_dict + + perspective = intrinsics_to_projection(intrinsics, near, far) + + full_proj = (perspective @ extrinsics).unsqueeze(0) + extrinsics = extrinsics.unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + if transformation is not None: + vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2)) + vertices = vertices_homo[..., :3].contiguous() + vertices_camera = torch.bmm(vertices_homo, extrinsics.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces = mesh.faces + + if 'normal' in return_types: + v0 = vertices_camera[0, mesh.faces[:, 0], :3] + v1 = vertices_camera[0, mesh.faces[:, 1], :3] + v2 = vertices_camera[0, mesh.faces[:, 2], :3] + e0 = v1 - v0 + e1 = v2 - v0 + face_normal = torch.cross(e0, e1, dim=1) + face_normal = F.normalize(face_normal, dim=1) + face_normal = torch.where(torch.sum(face_normal * v0, dim=1, keepdim=True) > 0, face_normal, -face_normal) + + out_dict = edict() + if chunk_size is None: + rast, rast_db = dr.rasterize( + self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa) + ) + if clamp_barycentric_coords: + rast[..., :2] = torch.clamp(rast[..., :2], 0, 1) + rast[..., :2] /= torch.where(rast[..., :2].sum(dim=-1, keepdim=True) > 1, rast[..., :2].sum(dim=-1, keepdim=True), torch.ones_like(rast[..., :2])) + for type in return_types: + img = None + if type == "mask" : + img = (rast[..., -1:] > 0).float() + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "normal" : + img = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + img = (img + 1) / 2 + elif type == "coord": + img = dr.interpolate(vertices, rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices, rast, faces)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + imgs = { + 'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device), + 'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + } + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + base_color = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['base_color'] += base_color * mat.base_color_factor * mat_mask + else: + imgs['base_color'] += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + metallic = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['metallic'] += metallic * mat.metallic_factor * mat_mask + else: + imgs['metallic'] += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + roughness = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['roughness'] += roughness * mat.roughness_factor * mat_mask + else: + imgs['roughness'] += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + imgs['alpha'] += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + alpha = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += alpha * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += mat.alpha_factor * mat_mask + + img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0) + else: + img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + + out_dict[type] = img + else: + z_buffer = torch.full((1, resolution * ssaa, resolution * ssaa), torch.inf, device=self.device, dtype=torch.float32) + for i in range(0, faces.shape[0], chunk_size): + faces_chunk = faces[i:i+chunk_size] + rast, rast_db = dr.rasterize( + self.glctx, vertices_clip, faces_chunk, (resolution * ssaa, resolution * ssaa) + ) + z_filter = torch.logical_and( + rast[..., 3] != 0, + rast[..., 2] < z_buffer + ) + z_buffer[z_filter] = rast[z_filter][..., 2] + + for type in return_types: + img = None + if type == "mask" : + img = (rast[..., -1:] > 0).float() + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_chunk)[0] + elif type == "normal" : + face_normal_chunk = face_normal[i:i+chunk_size] + img = dr.interpolate(face_normal_chunk.unsqueeze(0), rast, torch.arange(face_normal_chunk.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0] + img = (img + 1) / 2 + elif type == "coord": + img = dr.interpolate(vertices, rast, faces_chunk)[0] + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices, rast, faces_chunk)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + imgs = { + 'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device), + 'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + } + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + base_color = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['base_color'] += base_color * mat.base_color_factor * mat_mask + else: + imgs['base_color'] += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + metallic = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['metallic'] += metallic * mat.metallic_factor * mat_mask + else: + imgs['metallic'] += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + roughness = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['roughness'] += roughness * mat.roughness_factor * mat_mask + else: + imgs['roughness'] += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + imgs['alpha'] += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + alpha = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += alpha * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += mat.alpha_factor * mat_mask + + img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0) + else: + img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces_chunk)[0] + + if type not in out_dict: + out_dict[type] = img + else: + out_dict[type][z_filter] = img[z_filter] + + for type in return_types: + img = out_dict[type] + if ssaa > 1: + img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + img = img.squeeze() + else: + img = img.permute(0, 3, 1, 2).squeeze() + out_dict[type] = img + + if isinstance(mesh, (MeshWithVoxel, MeshWithPbrMaterial)) and 'attr' in return_types: + for k, s in mesh.layout.items(): + out_dict[k] = out_dict['attr'][s] + del out_dict['attr'] + + return out_dict diff --git a/trellis2/renderers/pbr_mesh_renderer.py b/trellis2/renderers/pbr_mesh_renderer.py new file mode 100755 index 0000000000000000000000000000000000000000..047cb67df4d2f85738a1462dc117865ce8b96fc1 --- /dev/null +++ b/trellis2/renderers/pbr_mesh_renderer.py @@ -0,0 +1,374 @@ +from typing import * +import torch +from easydict import EasyDict as edict +import numpy as np +import utils3d +from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode +import torch.nn.functional as F + + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -x, -y + elif s == 1: rx, ry, rz = -torch.ones_like(x), x, -y + elif s == 2: rx, ry, rz = x, y, torch.ones_like(x) + elif s == 3: rx, ry, rz = x, -y, -torch.ones_like(x) + elif s == 4: rx, ry, rz = x, torch.ones_like(x), -y + elif s == 5: rx, ry, rz = -x, -torch.ones_like(x), -y + return torch.stack((rx, ry, rz), dim=-1) + + +def latlong_to_cubemap(latlong_map, res): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = F.normalize(cube_to_dir(s, gx, gy), dim=-1) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + + +class EnvMap: + def __init__(self, image: torch.Tensor): + if 'EnvironmentLight' not in globals(): + from nvdiffrec_render.light import EnvironmentLight + cubemap = latlong_to_cubemap(image, [512, 512]) + self._backend = EnvironmentLight(cubemap) + self._backend.build_mips() + + def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True): + return self._backend.shade(gb_pos, gb_normal, kd, ks, view_pos, specular) + + def sample(self, directions: torch.Tensor): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + return dr.texture( + self._backend.base.unsqueeze(0), + directions.unsqueeze(0), + boundary_mode='cube', + )[0] + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = (far + near) / (far - near) + ret[2, 3] = 2 * near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def aces_tonemapping(x: torch.Tensor) -> torch.Tensor: + """ + Applies ACES tone mapping curve to an HDR image tensor. + Input: x - HDR tensor, shape (..., 3), range [0, +inf) + Output: LDR tensor, same shape, range [0, 1] + """ + a = 2.51 + b = 0.03 + c = 2.43 + d = 0.59 + e = 0.14 + + # Apply the ACES fitted curve + mapped = (x * (a * x + b)) / (x * (c * x + d) + e) + + # Clamp to [0, 1] for display or saving + return torch.clamp(mapped, 0.0, 1.0) + + +def gamma_correction(x: torch.Tensor, gamma: float = 2.2) -> torch.Tensor: + """ + Applies gamma correction to an HDR image tensor. + """ + return torch.clamp(x ** (1.0 / gamma), 0.0, 1.0) + + +class PbrMeshRenderer: + """ + Renderer for the PBR mesh. + + Args: + rendering_options (dict): Rendering options. + glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop. + """ + def __init__(self, rendering_options={}, device='cuda'): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "peel_layers": 8, + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeGLContext(device=device) + self.device=device + + def render( + self, + mesh : Mesh, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + envmap : EnvMap, + transformation : Optional[torch.Tensor] = None + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + envmap : EnvMap + transformation (torch.Tensor): (4, 4) transformation matrix + + Returns: + edict based on return_types containing: + shaded (torch.Tensor): [3, H, W] shaded color image + normal (torch.Tensor): [3, H, W] normal image + base_color (torch.Tensor): [3, H, W] base color image + metallic (torch.Tensor): [H, W] metallic image + roughness (torch.Tensor): [H, W] roughness image + """ + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + return edict( + shaded=torch.full((4, resolution, resolution), 0.5, dtype=torch.float32, device=self.device), + ) + + rays_o, rays_d = utils3d.torch.get_image_rays( + extrinsics, intrinsics, resolution * ssaa, resolution * ssaa + ) + + perspective = intrinsics_to_projection(intrinsics, near, far) + + full_proj = (perspective @ extrinsics).unsqueeze(0) + extrinsics = extrinsics.unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + vertices_orig = vertices.clone() + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + if transformation is not None: + vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2)) + vertices = vertices_homo[..., :3].contiguous() + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces = mesh.faces + + v0 = vertices[0, mesh.faces[:, 0], :3] + v1 = vertices[0, mesh.faces[:, 1], :3] + v2 = vertices[0, mesh.faces[:, 2], :3] + e0 = v1 - v0 + e1 = v2 - v0 + face_normal = torch.cross(e0, e1, dim=1) + face_normal = F.normalize(face_normal, dim=1) + + out_dict = edict() + shaded = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + with dr.DepthPeeler(self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa)) as peeler: + for _ in range(self.rendering_options["peel_layers"]): + rast, rast_db = peeler.rasterize_next_layer() + + # Pos + pos = dr.interpolate(vertices, rast, faces)[0][0] + + # Normal + gb_normal = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0][0] + gb_normal = torch.where( + torch.sum(gb_normal * (pos - rays_o), dim=-1, keepdim=True) > 0, + -gb_normal, + gb_normal + ) + if _ == 0: + cam_normal = extrinsics[..., :3, :3].reshape(1, 1, 3, 3) @ gb_normal.unsqueeze(-1) + cam_normal = -cam_normal.squeeze(-1) * 0.5 + 0.5 + out_dict.normal = cam_normal + mask = (rast[0, ..., -1:] > 0).float() + out_dict.mask = mask + + # PBR attributes + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices_orig, rast, faces)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + gb_basecolor = img[0, ..., mesh.layout['base_color']] + gb_metallic = img[0, ..., mesh.layout['metallic']] + gb_roughness = img[0, ..., mesh.layout['roughness']] + gb_alpha = img[0, ..., mesh.layout['alpha']] + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + gb_basecolor = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + gb_metallic = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + gb_roughness = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + gb_alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + bc = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_basecolor += bc * mat.base_color_factor * mat_mask + else: + gb_basecolor += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + m = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_metallic += m * mat.metallic_factor * mat_mask + else: + gb_metallic += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + r = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_roughness += r * mat.roughness_factor * mat_mask + else: + gb_roughness += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + gb_alpha += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + a = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + gb_alpha += (a * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + gb_alpha += a * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + gb_alpha += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + gb_alpha += mat.alpha_factor * mat_mask + if _ == 0: + out_dict.base_color = gb_basecolor + out_dict.metallic = gb_metallic + out_dict.roughness = gb_roughness + out_dict.alpha = gb_alpha + + # Shading + gb_basecolor = torch.clamp(gb_basecolor, 0.0, 1.0) ** 2.2 + gb_metallic = torch.clamp(gb_metallic, 0.0, 1.0) + gb_roughness = torch.clamp(gb_roughness, 0.0, 1.0) + gb_alpha = torch.clamp(gb_alpha, 0.0, 1.0) + gb_orm = torch.cat([ + torch.zeros_like(gb_metallic), + gb_roughness, + gb_metallic, + ], dim=-1) + gb_shaded = envmap.shade( + pos.unsqueeze(0), + gb_normal.unsqueeze(0), + gb_basecolor.unsqueeze(0), + gb_orm.unsqueeze(0), + rays_o, + specular=True, + )[0] + + # Alpha blend + w = (1 - alpha) * gb_alpha + shaded += w * gb_shaded + alpha += w + + # Background + bg = envmap.sample(rays_d) + shaded += (1 - alpha) * bg + + out_dict.shaded = shaded + + # SSAA + for k in out_dict.keys(): + if ssaa > 1: + out_dict[k] = F.interpolate(out_dict[k].unsqueeze(0).permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + else: + out_dict[k] = out_dict[k].permute(2, 0, 1) + out_dict[k] = out_dict[k].squeeze() + + # Post processing + out_dict.shaded = aces_tonemapping(out_dict.shaded) + out_dict.shaded = gamma_correction(out_dict.shaded) + + return out_dict diff --git a/trellis2/renderers/voxel_renderer.py b/trellis2/renderers/voxel_renderer.py new file mode 100755 index 0000000000000000000000000000000000000000..dfe28ad8d341ed62c1d7a5ab739fed6cace30a5f --- /dev/null +++ b/trellis2/renderers/voxel_renderer.py @@ -0,0 +1,68 @@ +import torch +from easydict import EasyDict as edict +from ..representations import Voxel +from easydict import EasyDict as edict + + +class VoxelRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.rendering_options = edict({ + "resolution": None, + "near": 0.1, + "far": 10.0, + "ssaa": 1, + }) + self.rendering_options.update(rendering_options) + + def render( + self, + voxel: Voxel, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None + ) -> edict: + """ + Render the gausssian. + + Args: + voxel (Voxel): Voxel representation. + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color image + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + ... + """ + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + renderer = o_voxel.rasterize.VoxelRenderer(self.rendering_options) + positions = voxel.position + attrs = voxel.attrs if colors_overwrite is None else colors_overwrite + voxel_size = voxel.voxel_size + + # Render + render_ret = renderer.render(positions, attrs, voxel_size, extrinsics, intrinsics) + + ret = { + 'depth': render_ret['depth'], + 'alpha': render_ret['alpha'], + } + if colors_overwrite is not None: + ret['color'] = render_ret['attr'] + else: + for k, s in voxel.layout.items(): + ret[k] = render_ret['attr'][s] + + return ret diff --git a/trellis2/representations/__init__.py b/trellis2/representations/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..0e7d9299f866c344e81e27f748f837e5ce81ed8b --- /dev/null +++ b/trellis2/representations/__init__.py @@ -0,0 +1,31 @@ +import importlib + +__attributes = { + 'Mesh': 'mesh', + 'Voxel': 'voxel', + 'MeshWithVoxel': 'mesh', + 'MeshWithPbrMaterial': 'mesh', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial + from .voxel import Voxel diff --git a/trellis2/representations/mesh/__init__.py b/trellis2/representations/mesh/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..aff4c99a97f764a7c695d87d6a60bd03d61e2106 --- /dev/null +++ b/trellis2/representations/mesh/__init__.py @@ -0,0 +1 @@ +from .base import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture diff --git a/trellis2/representations/mesh/base.py b/trellis2/representations/mesh/base.py new file mode 100755 index 0000000000000000000000000000000000000000..b70e4cca19d65b47d07f297894b831d3506be5ff --- /dev/null +++ b/trellis2/representations/mesh/base.py @@ -0,0 +1,234 @@ +from typing import * +import torch +from ..voxel import Voxel +import cumesh +from flex_gemm.ops.grid_sample import grid_sample_3d + + +class Mesh: + def __init__(self, + vertices, + faces, + vertex_attrs=None + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.vertex_attrs = vertex_attrs + + @property + def device(self): + return self.vertices.device + + def to(self, device, non_blocking=False): + return Mesh( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None, + ) + + def cuda(self, non_blocking=False): + return self.to('cuda', non_blocking=non_blocking) + + def cpu(self): + return self.to('cpu') + + def fill_holes(self, max_hole_perimeter=3e-2): + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.get_edges() + mesh.get_boundary_info() + if mesh.num_boundaries == 0: + return + mesh.get_vertex_edge_adjacency() + mesh.get_vertex_boundary_adjacency() + mesh.get_manifold_boundary_adjacency() + mesh.read_manifold_boundary_adjacency() + mesh.get_boundary_connected_components() + mesh.get_boundary_loops() + if mesh.num_boundary_loops == 0: + return + mesh.fill_holes(max_hole_perimeter=max_hole_perimeter) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + def remove_faces(self, face_mask: torch.Tensor): + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.remove_faces(face_mask) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + def simplify(self, target=1000000, verbose: bool=False, options: dict={}): + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.simplify(target, verbose=verbose, options=options) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + +class TextureFilterMode: + CLOSEST = 0 + LINEAR = 1 + + +class TextureWrapMode: + CLAMP_TO_EDGE = 0 + REPEAT = 1 + MIRRORED_REPEAT = 2 + + +class AlphaMode: + OPAQUE = 0 + MASK = 1 + BLEND = 2 + + +class Texture: + def __init__( + self, + image: torch.Tensor, + filter_mode: TextureFilterMode = TextureFilterMode.LINEAR, + wrap_mode: TextureWrapMode = TextureWrapMode.REPEAT + ): + self.image = image + self.filter_mode = filter_mode + self.wrap_mode = wrap_mode + + def to(self, device, non_blocking=False): + return Texture( + self.image.to(device, non_blocking=non_blocking), + self.filter_mode, + self.wrap_mode, + ) + + +class PbrMaterial: + def __init__( + self, + base_color_texture: Optional[Texture] = None, + base_color_factor: Union[torch.Tensor, List[float]] = [1.0, 1.0, 1.0], + metallic_texture: Optional[Texture] = None, + metallic_factor: float = 1.0, + roughness_texture: Optional[Texture] = None, + roughness_factor: float = 1.0, + alpha_texture: Optional[Texture] = None, + alpha_factor: float = 1.0, + alpha_mode: AlphaMode = AlphaMode.OPAQUE, + alpha_cutoff: float = 0.5, + ): + self.base_color_texture = base_color_texture + self.base_color_factor = torch.tensor(base_color_factor, dtype=torch.float32)[:3] + self.metallic_texture = metallic_texture + self.metallic_factor = metallic_factor + self.roughness_texture = roughness_texture + self.roughness_factor = roughness_factor + self.alpha_texture = alpha_texture + self.alpha_factor = alpha_factor + self.alpha_mode = alpha_mode + self.alpha_cutoff = alpha_cutoff + + def to(self, device, non_blocking=False): + return PbrMaterial( + base_color_texture=self.base_color_texture.to(device, non_blocking=non_blocking) if self.base_color_texture is not None else None, + base_color_factor=self.base_color_factor.to(device, non_blocking=non_blocking), + metallic_texture=self.metallic_texture.to(device, non_blocking=non_blocking) if self.metallic_texture is not None else None, + metallic_factor=self.metallic_factor, + roughness_texture=self.roughness_texture.to(device, non_blocking=non_blocking) if self.roughness_texture is not None else None, + roughness_factor=self.roughness_factor, + alpha_texture=self.alpha_texture.to(device, non_blocking=non_blocking) if self.alpha_texture is not None else None, + alpha_factor=self.alpha_factor, + alpha_mode=self.alpha_mode, + alpha_cutoff=self.alpha_cutoff, + ) + + +class MeshWithPbrMaterial(Mesh): + def __init__(self, + vertices, + faces, + material_ids, + uv_coords, + materials: List[PbrMaterial], + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.material_ids = material_ids # [M] + self.uv_coords = uv_coords # [M, 3, 2] + self.materials = materials + self.layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + + def to(self, device, non_blocking=False): + return MeshWithPbrMaterial( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.material_ids.to(device, non_blocking=non_blocking), + self.uv_coords.to(device, non_blocking=non_blocking), + [material.to(device, non_blocking=non_blocking) for material in self.materials], + ) + + +class MeshWithVoxel(Mesh, Voxel): + def __init__(self, + vertices: torch.Tensor, + faces: torch.Tensor, + origin: list, + voxel_size: float, + coords: torch.Tensor, + attrs: torch.Tensor, + voxel_shape: torch.Size, + layout: Dict = {}, + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.voxel_shape = voxel_shape + self.layout = layout + + def to(self, device, non_blocking=False): + return MeshWithVoxel( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.origin.tolist(), + self.voxel_size, + self.coords.to(device, non_blocking=non_blocking), + self.attrs.to(device, non_blocking=non_blocking), + self.voxel_shape, + self.layout, + ) + + def query_attrs(self, xyz): + grid = ((xyz - self.origin) / self.voxel_size).reshape(1, -1, 3) + vertex_attrs = grid_sample_3d( + self.attrs, + torch.cat([torch.zeros_like(self.coords[..., :1]), self.coords], dim=-1), + self.voxel_shape, + grid, + mode='trilinear' + )[0] + return vertex_attrs + + def query_vertex_attrs(self): + return self.query_attrs(self.vertices) diff --git a/trellis2/representations/voxel/__init__.py b/trellis2/representations/voxel/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b5792ea14b2371a96c4130371eb976f0aff4b5dd --- /dev/null +++ b/trellis2/representations/voxel/__init__.py @@ -0,0 +1 @@ +from .voxel_model import Voxel \ No newline at end of file diff --git a/trellis2/representations/voxel/voxel_model.py b/trellis2/representations/voxel/voxel_model.py new file mode 100755 index 0000000000000000000000000000000000000000..9317ab22db61c5da96514ca46a780378392f3cc8 --- /dev/null +++ b/trellis2/representations/voxel/voxel_model.py @@ -0,0 +1,54 @@ +from typing import Dict +import torch + + +class Voxel: + def __init__( + self, + origin: list, + voxel_size: float, + coords: torch.Tensor = None, + attrs: torch.Tensor = None, + layout: Dict = {}, + device: torch.device = 'cuda' + ): + self.origin = torch.tensor(origin, dtype=torch.float32, device=device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.layout = layout + self.device = device + + @property + def position(self): + return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] + + def split_attrs(self): + return { + k: self.attrs[:, self.layout[k]] + for k in self.layout + } + + def save(self, path): + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + o_voxel.io.write( + path, + self.coords, + self.split_attrs(), + ) + + def load(self, path): + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + coord, attrs = o_voxel.io.read(path) + self.coords = coord.int().to(self.device) + self.attrs = torch.cat([attrs[k] for k in attrs], dim=1).to(self.device) + # build layout + start = 0 + self.layout = {} + for k in attrs: + self.layout[k] = slice(start, start + attrs[k].shape[1]) + start += attrs[k].shape[1] diff --git a/trellis2/utils/__init__.py b/trellis2/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/trellis2/utils/data_utils.py b/trellis2/utils/data_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..805b6cc118106857e7ef767ab4bfd133dbd78e6f --- /dev/null +++ b/trellis2/utils/data_utils.py @@ -0,0 +1,226 @@ +from typing import * +import math +import torch +import numpy as np +from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler +import torch.distributed as dist + + +def recursive_to_device( + data: Any, + device: torch.device, + non_blocking: bool = False, +) -> Any: + """ + Recursively move all tensors in a data structure to a device. + """ + if hasattr(data, "to"): + return data.to(device, non_blocking=non_blocking) + elif isinstance(data, (list, tuple)): + return type(data)(recursive_to_device(d, device, non_blocking) for d in data) + elif isinstance(data, dict): + return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()} + else: + return data + + +def load_balanced_group_indices( + load: List[int], + num_groups: int, + equal_size: bool = False, +) -> List[List[int]]: + """ + Split indices into groups with balanced load. + """ + if equal_size: + group_size = len(load) // num_groups + indices = np.argsort(load)[::-1] + groups = [[] for _ in range(num_groups)] + group_load = np.zeros(num_groups) + for idx in indices: + min_group_idx = np.argmin(group_load) + groups[min_group_idx].append(idx) + if equal_size and len(groups[min_group_idx]) == group_size: + group_load[min_group_idx] = float('inf') + else: + group_load[min_group_idx] += load[idx] + return groups + + +def cycle(data_loader: DataLoader) -> Iterator: + while True: + for data in data_loader: + if isinstance(data_loader.sampler, ResumableSampler): + data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined] + yield data + if isinstance(data_loader.sampler, DistributedSampler): + data_loader.sampler.epoch += 1 + if isinstance(data_loader.sampler, ResumableSampler): + data_loader.sampler.epoch += 1 + data_loader.sampler.idx = 0 + + +class ResumableSampler(Sampler): + """ + Distributed sampler that is resumable. + + Args: + dataset: Dataset used for sampling. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + self.dataset = dataset + self.epoch = 0 + self.idx = 0 + self.drop_last = drop_last + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.rank = dist.get_rank() if dist.is_initialized() else 0 + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type] + self.total_size = self.num_samples * self.world_size + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.world_size] + + # resume from previous state + indices = indices[self.idx:] + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def state_dict(self) -> dict[str, int]: + return { + 'epoch': self.epoch, + 'idx': self.idx, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict['epoch'] + self.idx = state_dict['idx'] + + +class BalancedResumableSampler(ResumableSampler): + """ + Distributed sampler that is resumable and balances the load among the processes. + + Args: + dataset: Dataset used for sampling. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + batch_size: int = 1, + ) -> None: + assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler' + super().__init__(dataset, shuffle, seed, drop_last) + self.batch_size = batch_size + self.loads = dataset.loads + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # balance load among processes + num_batches = len(indices) // (self.batch_size * self.world_size) + balanced_indices = [] + for i in range(num_batches): + start_idx = i * self.batch_size * self.world_size + end_idx = (i + 1) * self.batch_size * self.world_size + batch_indices = indices[start_idx:end_idx] + batch_loads = [self.loads[idx] for idx in batch_indices] + groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True) + balanced_indices.extend([batch_indices[j] for j in groups[self.rank]]) + + # resume from previous state + indices = balanced_indices[self.idx:] + + return iter(indices) diff --git a/trellis2/utils/dist_utils.py b/trellis2/utils/dist_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..348799c064804ed3ae1a98144249bd9b6bfb9915 --- /dev/null +++ b/trellis2/utils/dist_utils.py @@ -0,0 +1,93 @@ +import os +import io +from contextlib import contextmanager +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + + +def setup_dist(rank, local_rank, world_size, master_addr, master_port): + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = master_port + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(local_rank) + torch.cuda.set_device(local_rank) + dist.init_process_group('nccl', rank=rank, world_size=world_size) + + +def read_file_dist(path): + """ + Read the binary file distributedly. + File is only read once by the rank 0 process and broadcasted to other processes. + + Returns: + data (io.BytesIO): The binary data read from the file. + """ + if dist.is_initialized() and dist.get_world_size() > 1: + # read file + size = torch.LongTensor(1).cuda() + if dist.get_rank() == 0: + with open(path, 'rb') as f: + data = f.read() + data = torch.ByteTensor( + torch.UntypedStorage.from_buffer(data, dtype=torch.uint8) + ).cuda() + size[0] = data.shape[0] + # broadcast size + dist.broadcast(size, src=0) + if dist.get_rank() != 0: + data = torch.ByteTensor(size[0].item()).cuda() + # broadcast data + dist.broadcast(data, src=0) + # convert to io.BytesIO + data = data.cpu().numpy().tobytes() + data = io.BytesIO(data) + return data + else: + with open(path, 'rb') as f: + data = f.read() + data = io.BytesIO(data) + return data + + +def unwrap_dist(model): + """ + Unwrap the model from distributed training. + """ + if isinstance(model, DDP): + return model.module + return model + + +@contextmanager +def master_first(): + """ + A context manager that ensures master process executes first. + """ + if not dist.is_initialized(): + yield + else: + if dist.get_rank() == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + + +@contextmanager +def local_master_first(): + """ + A context manager that ensures local master process executes first. + """ + if not dist.is_initialized(): + yield + else: + if dist.get_rank() % torch.cuda.device_count() == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + \ No newline at end of file diff --git a/trellis2/utils/elastic_utils.py b/trellis2/utils/elastic_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..cba3cf83836e5b58f5bc3333e809ffc932375a04 --- /dev/null +++ b/trellis2/utils/elastic_utils.py @@ -0,0 +1,228 @@ +from abc import abstractmethod +from contextlib import contextmanager +from typing import Tuple +import torch +import torch.nn as nn +import numpy as np + + +class MemoryController: + """ + Base class for memory management during training. + """ + + _last_input_size = None + _last_mem_ratio = [] + + @contextmanager + def record(self): + pass + + def update_run_states(self, input_size=None, mem_ratio=None): + if self._last_input_size is None: + self._last_input_size = input_size + elif self._last_input_size!= input_size: + raise ValueError(f'Input size should not change for different ElasticModules.') + self._last_mem_ratio.append(mem_ratio) + + @abstractmethod + def get_mem_ratio(self, input_size): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def log(self): + pass + + +class LinearMemoryController(MemoryController): + """ + A simple controller for memory management during training. + The memory usage is modeled as a linear function of: + - the number of input parameters + - the ratio of memory the model use compared to the maximum usage (with no checkpointing) + memory_usage = k * input_size * mem_ratio + b + The controller keeps track of the memory usage and gives the + expected memory ratio to keep the memory usage under a target + """ + def __init__( + self, + buffer_size=1000, + update_every=500, + target_ratio=0.8, + available_memory=None, + max_mem_ratio_start=0.1, + params=None, + device=None + ): + self.buffer_size = buffer_size + self.update_every = update_every + self.target_ratio = target_ratio + self.device = device or torch.cuda.current_device() + self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3 + + self._memory = np.zeros(buffer_size, dtype=np.float32) + self._input_size = np.zeros(buffer_size, dtype=np.float32) + self._mem_ratio = np.zeros(buffer_size, dtype=np.float32) + self._buffer_ptr = 0 + self._buffer_length = 0 + self._params = tuple(params) if params is not None else (0.0, 0.0) + self._max_mem_ratio = max_mem_ratio_start + self.step = 0 + + def __repr__(self): + return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})' + + def _add_sample(self, memory, input_size, mem_ratio): + self._memory[self._buffer_ptr] = memory + self._input_size[self._buffer_ptr] = input_size + self._mem_ratio[self._buffer_ptr] = mem_ratio + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + + @contextmanager + def record(self): + torch.cuda.reset_peak_memory_stats(self.device) + self._last_input_size = None + self._last_mem_ratio = [] + yield + self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3 + self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio) + self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio) + self.step += 1 + if self.step % self.update_every == 0: + self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1) + self._fit_params() + + def _fit_params(self): + memory_usage = self._memory[:self._buffer_length] + input_size = self._input_size[:self._buffer_length] + mem_ratio = self._mem_ratio[:self._buffer_length] + + x = input_size * mem_ratio + y = memory_usage + k, b = np.polyfit(x, y, 1) + self._params = (k, b) + # self._visualize() + + def _visualize(self): + import matplotlib.pyplot as plt + memory_usage = self._memory[:self._buffer_length] + input_size = self._input_size[:self._buffer_length] + mem_ratio = self._mem_ratio[:self._buffer_length] + k, b = self._params + + plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis') + x = np.array([0.0, 20000.0]) + plt.plot(x, k * x + b, c='r') + plt.savefig(f'linear_memory_controller_{self.step}.png') + plt.cla() + + def get_mem_ratio(self, input_size): + k, b = self._params + if k == 0: return np.random.rand() * self._max_mem_ratio + pred = (self.available_memory * self.target_ratio - b) / (k * input_size) + return min(self._max_mem_ratio, max(0.0, pred)) + + def state_dict(self): + return { + 'params': self._params, + } + + def load_state_dict(self, state_dict): + self._params = tuple(state_dict['params']) + + def log(self): + return { + 'params/k': self._params[0], + 'params/b': self._params[1], + 'memory': self._last_memory, + 'input_size': self._last_input_size, + 'mem_ratio': self._last_mem_ratio, + } + + +class ElasticModule(nn.Module): + """ + Module for training with elastic memory management. + """ + def __init__(self): + super().__init__() + self._memory_controller: MemoryController = None + + @abstractmethod + def _get_input_size(self, *args, **kwargs) -> int: + """ + Get the size of the input data. + + Returns: + int: The size of the input data. + """ + pass + + @abstractmethod + def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]: + """ + Forward with a given memory ratio. + """ + pass + + def register_memory_controller(self, memory_controller: MemoryController): + self._memory_controller = memory_controller + + def forward(self, *args, **kwargs): + if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: + _, ret = self._forward_with_mem_ratio(*args, **kwargs) + else: + input_size = self._get_input_size(*args, **kwargs) + mem_ratio = self._memory_controller.get_mem_ratio(input_size) + mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs) + self._memory_controller.update_run_states(input_size, mem_ratio) + return ret + + +class ElasticModuleMixin: + """ + Mixin for training with elastic memory management. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._memory_controller: MemoryController = None + + @abstractmethod + def _get_input_size(self, *args, **kwargs) -> int: + """ + Get the size of the input data. + + Returns: + int: The size of the input data. + """ + pass + + @abstractmethod + @contextmanager + def with_mem_ratio(self, mem_ratio=1.0) -> float: + """ + Context manager for training with a reduced memory ratio compared to the full memory usage. + + Returns: + float: The exact memory ratio used during the forward pass. + """ + pass + + def register_memory_controller(self, memory_controller: MemoryController): + self._memory_controller = memory_controller + + def forward(self, *args, **kwargs): + if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: + ret = super().forward(*args, **kwargs) + else: + input_size = self._get_input_size(*args, **kwargs) + mem_ratio = self._memory_controller.get_mem_ratio(input_size) + with self.with_mem_ratio(mem_ratio) as exact_mem_ratio: + ret = super().forward(*args, **kwargs) + self._memory_controller.update_run_states(input_size, exact_mem_ratio) + return ret diff --git a/trellis2/utils/general_utils.py b/trellis2/utils/general_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..589c103de8a777aea9994a899f97431cbab5447a --- /dev/null +++ b/trellis2/utils/general_utils.py @@ -0,0 +1,373 @@ +import re +import numpy as np +import cv2 +import torch +import contextlib + + +# Dictionary utils +def _dict_merge(dicta, dictb, prefix=''): + """ + Merge two dictionaries. + """ + assert isinstance(dicta, dict), 'input must be a dictionary' + assert isinstance(dictb, dict), 'input must be a dictionary' + dict_ = {} + all_keys = set(dicta.keys()).union(set(dictb.keys())) + for key in all_keys: + if key in dicta.keys() and key in dictb.keys(): + if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): + dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') + else: + raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') + elif key in dicta.keys(): + dict_[key] = dicta[key] + else: + dict_[key] = dictb[key] + return dict_ + + +def dict_merge(dicta, dictb): + """ + Merge two dictionaries. + """ + return _dict_merge(dicta, dictb, prefix='') + + +def dict_foreach(dic, func, special_func={}): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + dic[key] = dict_foreach(dic[key], func) + else: + if key in special_func.keys(): + dic[key] = special_func[key](dic[key]) + else: + dic[key] = func(dic[key]) + return dic + + +def dict_reduce(dicts, func, special_func={}): + """ + Reduce a list of dictionaries. Leaf values must be scalars. + """ + assert isinstance(dicts, list), 'input must be a list of dictionaries' + assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' + assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' + all_keys = set([key for dict_ in dicts for key in dict_.keys()]) + reduced_dict = {} + for key in all_keys: + vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] + if isinstance(vlist[0], dict): + reduced_dict[key] = dict_reduce(vlist, func, special_func) + else: + if key in special_func.keys(): + reduced_dict[key] = special_func[key](vlist) + else: + reduced_dict[key] = func(vlist) + return reduced_dict + + +def dict_any(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if dict_any(dic[key], func): + return True + else: + if func(dic[key]): + return True + return False + + +def dict_all(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if not dict_all(dic[key], func): + return False + else: + if not func(dic[key]): + return False + return True + + +def dict_flatten(dic, sep='.'): + """ + Flatten a nested dictionary into a dictionary with no nested dictionaries. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + flat_dict = {} + for key in dic.keys(): + if isinstance(dic[key], dict): + sub_dict = dict_flatten(dic[key], sep=sep) + for sub_key in sub_dict.keys(): + flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] + else: + flat_dict[key] = dic[key] + return flat_dict + + +# Context utils +@contextlib.contextmanager +def nested_contexts(*contexts): + with contextlib.ExitStack() as stack: + for ctx in contexts: + stack.enter_context(ctx()) + yield + + +# Image utils +def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): + num_images = len(images) + if nrow is None and ncol is None: + if aspect_ratio is not None: + nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) + else: + nrow = int(np.sqrt(num_images)) + ncol = (num_images + nrow - 1) // nrow + elif nrow is None and ncol is not None: + nrow = (num_images + ncol - 1) // ncol + elif nrow is not None and ncol is None: + ncol = (num_images + nrow - 1) // nrow + else: + assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' + + if images[0].ndim == 2: + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype) + else: + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) + for i, img in enumerate(images): + row = i // ncol + col = i % ncol + grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img + return grid + + +def notes_on_image(img, notes=None): + img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + if notes is not None: + img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + + +def text_image(text, resolution=(512, 512), max_size=0.5, h_align="left", v_align="center"): + """ + Draw text on an image of the given resolution. The text is automatically wrapped + and scaled so that it fits completely within the image while preserving any explicit + line breaks and original spacing. Horizontal and vertical alignment can be controlled + via flags. + + Parameters: + text (str): The input text. Newline characters and spacing are preserved. + resolution (tuple): The image resolution as (width, height). + max_size (float): The maximum font size. + h_align (str): Horizontal alignment. Options: "left", "center", "right". + v_align (str): Vertical alignment. Options: "top", "center", "bottom". + + Returns: + numpy.ndarray: The resulting image (BGR format) with the text drawn. + """ + width, height = resolution + # Create a white background image + img = np.full((height, width, 3), 255, dtype=np.uint8) + + # Set margins and compute available drawing area + margin = 10 + avail_width = width - 2 * margin + avail_height = height - 2 * margin + + # Choose OpenCV font and text thickness + font = cv2.FONT_HERSHEY_SIMPLEX + thickness = 1 + # Ratio for additional spacing between lines (relative to the height of "A") + line_spacing_ratio = 0.5 + + def wrap_line(line, max_width, font, thickness, scale): + """ + Wrap a single line of text into multiple lines such that each line's + width (measured at the given scale) does not exceed max_width. + This function preserves the original spacing by splitting the line into tokens + (words and whitespace) using a regular expression. + + Parameters: + line (str): The input text line. + max_width (int): Maximum allowed width in pixels. + font (int): OpenCV font identifier. + thickness (int): Text thickness. + scale (float): The current font scale. + + Returns: + List[str]: A list of wrapped lines. + """ + # Split the line into tokens (words and whitespace), preserving spacing + tokens = re.split(r'(\s+)', line) + if not tokens: + return [''] + + wrapped_lines = [] + current_line = "" + for token in tokens: + candidate = current_line + token + candidate_width = cv2.getTextSize(candidate, font, scale, thickness)[0][0] + if candidate_width <= max_width: + current_line = candidate + else: + # If current_line is empty, the token itself is too wide; + # break the token character by character. + if current_line == "": + sub_token = "" + for char in token: + candidate_char = sub_token + char + if cv2.getTextSize(candidate_char, font, scale, thickness)[0][0] <= max_width: + sub_token = candidate_char + else: + if sub_token: + wrapped_lines.append(sub_token) + sub_token = char + current_line = sub_token + else: + wrapped_lines.append(current_line) + current_line = token + if current_line: + wrapped_lines.append(current_line) + return wrapped_lines + + def compute_text_block(scale): + """ + Wrap the entire text (splitting at explicit newline characters) using the + provided scale, and then compute the overall width and height of the text block. + + Returns: + wrapped_lines (List[str]): The list of wrapped lines. + block_width (int): Maximum width among the wrapped lines. + block_height (int): Total height of the text block including spacing. + sizes (List[tuple]): A list of (width, height) for each wrapped line. + spacing (int): The spacing between lines (computed from the scaled "A" height). + """ + # Split text by explicit newlines + input_lines = text.splitlines() if text else [''] + wrapped_lines = [] + for line in input_lines: + wrapped = wrap_line(line, avail_width, font, thickness, scale) + wrapped_lines.extend(wrapped) + + sizes = [] + for line in wrapped_lines: + (text_size, _) = cv2.getTextSize(line, font, scale, thickness) + sizes.append(text_size) # (width, height) + + block_width = max((w for w, h in sizes), default=0) + # Use the height of "A" (at the current scale) to compute line spacing + base_height = cv2.getTextSize("A", font, scale, thickness)[0][1] + spacing = int(line_spacing_ratio * base_height) + block_height = sum(h for w, h in sizes) + spacing * (len(sizes) - 1) if sizes else 0 + + return wrapped_lines, block_width, block_height, sizes, spacing + + # Use binary search to find the maximum scale that allows the text block to fit + lo = 0.001 + hi = max_size + eps = 0.001 # convergence threshold + best_scale = lo + best_result = None + + while hi - lo > eps: + mid = (lo + hi) / 2 + wrapped_lines, block_width, block_height, sizes, spacing = compute_text_block(mid) + # Ensure that both width and height constraints are met + if block_width <= avail_width and block_height <= avail_height: + best_scale = mid + best_result = (wrapped_lines, block_width, block_height, sizes, spacing) + lo = mid # try a larger scale + else: + hi = mid # reduce the scale + + if best_result is None: + best_scale = 0.5 + best_result = compute_text_block(best_scale) + + wrapped_lines, block_width, block_height, sizes, spacing = best_result + + # Compute starting y-coordinate based on vertical alignment flag + if v_align == "top": + y_top = margin + elif v_align == "center": + y_top = margin + (avail_height - block_height) // 2 + elif v_align == "bottom": + y_top = margin + (avail_height - block_height) + else: + y_top = margin + (avail_height - block_height) // 2 # default to center if invalid flag + + # For cv2.putText, the y coordinate represents the text baseline; + # so for the first line add its height. + y = y_top + (sizes[0][1] if sizes else 0) + + # Draw each line with horizontal alignment based on the flag + for i, line in enumerate(wrapped_lines): + line_width, line_height = sizes[i] + if h_align == "left": + x = margin + elif h_align == "center": + x = margin + (avail_width - line_width) // 2 + elif h_align == "right": + x = margin + (avail_width - line_width) + else: + x = margin # default to left if invalid flag + + cv2.putText(img, line, (x, y), font, best_scale, (0, 0, 0), thickness, cv2.LINE_AA) + y += line_height + spacing + + return img + + +def save_image_with_notes(img, path, notes=None): + """ + Save an image with notes. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy().transpose(1, 2, 0) + if img.dtype == np.float32 or img.dtype == np.float64: + img = np.clip(img * 255, 0, 255).astype(np.uint8) + img = notes_on_image(img, notes) + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + +# debug utils + +def atol(x, y): + """ + Absolute tolerance. + """ + return torch.abs(x - y) + + +def rtol(x, y): + """ + Relative tolerance. + """ + return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) + + +# print utils +def indent(s, n=4): + """ + Indent a string. + """ + lines = s.split('\n') + for i in range(1, len(lines)): + lines[i] = ' ' * n + lines[i] + return '\n'.join(lines) + diff --git a/trellis2/utils/grad_clip_utils.py b/trellis2/utils/grad_clip_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..990a4352e24fc73bf732d8eb0f8ca9a07365b49e --- /dev/null +++ b/trellis2/utils/grad_clip_utils.py @@ -0,0 +1,81 @@ +from typing import * +import torch +import numpy as np +import torch.utils + + +class AdaptiveGradClipper: + """ + Adaptive gradient clipping for training. + """ + def __init__( + self, + max_norm=None, + clip_percentile=95.0, + buffer_size=1000, + ): + self.max_norm = max_norm + self.clip_percentile = clip_percentile + self.buffer_size = buffer_size + + self._grad_norm = np.zeros(buffer_size, dtype=np.float32) + self._max_norm = max_norm + self._buffer_ptr = 0 + self._buffer_length = 0 + + def __repr__(self): + return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})' + + def state_dict(self): + return { + 'grad_norm': self._grad_norm, + 'max_norm': self._max_norm, + 'buffer_ptr': self._buffer_ptr, + 'buffer_length': self._buffer_length, + } + + def load_state_dict(self, state_dict): + self._grad_norm = state_dict['grad_norm'] + self._max_norm = state_dict['max_norm'] + self._buffer_ptr = state_dict['buffer_ptr'] + self._buffer_length = state_dict['buffer_length'] + + def log(self): + return { + 'max_norm': self._max_norm, + } + + def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None): + """Clip the gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + max_norm = self._max_norm if self._max_norm is not None else float('inf') + grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach) + + if torch.isfinite(grad_norm): + self._grad_norm[self._buffer_ptr] = grad_norm + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + if self._buffer_length == self.buffer_size: + self._max_norm = np.percentile(self._grad_norm, self.clip_percentile) + self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm + + return grad_norm \ No newline at end of file diff --git a/trellis2/utils/loss_utils.py b/trellis2/utils/loss_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..52049f69543f2700bc5525b09cbf2fb25c08aa9e --- /dev/null +++ b/trellis2/utils/loss_utils.py @@ -0,0 +1,92 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +from lpips import LPIPS + + +def smooth_l1_loss(pred, target, beta=1.0): + diff = torch.abs(pred - target) + loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) + return loss.mean() + + +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def psnr(img1, img2, max_val=1.0): + mse = F.mse_loss(img1, img2) + return 20 * torch.log10(max_val / torch.sqrt(mse)) + + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +loss_fn_vgg = None +def lpips(img1, img2, value_range=(0, 1)): + global loss_fn_vgg + if loss_fn_vgg is None: + loss_fn_vgg = LPIPS(net='vgg').cuda().eval() + # normalize to [-1, 1] + img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 + img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 + return loss_fn_vgg(img1, img2).mean() + + +def normal_angle(pred, gt): + pred = pred * 2.0 - 1.0 + gt = gt * 2.0 - 1.0 + norms = pred.norm(dim=-1) * gt.norm(dim=-1) + cos_sim = (pred * gt).sum(-1) / (norms + 1e-9) + cos_sim = torch.clamp(cos_sim, -1.0, 1.0) + ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean() + if ang.isnan(): + return -1 + return ang diff --git a/trellis2/utils/mesh_utils.py b/trellis2/utils/mesh_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..a9f1451ebd8b89879eee79cc61a6f4161136f245 --- /dev/null +++ b/trellis2/utils/mesh_utils.py @@ -0,0 +1,268 @@ +from typing import Tuple, Dict +import numpy as np +from trimesh import grouping, util, remesh +import struct +import re +from plyfile import PlyData, PlyElement + + +def read_ply(filename): + """ + Read a PLY file and return vertices, triangle faces, and quad faces. + + Args: + filename (str): The file path to read from. + + Returns: + vertices (np.ndarray): Array of shape [N, 3] containing vertex positions. + tris (np.ndarray): Array of shape [M, 3] containing triangle face indices (empty if none). + quads (np.ndarray): Array of shape [K, 4] containing quad face indices (empty if none). + """ + with open(filename, 'rb') as f: + # Read the header until 'end_header' is encountered + header_bytes = b"" + while True: + line = f.readline() + if not line: + raise ValueError("PLY header not found") + header_bytes += line + if b"end_header" in line: + break + header = header_bytes.decode('utf-8') + + # Determine if the file is in ASCII or binary format + is_ascii = "ascii" in header + + # Extract the number of vertices and faces from the header using regex + vertex_match = re.search(r'element vertex (\d+)', header) + if vertex_match: + num_vertices = int(vertex_match.group(1)) + else: + raise ValueError("Vertex count not found in header") + + face_match = re.search(r'element face (\d+)', header) + if face_match: + num_faces = int(face_match.group(1)) + else: + raise ValueError("Face count not found in header") + + vertices = [] + tris = [] + quads = [] + + if is_ascii: + # For ASCII format, read each line of vertex data (each line contains 3 floats) + for _ in range(num_vertices): + line = f.readline().decode('utf-8').strip() + if not line: + continue + parts = line.split() + vertices.append([float(parts[0]), float(parts[1]), float(parts[2])]) + + # Read face data, where the first number indicates the number of vertices for the face + for _ in range(num_faces): + line = f.readline().decode('utf-8').strip() + if not line: + continue + parts = line.split() + count = int(parts[0]) + indices = list(map(int, parts[1:])) + if count == 3: + tris.append(indices) + elif count == 4: + quads.append(indices) + else: + # Skip faces with other numbers of vertices (can be extended as needed) + pass + else: + # For binary format: read directly from the binary stream + # Each vertex consists of 3 floats (12 bytes per vertex) + for _ in range(num_vertices): + data = f.read(12) + if len(data) < 12: + raise ValueError("Insufficient vertex data") + v = struct.unpack(' 0 else np.empty((0, 3), dtype=np.int32) + quads = np.array(quads, dtype=np.int32) if len(quads) > 0 else np.empty((0, 4), dtype=np.int32) + + return vertices, tris, quads + + +def write_ply( + filename: str, + vertices: np.ndarray, + tris: np.ndarray, + quads: np.ndarray, + vertex_colors: np.ndarray = None, + ascii: bool = False +): + """ + Write a mesh to a PLY file, with the option to save in ASCII or binary format, + and optional per-vertex colors. + + Args: + filename (str): The filename to write to. + vertices (np.ndarray): [N, 3] The vertex positions. + tris (np.ndarray): [M, 3] The triangle indices. + quads (np.ndarray): [K, 4] The quad indices. + vertex_colors (np.ndarray, optional): [N, 3] or [N, 4] UInt8 colors for each vertex (RGB or RGBA). + ascii (bool): If True, write in ASCII format; otherwise binary little-endian. + """ + import struct + + num_vertices = len(vertices) + num_faces = len(tris) + len(quads) + + # Build header + header_lines = [ + "ply", + f"format {'ascii 1.0' if ascii else 'binary_little_endian 1.0'}", + f"element vertex {num_vertices}", + "property float x", + "property float y", + "property float z", + ] + + # Add vertex color properties if provided + has_color = vertex_colors is not None + if has_color: + # Expect uint8 values 0-255 + header_lines += [ + "property uchar red", + "property uchar green", + "property uchar blue", + ] + # Include alpha if RGBA + if vertex_colors.shape[1] == 4: + header_lines.append("property uchar alpha") + + header_lines += [ + f"element face {num_faces}", + "property list uchar int vertex_index", + "end_header", + "" + ] + header = "\n".join(header_lines) + + mode = 'w' if ascii else 'wb' + with open(filename, mode) as f: + # Write header + if ascii: + f.write(header) + else: + f.write(header.encode('utf-8')) + + # Write vertex data + for i, v in enumerate(vertices): + if ascii: + line = f"{v[0]} {v[1]} {v[2]}" + if has_color: + col = vertex_colors[i] + line += ' ' + ' '.join(str(int(c)) for c in col) + f.write(line + '\n') + else: + # pack xyz as floats + f.write(struct.pack(' 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] \ No newline at end of file diff --git a/trellis2/utils/render_utils.py b/trellis2/utils/render_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..28b0dd26d4faafd06c9da67699270be4573bd5cc --- /dev/null +++ b/trellis2/utils/render_utils.py @@ -0,0 +1,129 @@ +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from PIL import Image + +from ..renderers import MeshRenderer, VoxelRenderer, PbrMeshRenderer +from ..representations import Mesh, Voxel, MeshWithPbrMaterial, MeshWithVoxel +from .random_utils import sphere_hammersley_sequence + + +def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).cuda() + yaw = torch.tensor(float(yaw)).cuda() + pitch = torch.tensor(float(pitch)).cuda() + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda() * r + extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def get_renderer(sample, **kwargs): + if isinstance(sample, (MeshWithPbrMaterial, MeshWithVoxel)): + renderer = PbrMeshRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 1) + renderer.rendering_options.far = kwargs.get('far', 100) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + renderer.rendering_options.peel_layers = kwargs.get('peel_layers', 8) + elif isinstance(sample, Mesh): + renderer = MeshRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 1) + renderer.rendering_options.far = kwargs.get('far', 100) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + renderer.rendering_options.chunk_size = kwargs.get('chunk_size', None) + elif isinstance(sample, Voxel): + renderer = VoxelRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 0.1) + renderer.rendering_options.far = kwargs.get('far', 10.0) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + else: + raise ValueError(f'Unsupported sample type: {type(sample)}') + return renderer + + +def render_frames(sample, extrinsics, intrinsics, options={}, verbose=True, **kwargs): + renderer = get_renderer(sample, **options) + rets = {} + for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), total=len(extrinsics), desc='Rendering', disable=not verbose): + res = renderer.render(sample, extr, intr, **kwargs) + for k, v in res.items(): + if k not in rets: rets[k] = [] + if v.dim() == 2: v = v[None].repeat(3, 1, 1) + rets[k].append(np.clip(v.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + return rets + + +def render_video(sample, resolution=1024, bg_color=(0, 0, 0), num_frames=120, r=2, fov=40, **kwargs): + yaws = -torch.linspace(0, 2 * 3.1415, num_frames) + np.pi/2 + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitch = pitch.tolist() + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) + return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def render_multiview(sample, resolution=512, nviews=30): + r = 2 + fov = 40 + cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] + yaws = [cam[0] for cam in cams] + pitchs = [cam[1] for cam in cams] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) + return res['color'], extrinsics, intrinsics + + +def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs): + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = offset[0] + yaw = [y + yaw_offset for y in yaw] + pitch = [offset[1] for _ in range(4)] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) + return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def make_pbr_vis_frames(result, resolution=1024): + num_frames = len(result['shaded']) + frames = [] + for i in range(num_frames): + shaded = Image.fromarray(result['shaded'][i]) + normal = Image.fromarray(result['normal'][i]) + base_color = Image.fromarray(result['base_color'][i]) + metallic = Image.fromarray(result['metallic'][i]) + roughness = Image.fromarray(result['roughness'][i]) + alpha = Image.fromarray(result['alpha'][i]) + shaded = shaded.resize((resolution, resolution)) + normal = normal.resize((resolution, resolution)) + base_color = base_color.resize((resolution//2, resolution//2)) + metallic = metallic.resize((resolution//2, resolution//2)) + roughness = roughness.resize((resolution//2, resolution//2)) + alpha = alpha.resize((resolution//2, resolution//2)) + row1 = np.concatenate([shaded, normal], axis=1) + row2 = np.concatenate([base_color, metallic, roughness, alpha], axis=1) + frame = np.concatenate([row1, row2], axis=0) + frames.append(frame) + return frames diff --git a/trellis2/utils/vis_utils.py b/trellis2/utils/vis_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..0e5f58e564aea50e4d80b0265220ee3fb382cd69 --- /dev/null +++ b/trellis2/utils/vis_utils.py @@ -0,0 +1,44 @@ +from typing import * +import numpy as np +import torch +from ..modules import sparse as sp +from ..representations import Voxel +from .render_utils import render_video + + +def pca_color(feats: torch.Tensor, channels: Tuple[int, int, int] = (0, 1, 2)) -> torch.Tensor: + """ + Apply PCA to the features and return the first three principal components. + """ + feats = feats.detach() + u, s, v = torch.svd(feats) + color = u[:, channels] + color = (color - color.min(dim=0, keepdim=True)[0]) / (color.max(dim=0, keepdim=True)[0] - color.min(dim=0, keepdim=True)[0]) + return color + + +def vis_sparse_tensor( + x: sp.SparseTensor, + num_frames: int = 300, +): + assert x.shape[0] == 1, "Only support batch size 1" + assert x.coords.shape[1] == 4, "Only support 3D coordinates" + + coords = x.coords.cuda().detach()[:, 1:] + feats = x.feats.cuda().detach() + color = pca_color(feats) + + resolution = max(list(x.spatial_shape)) + resolution = int(2**np.ceil(np.log2(resolution))) + + rep = Voxel( + origin=[-0.5, -0.5, -0.5], + voxel_size=1/resolution, + coords=coords, + attrs=color, + layout={ + 'color': slice(0, 3), + } + ) + + return render_video(rep, colors_overwrite=color, num_frames=num_frames)['color']