JeffreyXiang commited on
Commit
917a889
·
1 Parent(s): cc2d3ad
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +207 -0
  2. README.md +4 -3
  3. app.py +335 -0
  4. requirements.txt +20 -0
  5. trellis2/__init__.py +6 -0
  6. trellis2/models/__init__.py +78 -0
  7. trellis2/models/sc_vaes/fdg_vae.py +110 -0
  8. trellis2/models/sc_vaes/sparse_unet_vae.py +522 -0
  9. trellis2/models/sparse_elastic_mixin.py +24 -0
  10. trellis2/models/sparse_structure_flow.py +248 -0
  11. trellis2/models/sparse_structure_vae.py +306 -0
  12. trellis2/models/structured_latent_flow.py +208 -0
  13. trellis2/modules/attention/__init__.py +3 -0
  14. trellis2/modules/attention/config.py +32 -0
  15. trellis2/modules/attention/full_attn.py +144 -0
  16. trellis2/modules/attention/modules.py +102 -0
  17. trellis2/modules/attention/rope.py +48 -0
  18. trellis2/modules/norm.py +32 -0
  19. trellis2/modules/sparse/__init__.py +69 -0
  20. trellis2/modules/sparse/attention/__init__.py +3 -0
  21. trellis2/modules/sparse/attention/full_attn.py +214 -0
  22. trellis2/modules/sparse/attention/modules.py +141 -0
  23. trellis2/modules/sparse/attention/rope.py +58 -0
  24. trellis2/modules/sparse/attention/windowed_attn.py +190 -0
  25. trellis2/modules/sparse/basic.py +836 -0
  26. trellis2/modules/sparse/config.py +43 -0
  27. trellis2/modules/sparse/conv/__init__.py +2 -0
  28. trellis2/modules/sparse/conv/config.py +3 -0
  29. trellis2/modules/sparse/conv/conv.py +30 -0
  30. trellis2/modules/sparse/conv/conv_flex_gemm.py +68 -0
  31. trellis2/modules/sparse/conv/conv_spconv.py +73 -0
  32. trellis2/modules/sparse/conv/conv_torchsparse.py +30 -0
  33. trellis2/modules/sparse/linear.py +15 -0
  34. trellis2/modules/sparse/nonlinearity.py +35 -0
  35. trellis2/modules/sparse/norm.py +64 -0
  36. trellis2/modules/sparse/spatial/__init__.py +2 -0
  37. trellis2/modules/sparse/spatial/basic.py +109 -0
  38. trellis2/modules/sparse/spatial/spatial2channel.py +93 -0
  39. trellis2/modules/sparse/transformer/__init__.py +2 -0
  40. trellis2/modules/sparse/transformer/blocks.py +145 -0
  41. trellis2/modules/sparse/transformer/modulated.py +166 -0
  42. trellis2/modules/spatial.py +48 -0
  43. trellis2/modules/transformer/__init__.py +2 -0
  44. trellis2/modules/transformer/blocks.py +186 -0
  45. trellis2/modules/transformer/modulated.py +165 -0
  46. trellis2/modules/utils.py +74 -0
  47. trellis2/pipelines/__init__.py +55 -0
  48. trellis2/pipelines/base.py +70 -0
  49. trellis2/pipelines/rembg/BiRefNet.py +42 -0
  50. trellis2/pipelines/rembg/__init__.py +1 -0
.gitignore ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
  title: TRELLIS.2
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 6.1.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: TRELLIS.2
3
+ emoji: 🏢
4
+ colorFrom: indigo
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 6.1.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: High-fidelity 3D Generation from images
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+
4
+ import os
5
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
6
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
7
+ from datetime import datetime
8
+ import shutil
9
+ import cv2
10
+ from typing import *
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ from trellis2.modules.sparse import SparseTensor
15
+ from trellis2.pipelines import Trellis2ImageTo3DPipeline
16
+ from trellis2.renderers import EnvMap
17
+ from trellis2.utils import render_utils
18
+ import o_voxel
19
+
20
+
21
+ MAX_SEED = np.iinfo(np.int32).max
22
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
+ os.makedirs(TMP_DIR, exist_ok=True)
24
+
25
+
26
+ def start_session(req: gr.Request):
27
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
+ os.makedirs(user_dir, exist_ok=True)
29
+
30
+
31
+ def end_session(req: gr.Request):
32
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
+ shutil.rmtree(user_dir)
34
+
35
+
36
+ def preprocess_image(image: Image.Image) -> Image.Image:
37
+ """
38
+ Preprocess the input image.
39
+
40
+ Args:
41
+ image (Image.Image): The input image.
42
+
43
+ Returns:
44
+ Image.Image: The preprocessed image.
45
+ """
46
+ processed_image = pipeline.preprocess_image(image)
47
+ return processed_image
48
+
49
+
50
+ def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
51
+ shape_slat, tex_slat, res = latents
52
+ return {
53
+ 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
54
+ 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
55
+ 'coords': shape_slat.coords.cpu().numpy(),
56
+ 'res': res,
57
+ }
58
+
59
+
60
+ def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
61
+ shape_slat = SparseTensor(
62
+ feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
63
+ coords=torch.from_numpy(state['coords']).cuda(),
64
+ )
65
+ tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
66
+ return shape_slat, tex_slat, state['res']
67
+
68
+
69
+ def get_seed(randomize_seed: bool, seed: int) -> int:
70
+ """
71
+ Get the random seed.
72
+ """
73
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
74
+
75
+
76
+ @spaces.GPU(duration=120)
77
+ def image_to_3d(
78
+ image: Image.Image,
79
+ seed: int,
80
+ resolution: str,
81
+ ss_guidance_strength: float,
82
+ ss_guidance_rescale: float,
83
+ ss_sampling_steps: int,
84
+ ss_rescale_t: float,
85
+ shape_slat_guidance_strength: float,
86
+ shape_slat_guidance_rescale: float,
87
+ shape_slat_sampling_steps: int,
88
+ shape_slat_rescale_t: float,
89
+ tex_slat_guidance_strength: float,
90
+ tex_slat_guidance_rescale: float,
91
+ tex_slat_sampling_steps: int,
92
+ tex_slat_rescale_t: float,
93
+ req: gr.Request,
94
+ progress=gr.Progress(track_tqdm=True),
95
+ ) -> str:
96
+ """
97
+ Convert an image to a 3D model.
98
+
99
+ Args:
100
+ image (Image.Image): The input image.
101
+ seed (int): The random seed.
102
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
103
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
104
+ shape_slat_guidance_strength (float): The guidance strength for shape slat generation.
105
+ shape_slat_sampling_steps (int): The number of sampling steps for shape slat generation.
106
+ tex_slat_guidance_strength (float): The guidance strength for texture slat generation.
107
+ tex_slat_sampling_steps (int): The number of sampling steps for texture slat generation.
108
+
109
+ Returns:
110
+ str: The path to the preview video of the 3D model.
111
+ str: The path to the 3D model.
112
+ """
113
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
114
+ outputs, latents = pipeline.run(
115
+ image,
116
+ seed=seed,
117
+ preprocess_image=False,
118
+ sparse_structure_sampler_params={
119
+ "steps": ss_sampling_steps,
120
+ "guidance_strength": ss_guidance_strength,
121
+ "guidance_rescale": ss_guidance_rescale,
122
+ "rescale_t": ss_rescale_t,
123
+ },
124
+ shape_slat_sampler_params={
125
+ "steps": shape_slat_sampling_steps,
126
+ "guidance_strength": shape_slat_guidance_strength,
127
+ "guidance_rescale": shape_slat_guidance_rescale,
128
+ "rescale_t": shape_slat_rescale_t,
129
+ },
130
+ tex_slat_sampler_params={
131
+ "steps": tex_slat_sampling_steps,
132
+ "guidance_strength": tex_slat_guidance_strength,
133
+ "guidance_rescale": tex_slat_guidance_rescale,
134
+ "rescale_t": tex_slat_rescale_t,
135
+ },
136
+ pipeline_type={
137
+ "512": "512",
138
+ "1024": "512->1024",
139
+ "1536": "512->1536",
140
+ }[resolution],
141
+ return_latent=True,
142
+ )
143
+ images = render_utils.make_pbr_vis_frames(
144
+ render_utils.render_snapshot(outputs[0], resolution=1024, r=2, fov=36, envmap=envmap),
145
+ resolution=1024
146
+ )
147
+ state = pack_state(latents)
148
+ torch.cuda.empty_cache()
149
+ return state, [Image.fromarray(image) for image in images]
150
+
151
+
152
+ @spaces.GPU(duration=120)
153
+ def extract_glb(
154
+ state: dict,
155
+ decimation_target: int,
156
+ texture_size: int,
157
+ req: gr.Request,
158
+ progress=gr.Progress(track_tqdm=True),
159
+ ) -> Tuple[str, str]:
160
+ """
161
+ Extract a GLB file from the 3D model.
162
+
163
+ Args:
164
+ state (dict): The state of the generated 3D model.
165
+ decimation_target (int): The target face count for decimation.
166
+ texture_size (int): The texture resolution.
167
+
168
+ Returns:
169
+ str: The path to the extracted GLB file.
170
+ """
171
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
172
+ shape_slat, tex_slat, res = unpack_state(state)
173
+ mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
174
+ glb = o_voxel.postprocess.to_glb(
175
+ vertices=mesh.vertices,
176
+ faces=mesh.faces,
177
+ attr_volume=mesh.attrs,
178
+ coords=mesh.coords,
179
+ attr_layout=pipeline.pbr_attr_layout,
180
+ grid_size=res,
181
+ aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
182
+ decimation_target=decimation_target,
183
+ texture_size=texture_size,
184
+ use_tqdm=True,
185
+ )[0]
186
+ now = datetime.now()
187
+ timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
188
+ os.makedirs(user_dir, exist_ok=True)
189
+ glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
190
+ glb.export(glb_path)
191
+ torch.cuda.empty_cache()
192
+ return glb_path, glb_path
193
+
194
+
195
+ css = """
196
+ .stepper-wrapper {
197
+ padding: 0;
198
+ }
199
+
200
+ .stepper-container {
201
+ padding: 0;
202
+ align-items: center;
203
+ }
204
+
205
+ .step-button {
206
+ flex-direction: row;
207
+ }
208
+
209
+ .step-connector {
210
+ transform: none;
211
+ }
212
+
213
+ .step-number {
214
+ width: 16px;
215
+ height: 16px;
216
+ }
217
+
218
+ .step-label {
219
+ position: relative;
220
+ bottom: 0;
221
+ }
222
+ """
223
+
224
+
225
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
226
+ gr.Markdown("""
227
+ ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/trellis.2)
228
+ * Upload an image and click "Generate" to create a 3D asset.
229
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
230
+ """)
231
+
232
+ with gr.Row():
233
+ with gr.Column(scale=1, min_width=360):
234
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
235
+
236
+ resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="512")
237
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
238
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
239
+ decimation_target = gr.Slider(10000, 500000, label="Decimation Target", value=100000, step=10000)
240
+ texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
241
+
242
+ with gr.Accordion(label="Advanced Settings", open=False):
243
+ gr.Markdown("Stage 1: Sparse Structure Generation")
244
+ with gr.Row():
245
+ ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
246
+ ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
247
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
248
+ ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
249
+ gr.Markdown("Stage 2: Shape Generation")
250
+ with gr.Row():
251
+ shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
252
+ shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
253
+ shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
254
+ shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
255
+ gr.Markdown("Stage 3: Material Generation")
256
+ with gr.Row():
257
+ tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
258
+ tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
259
+ tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
260
+ tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
261
+
262
+ generate_btn = gr.Button("Generate")
263
+
264
+ with gr.Column(scale=10):
265
+ with gr.Walkthrough(selected=0) as walkthrough:
266
+ with gr.Step("Preview", id=0):
267
+ preview_output = gr.Gallery(label="3D Asset Preview", height=800, show_label=True, preview=True)
268
+ extract_btn = gr.Button("Extract GLB")
269
+ with gr.Step("Extract", id=1):
270
+ glb_output = gr.Model3D(label="Extracted GLB", height=800, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
271
+ download_btn = gr.DownloadButton(label="Download GLB")
272
+
273
+ with gr.Column(scale=1, min_width=172):
274
+ examples = gr.Examples(
275
+ examples=[
276
+ f'assets/example_image/{image}'
277
+ for image in os.listdir("assets/example_image")
278
+ ],
279
+ inputs=[image_prompt],
280
+ fn=preprocess_image,
281
+ outputs=[image_prompt],
282
+ run_on_click=True,
283
+ examples_per_page=18,
284
+ )
285
+
286
+ output_buf = gr.State()
287
+
288
+
289
+ # Handlers
290
+ demo.load(start_session)
291
+ demo.unload(end_session)
292
+
293
+ image_prompt.upload(
294
+ preprocess_image,
295
+ inputs=[image_prompt],
296
+ outputs=[image_prompt],
297
+ )
298
+
299
+ generate_btn.click(
300
+ get_seed,
301
+ inputs=[randomize_seed, seed],
302
+ outputs=[seed],
303
+ ).then(
304
+ lambda: gr.Walkthrough(selected=0), outputs=walkthrough
305
+ ).then(
306
+ image_to_3d,
307
+ inputs=[
308
+ image_prompt, seed, resolution,
309
+ ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
310
+ shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
311
+ tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
312
+ ],
313
+ outputs=[output_buf, preview_output],
314
+ )
315
+
316
+ extract_btn.click(
317
+ lambda: gr.Walkthrough(selected=1), outputs=walkthrough
318
+ ).then(
319
+ extract_glb,
320
+ inputs=[output_buf, decimation_target, texture_size],
321
+ outputs=[glb_output, download_btn],
322
+ )
323
+
324
+
325
+ # Launch the Gradio app
326
+ if __name__ == "__main__":
327
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained('JeffreyXiang/TRELLIS.2-4B')
328
+ pipeline.cuda()
329
+
330
+ envmap = EnvMap(torch.tensor(
331
+ cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
332
+ dtype=torch.float32, device='cuda'
333
+ ))
334
+
335
+ demo.launch(css=css, mcp_server=True)
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+
3
+ torch==2.6.0
4
+ torchvision==0.21.0
5
+ triton==3.2.0
6
+ pillow==12.0.0
7
+ imageio==2.37.2
8
+ imageio-ffmpeg==0.6.0
9
+ tqdm==4.67.1
10
+ easydict==1.13
11
+ opencv-python-headless==4.12.0.88
12
+ trimesh==4.10.1
13
+ transformers==4.46.3
14
+ git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
15
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
16
+ https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl?download=true
17
+ https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl?download=true
18
+ https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl?download=true
19
+ https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/nvdiffrast-0.3.5-cp310-cp310-linux_x86_64?download=true
20
+ https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/nvdiffrec_render-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
trellis2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import models
2
+ from . import modules
3
+ from . import pipelines
4
+ from . import renderers
5
+ from . import representations
6
+ from . import utils
trellis2/models/__init__.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ # Sparse Structure
5
+ 'SparseStructureEncoder': 'sparse_structure_vae',
6
+ 'SparseStructureDecoder': 'sparse_structure_vae',
7
+ 'SparseStructureFlowModel': 'sparse_structure_flow',
8
+
9
+ # SLat Generation
10
+ 'SLatFlowModel': 'structured_latent_flow',
11
+ 'ElasticSLatFlowModel': 'structured_latent_flow',
12
+
13
+ # SC-VAEs
14
+ 'SparseUnetVaeEncoder': 'sc_vaes.sparse_unet_vae',
15
+ 'SparseUnetVaeDecoder': 'sc_vaes.sparse_unet_vae',
16
+ 'FlexiDualGridVaeEncoder': 'sc_vaes.fdg_vae',
17
+ 'FlexiDualGridVaeDecoder': 'sc_vaes.fdg_vae'
18
+ }
19
+
20
+ __submodules = []
21
+
22
+ __all__ = list(__attributes.keys()) + __submodules
23
+
24
+ def __getattr__(name):
25
+ if name not in globals():
26
+ if name in __attributes:
27
+ module_name = __attributes[name]
28
+ module = importlib.import_module(f".{module_name}", __name__)
29
+ globals()[name] = getattr(module, name)
30
+ elif name in __submodules:
31
+ module = importlib.import_module(f".{name}", __name__)
32
+ globals()[name] = module
33
+ else:
34
+ raise AttributeError(f"module {__name__} has no attribute {name}")
35
+ return globals()[name]
36
+
37
+
38
+ def from_pretrained(path: str, **kwargs):
39
+ """
40
+ Load a model from a pretrained checkpoint.
41
+
42
+ Args:
43
+ path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
44
+ NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
45
+ **kwargs: Additional arguments for the model constructor.
46
+ """
47
+ import os
48
+ import json
49
+ from safetensors.torch import load_file
50
+ is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
51
+
52
+ if is_local:
53
+ config_file = f"{path}.json"
54
+ model_file = f"{path}.safetensors"
55
+ else:
56
+ from huggingface_hub import hf_hub_download
57
+ path_parts = path.split('/')
58
+ repo_id = f'{path_parts[0]}/{path_parts[1]}'
59
+ model_name = '/'.join(path_parts[2:])
60
+ config_file = hf_hub_download(repo_id, f"{model_name}.json")
61
+ model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
62
+
63
+ with open(config_file, 'r') as f:
64
+ config = json.load(f)
65
+ model = __getattr__(config['name'])(**config['args'], **kwargs)
66
+ model.load_state_dict(load_file(model_file), strict=False)
67
+
68
+ return model
69
+
70
+
71
+ # For Pylance
72
+ if __name__ == '__main__':
73
+ from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
74
+ from .sparse_structure_flow import SparseStructureFlowModel
75
+ from .structured_latent_flow import SLatFlowModel, ElasticSLatFlowModel
76
+
77
+ from .sc_vaes.sparse_unet_vae import SparseUnetVaeEncoder, SparseUnetVaeDecoder
78
+ from .sc_vaes.fdg_vae import FlexiDualGridVaeEncoder, FlexiDualGridVaeDecoder
trellis2/models/sc_vaes/fdg_vae.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ...modules import sparse as sp
6
+ from .sparse_unet_vae import (
7
+ SparseResBlock3d,
8
+ SparseConvNeXtBlock3d,
9
+
10
+ SparseResBlockDownsample3d,
11
+ SparseResBlockUpsample3d,
12
+ SparseResBlockS2C3d,
13
+ SparseResBlockC2S3d,
14
+ )
15
+ from .sparse_unet_vae import (
16
+ SparseUnetVaeEncoder,
17
+ SparseUnetVaeDecoder,
18
+ )
19
+ from ...representations import Mesh
20
+ from o_voxel.convert import flexible_dual_grid_to_mesh
21
+
22
+
23
+ class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder):
24
+ def __init__(
25
+ self,
26
+ model_channels: List[int],
27
+ latent_channels: int,
28
+ num_blocks: List[int],
29
+ block_type: List[str],
30
+ down_block_type: List[str],
31
+ block_args: List[Dict[str, Any]],
32
+ use_fp16: bool = False,
33
+ ):
34
+ super().__init__(
35
+ 6,
36
+ model_channels,
37
+ latent_channels,
38
+ num_blocks,
39
+ block_type,
40
+ down_block_type,
41
+ block_args,
42
+ use_fp16,
43
+ )
44
+
45
+ def forward(self, vertices: sp.SparseTensor, intersected: sp.SparseTensor, sample_posterior=False, return_raw=False):
46
+ x = vertices.replace(torch.cat([
47
+ vertices.feats - 0.5,
48
+ intersected.feats.float() - 0.5,
49
+ ], dim=1))
50
+ return super().forward(x, sample_posterior, return_raw)
51
+
52
+
53
+ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
54
+ def __init__(
55
+ self,
56
+ resolution: int,
57
+ model_channels: List[int],
58
+ latent_channels: int,
59
+ num_blocks: List[int],
60
+ block_type: List[str],
61
+ up_block_type: List[str],
62
+ block_args: List[Dict[str, Any]],
63
+ voxel_margin: float = 0.5,
64
+ use_fp16: bool = False,
65
+ ):
66
+ self.resolution = resolution
67
+ self.voxel_margin = voxel_margin
68
+
69
+ super().__init__(
70
+ 7,
71
+ model_channels,
72
+ latent_channels,
73
+ num_blocks,
74
+ block_type,
75
+ up_block_type,
76
+ block_args,
77
+ use_fp16,
78
+ )
79
+
80
+ def set_resolution(self, resolution: int) -> None:
81
+ self.resolution = resolution
82
+
83
+ def forward(self, x: sp.SparseTensor, gt_intersected: sp.SparseTensor = None, **kwargs):
84
+ decoded = super().forward(x, **kwargs)
85
+ if self.training:
86
+ h, subs_gt, subs = decoded
87
+ vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
88
+ intersected_logits = h.replace(h.feats[..., 3:6])
89
+ quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
90
+ mesh = [Mesh(flexible_dual_grid_to_mesh(
91
+ h.coords[:, 1:], v.feats, i.feats, q.feats,
92
+ aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
93
+ grid_size=self.resolution,
94
+ train=True
95
+ )) for v, i, q in zip(vertices, gt_intersected, quad_lerp)]
96
+ return mesh, vertices, intersected_logits, subs_gt, subs
97
+ else:
98
+ out_list = list(decoded) if isinstance(decoded, tuple) else [decoded]
99
+ h = out_list[0]
100
+ vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
101
+ intersected = h.replace(h.feats[..., 3:6] > 0)
102
+ quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
103
+ mesh = [Mesh(*flexible_dual_grid_to_mesh(
104
+ h.coords[:, 1:], v.feats, i.feats, q.feats,
105
+ aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
106
+ grid_size=self.resolution,
107
+ train=False
108
+ )) for v, i, q in zip(vertices, intersected, quad_lerp)]
109
+ out_list[0] = mesh
110
+ return out_list[0] if len(out_list) == 1 else tuple(out_list)
trellis2/models/sc_vaes/sparse_unet_vae.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from ...modules.utils import convert_module_to_f16, convert_module_to_f32, zero_module
7
+ from ...modules import sparse as sp
8
+ from ...modules.norm import LayerNorm32
9
+
10
+
11
+ class SparseResBlock3d(nn.Module):
12
+ def __init__(
13
+ self,
14
+ channels: int,
15
+ out_channels: Optional[int] = None,
16
+ downsample: bool = False,
17
+ upsample: bool = False,
18
+ resample_mode: Literal['nearest', 'spatial2channel'] = 'nearest',
19
+ use_checkpoint: bool = False,
20
+ ):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.out_channels = out_channels or channels
24
+ self.downsample = downsample
25
+ self.upsample = upsample
26
+ self.resample_mode = resample_mode
27
+ self.use_checkpoint = use_checkpoint
28
+
29
+ assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
30
+
31
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
32
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
33
+ if resample_mode == 'nearest':
34
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
35
+ elif resample_mode =='spatial2channel' and not self.downsample:
36
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3)
37
+ elif resample_mode =='spatial2channel' and self.downsample:
38
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3)
39
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
40
+ if resample_mode == 'nearest':
41
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
42
+ elif resample_mode =='spatial2channel' and self.downsample:
43
+ self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1))
44
+ elif resample_mode =='spatial2channel' and not self.downsample:
45
+ self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
46
+ self.updown = None
47
+ if self.downsample:
48
+ if resample_mode == 'nearest':
49
+ self.updown = sp.SparseDownsample(2)
50
+ elif resample_mode =='spatial2channel':
51
+ self.updown = sp.SparseSpatial2Channel(2)
52
+ elif self.upsample:
53
+ self.to_subdiv = sp.SparseLinear(channels, 8)
54
+ if resample_mode == 'nearest':
55
+ self.updown = sp.SparseUpsample(2)
56
+ elif resample_mode =='spatial2channel':
57
+ self.updown = sp.SparseChannel2Spatial(2)
58
+
59
+ def _updown(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
60
+ if self.downsample:
61
+ x = self.updown(x)
62
+ elif self.upsample:
63
+ x = self.updown(x, subdiv.replace(subdiv.feats > 0))
64
+ return x
65
+
66
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
67
+ subdiv = None
68
+ if self.upsample:
69
+ subdiv = self.to_subdiv(x)
70
+ h = x.replace(self.norm1(x.feats))
71
+ h = h.replace(F.silu(h.feats))
72
+ if self.resample_mode == 'spatial2channel':
73
+ h = self.conv1(h)
74
+ h = self._updown(h, subdiv)
75
+ x = self._updown(x, subdiv)
76
+ if self.resample_mode == 'nearest':
77
+ h = self.conv1(h)
78
+ h = h.replace(self.norm2(h.feats))
79
+ h = h.replace(F.silu(h.feats))
80
+ h = self.conv2(h)
81
+ h = h + self.skip_connection(x)
82
+ if self.upsample:
83
+ return h, subdiv
84
+ return h
85
+
86
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
87
+ if self.use_checkpoint:
88
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
89
+ else:
90
+ return self._forward(x)
91
+
92
+
93
+ class SparseResBlockDownsample3d(nn.Module):
94
+ def __init__(
95
+ self,
96
+ channels: int,
97
+ out_channels: Optional[int] = None,
98
+ use_checkpoint: bool = False,
99
+ ):
100
+ super().__init__()
101
+ self.channels = channels
102
+ self.out_channels = out_channels or channels
103
+ self.use_checkpoint = use_checkpoint
104
+
105
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
106
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
107
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
108
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
109
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
110
+ self.updown = sp.SparseDownsample(2)
111
+
112
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
113
+ h = x.replace(self.norm1(x.feats))
114
+ h = h.replace(F.silu(h.feats))
115
+ h = self.updown(h)
116
+ x = self.updown(x)
117
+ h = self.conv1(h)
118
+ h = h.replace(self.norm2(h.feats))
119
+ h = h.replace(F.silu(h.feats))
120
+ h = self.conv2(h)
121
+ h = h + self.skip_connection(x)
122
+ return h
123
+
124
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
125
+ if self.use_checkpoint:
126
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
127
+ else:
128
+ return self._forward(x)
129
+
130
+
131
+ class SparseResBlockUpsample3d(nn.Module):
132
+ def __init__(
133
+ self,
134
+ channels: int,
135
+ out_channels: Optional[int] = None,
136
+ use_checkpoint: bool = False,
137
+ pred_subdiv: bool = True,
138
+ ):
139
+ super().__init__()
140
+ self.channels = channels
141
+ self.out_channels = out_channels or channels
142
+ self.use_checkpoint = use_checkpoint
143
+ self.pred_subdiv = pred_subdiv
144
+
145
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
146
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
147
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
148
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
149
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
150
+ if self.pred_subdiv:
151
+ self.to_subdiv = sp.SparseLinear(channels, 8)
152
+ self.updown = sp.SparseUpsample(2)
153
+
154
+ def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
155
+ if self.pred_subdiv:
156
+ subdiv = self.to_subdiv(x)
157
+ h = x.replace(self.norm1(x.feats))
158
+ h = h.replace(F.silu(h.feats))
159
+ subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
160
+ h = self.updown(h, subdiv_binarized)
161
+ x = self.updown(x, subdiv_binarized)
162
+ h = self.conv1(h)
163
+ h = h.replace(self.norm2(h.feats))
164
+ h = h.replace(F.silu(h.feats))
165
+ h = self.conv2(h)
166
+ h = h + self.skip_connection(x)
167
+ if self.pred_subdiv:
168
+ return h, subdiv
169
+ else:
170
+ return h
171
+
172
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
173
+ if self.use_checkpoint:
174
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
175
+ else:
176
+ return self._forward(x)
177
+
178
+
179
+ class SparseResBlockS2C3d(nn.Module):
180
+ def __init__(
181
+ self,
182
+ channels: int,
183
+ out_channels: Optional[int] = None,
184
+ use_checkpoint: bool = False,
185
+ ):
186
+ super().__init__()
187
+ self.channels = channels
188
+ self.out_channels = out_channels or channels
189
+ self.use_checkpoint = use_checkpoint
190
+
191
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
192
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
193
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3)
194
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
195
+ self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1))
196
+ self.updown = sp.SparseSpatial2Channel(2)
197
+
198
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
199
+ h = x.replace(self.norm1(x.feats))
200
+ h = h.replace(F.silu(h.feats))
201
+ h = self.conv1(h)
202
+ h = self.updown(h)
203
+ x = self.updown(x)
204
+ h = h.replace(self.norm2(h.feats))
205
+ h = h.replace(F.silu(h.feats))
206
+ h = self.conv2(h)
207
+ h = h + self.skip_connection(x)
208
+ return h
209
+
210
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
211
+ if self.use_checkpoint:
212
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
213
+ else:
214
+ return self._forward(x)
215
+
216
+
217
+ class SparseResBlockC2S3d(nn.Module):
218
+ def __init__(
219
+ self,
220
+ channels: int,
221
+ out_channels: Optional[int] = None,
222
+ use_checkpoint: bool = False,
223
+ pred_subdiv: bool = True,
224
+ ):
225
+ super().__init__()
226
+ self.channels = channels
227
+ self.out_channels = out_channels or channels
228
+ self.use_checkpoint = use_checkpoint
229
+ self.pred_subdiv = pred_subdiv
230
+
231
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
232
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
233
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3)
234
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
235
+ self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
236
+ if pred_subdiv:
237
+ self.to_subdiv = sp.SparseLinear(channels, 8)
238
+ self.updown = sp.SparseChannel2Spatial(2)
239
+
240
+ def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
241
+ if self.pred_subdiv:
242
+ subdiv = self.to_subdiv(x)
243
+ h = x.replace(self.norm1(x.feats))
244
+ h = h.replace(F.silu(h.feats))
245
+ h = self.conv1(h)
246
+ subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
247
+ h = self.updown(h, subdiv_binarized)
248
+ x = self.updown(x, subdiv_binarized)
249
+ h = h.replace(self.norm2(h.feats))
250
+ h = h.replace(F.silu(h.feats))
251
+ h = self.conv2(h)
252
+ h = h + self.skip_connection(x)
253
+ if self.pred_subdiv:
254
+ return h, subdiv
255
+ else:
256
+ return h
257
+
258
+ def forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor:
259
+ if self.use_checkpoint:
260
+ return torch.utils.checkpoint.checkpoint(self._forward, x, subdiv, use_reentrant=False)
261
+ else:
262
+ return self._forward(x, subdiv)
263
+
264
+
265
+ class SparseConvNeXtBlock3d(nn.Module):
266
+ def __init__(
267
+ self,
268
+ channels: int,
269
+ mlp_ratio: float = 4.0,
270
+ use_checkpoint: bool = False,
271
+ ):
272
+ super().__init__()
273
+ self.channels = channels
274
+ self.use_checkpoint = use_checkpoint
275
+
276
+ self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
277
+ self.conv = sp.SparseConv3d(channels, channels, 3)
278
+ self.mlp = nn.Sequential(
279
+ nn.Linear(channels, int(channels * mlp_ratio)),
280
+ nn.SiLU(),
281
+ zero_module(nn.Linear(int(channels * mlp_ratio), channels)),
282
+ )
283
+
284
+ def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
285
+ h = self.conv(x)
286
+ h = h.replace(self.norm(h.feats))
287
+ h = h.replace(self.mlp(h.feats))
288
+ return h + x
289
+
290
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
291
+ if self.use_checkpoint:
292
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
293
+ else:
294
+ return self._forward(x)
295
+
296
+
297
+ class SparseUnetVaeEncoder(nn.Module):
298
+ """
299
+ Sparse Swin Transformer Unet VAE model.
300
+ """
301
+ def __init__(
302
+ self,
303
+ in_channels: int,
304
+ model_channels: List[int],
305
+ latent_channels: int,
306
+ num_blocks: List[int],
307
+ block_type: List[str],
308
+ down_block_type: List[str],
309
+ block_args: List[Dict[str, Any]],
310
+ use_fp16: bool = False,
311
+ ):
312
+ super().__init__()
313
+ self.in_channels = in_channels
314
+ self.model_channels = model_channels
315
+ self.num_blocks = num_blocks
316
+ self.dtype = torch.float16 if use_fp16 else torch.float32
317
+ self.dtype = torch.float16 if use_fp16 else torch.float32
318
+
319
+ self.input_layer = sp.SparseLinear(in_channels, model_channels[0])
320
+ self.to_latent = sp.SparseLinear(model_channels[-1], 2 * latent_channels)
321
+
322
+ self.blocks = nn.ModuleList([])
323
+ for i in range(len(num_blocks)):
324
+ self.blocks.append(nn.ModuleList([]))
325
+ for j in range(num_blocks[i]):
326
+ self.blocks[-1].append(
327
+ globals()[block_type[i]](
328
+ model_channels[i],
329
+ **block_args[i],
330
+ )
331
+ )
332
+ if i < len(num_blocks) - 1:
333
+ self.blocks[-1].append(
334
+ globals()[down_block_type[i]](
335
+ model_channels[i],
336
+ model_channels[i+1],
337
+ **block_args[i],
338
+ )
339
+ )
340
+
341
+ self.initialize_weights()
342
+ if use_fp16:
343
+ self.convert_to_fp16()
344
+
345
+ @property
346
+ def device(self) -> torch.device:
347
+ """
348
+ Return the device of the model.
349
+ """
350
+ return next(self.parameters()).device
351
+
352
+ def convert_to_fp16(self) -> None:
353
+ """
354
+ Convert the torso of the model to float16.
355
+ """
356
+ self.blocks.apply(convert_module_to_f16)
357
+
358
+ def convert_to_fp32(self) -> None:
359
+ """
360
+ Convert the torso of the model to float32.
361
+ """
362
+ self.blocks.apply(convert_module_to_f32)
363
+
364
+ def initialize_weights(self) -> None:
365
+ # Initialize transformer layers:
366
+ def _basic_init(module):
367
+ if isinstance(module, nn.Linear):
368
+ torch.nn.init.xavier_uniform_(module.weight)
369
+ if module.bias is not None:
370
+ nn.init.constant_(module.bias, 0)
371
+ self.apply(_basic_init)
372
+
373
+ def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False):
374
+ h = self.input_layer(x)
375
+ h = h.type(self.dtype)
376
+ for i, res in enumerate(self.blocks):
377
+ for j, block in enumerate(res):
378
+ h = block(h)
379
+ h = h.type(x.dtype)
380
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
381
+ h = self.to_latent(h)
382
+
383
+ # Sample from the posterior distribution
384
+ mean, logvar = h.feats.chunk(2, dim=-1)
385
+ if sample_posterior:
386
+ std = torch.exp(0.5 * logvar)
387
+ z = mean + std * torch.randn_like(std)
388
+ else:
389
+ z = mean
390
+ z = h.replace(z)
391
+
392
+ if return_raw:
393
+ return z, mean, logvar
394
+ else:
395
+ return z
396
+
397
+
398
+ class SparseUnetVaeDecoder(nn.Module):
399
+ """
400
+ Sparse Swin Transformer Unet VAE model.
401
+ """
402
+ def __init__(
403
+ self,
404
+ out_channels: int,
405
+ model_channels: List[int],
406
+ latent_channels: int,
407
+ num_blocks: List[int],
408
+ block_type: List[str],
409
+ up_block_type: List[str],
410
+ block_args: List[Dict[str, Any]],
411
+ use_fp16: bool = False,
412
+ pred_subdiv: bool = True,
413
+ ):
414
+ super().__init__()
415
+ self.out_channels = out_channels
416
+ self.model_channels = model_channels
417
+ self.num_blocks = num_blocks
418
+ self.use_fp16 = use_fp16
419
+ self.pred_subdiv = pred_subdiv
420
+ self.dtype = torch.float16 if use_fp16 else torch.float32
421
+ self.low_vram = False
422
+
423
+ self.output_layer = sp.SparseLinear(model_channels[-1], out_channels)
424
+ self.from_latent = sp.SparseLinear(latent_channels, model_channels[0])
425
+
426
+ self.blocks = nn.ModuleList([])
427
+ for i in range(len(num_blocks)):
428
+ self.blocks.append(nn.ModuleList([]))
429
+ for j in range(num_blocks[i]):
430
+ self.blocks[-1].append(
431
+ globals()[block_type[i]](
432
+ model_channels[i],
433
+ **block_args[i],
434
+ )
435
+ )
436
+ if i < len(num_blocks) - 1:
437
+ self.blocks[-1].append(
438
+ globals()[up_block_type[i]](
439
+ model_channels[i],
440
+ model_channels[i+1],
441
+ pred_subdiv=pred_subdiv,
442
+ **block_args[i],
443
+ )
444
+ )
445
+
446
+ self.initialize_weights()
447
+ if use_fp16:
448
+ self.convert_to_fp16()
449
+
450
+ @property
451
+ def device(self) -> torch.device:
452
+ """
453
+ Return the device of the model.
454
+ """
455
+ return next(self.parameters()).device
456
+
457
+ def convert_to_fp16(self) -> None:
458
+ """
459
+ Convert the torso of the model to float16.
460
+ """
461
+ self.blocks.apply(convert_module_to_f16)
462
+
463
+ def convert_to_fp32(self) -> None:
464
+ """
465
+ Convert the torso of the model to float32.
466
+ """
467
+ self.blocks.apply(convert_module_to_f32)
468
+
469
+ def initialize_weights(self) -> None:
470
+ # Initialize transformer layers:
471
+ def _basic_init(module):
472
+ if isinstance(module, nn.Linear):
473
+ torch.nn.init.xavier_uniform_(module.weight)
474
+ if module.bias is not None:
475
+ nn.init.constant_(module.bias, 0)
476
+ self.apply(_basic_init)
477
+
478
+ def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor]] = None, return_subs: bool = False) -> sp.SparseTensor:
479
+ assert guide_subs is None or self.pred_subdiv == False, "Only decoders with pred_subdiv=False can be used with guide_subs"
480
+ assert return_subs == False or self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with return_subs"
481
+
482
+ h = self.from_latent(x)
483
+ h = h.type(self.dtype)
484
+ subs_gt = []
485
+ subs = []
486
+ for i, res in enumerate(self.blocks):
487
+ for j, block in enumerate(res):
488
+ if i < len(self.blocks) - 1 and j == len(res) - 1:
489
+ if self.pred_subdiv:
490
+ if self.training:
491
+ subs_gt.append(h.get_spatial_cache('subdivision'))
492
+ h, sub = block(h)
493
+ subs.append(sub)
494
+ else:
495
+ h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None)
496
+ else:
497
+ h = block(h)
498
+ h = h.type(x.dtype)
499
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
500
+ h = self.output_layer(h)
501
+ if self.training and self.pred_subdiv:
502
+ return h, subs_gt, subs
503
+ else:
504
+ if return_subs:
505
+ return h, subs
506
+ else:
507
+ return h
508
+
509
+ def upsample(self, x: sp.SparseTensor, upsample_times: int) -> torch.Tensor:
510
+ assert self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with upsampling"
511
+
512
+ h = self.from_latent(x)
513
+ h = h.type(self.dtype)
514
+ for i, res in enumerate(self.blocks):
515
+ if i == upsample_times:
516
+ return h.coords
517
+ for j, block in enumerate(res):
518
+ if i < len(self.blocks) - 1 and j == len(res) - 1:
519
+ h, sub = block(h)
520
+ else:
521
+ h = block(h)
522
+
trellis2/models/sparse_elastic_mixin.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import *
3
+ import math
4
+ from ..modules import sparse as sp
5
+ from ..utils.elastic_utils import ElasticModuleMixin
6
+
7
+
8
+ class SparseTransformerElasticMixin(ElasticModuleMixin):
9
+ def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
10
+ return x.feats.shape[0]
11
+
12
+ @contextmanager
13
+ def with_mem_ratio(self, mem_ratio=1.0):
14
+ if mem_ratio == 1.0:
15
+ yield 1.0
16
+ return
17
+ num_blocks = len(self.blocks)
18
+ num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
19
+ exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
20
+ for i in range(num_blocks):
21
+ self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
22
+ yield exact_mem_ratio
23
+ for i in range(num_blocks):
24
+ self.blocks[i].use_checkpoint = False
trellis2/models/sparse_structure_flow.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from functools import partial
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from ..trainers.utils import str_to_dtype
8
+ from ..modules.utils import convert_module_to, manual_cast
9
+ from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
10
+ from ..modules.attention import RotaryPositionEmbedder
11
+
12
+
13
+ class TimestepEmbedder(nn.Module):
14
+ """
15
+ Embeds scalar timesteps into vector representations.
16
+ """
17
+ def __init__(self, hidden_size, frequency_embedding_size=256):
18
+ super().__init__()
19
+ self.mlp = nn.Sequential(
20
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
21
+ nn.SiLU(),
22
+ nn.Linear(hidden_size, hidden_size, bias=True),
23
+ )
24
+ self.frequency_embedding_size = frequency_embedding_size
25
+
26
+ @staticmethod
27
+ def timestep_embedding(t, dim, max_period=10000):
28
+ """
29
+ Create sinusoidal timestep embeddings.
30
+
31
+ Args:
32
+ t: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ dim: the dimension of the output.
35
+ max_period: controls the minimum frequency of the embeddings.
36
+
37
+ Returns:
38
+ an (N, D) Tensor of positional embeddings.
39
+ """
40
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
41
+ half = dim // 2
42
+ freqs = torch.exp(
43
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
44
+ ).to(device=t.device)
45
+ args = t[:, None].float() * freqs[None]
46
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
47
+ if dim % 2:
48
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
49
+ return embedding
50
+
51
+ def forward(self, t):
52
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
53
+ t_emb = self.mlp(t_freq)
54
+ return t_emb
55
+
56
+
57
+ class SparseStructureFlowModel(nn.Module):
58
+ def __init__(
59
+ self,
60
+ resolution: int,
61
+ in_channels: int,
62
+ model_channels: int,
63
+ cond_channels: int,
64
+ out_channels: int,
65
+ num_blocks: int,
66
+ num_heads: Optional[int] = None,
67
+ num_head_channels: Optional[int] = 64,
68
+ mlp_ratio: float = 4,
69
+ pe_mode: Literal["ape", "rope"] = "ape",
70
+ rope_freq: Tuple[float, float] = (1.0, 10000.0),
71
+ dtype: str = 'float32',
72
+ use_checkpoint: bool = False,
73
+ share_mod: bool = False,
74
+ initialization: str = 'vanilla',
75
+ qk_rms_norm: bool = False,
76
+ qk_rms_norm_cross: bool = False,
77
+ **kwargs
78
+ ):
79
+ super().__init__()
80
+ self.resolution = resolution
81
+ self.in_channels = in_channels
82
+ self.model_channels = model_channels
83
+ self.cond_channels = cond_channels
84
+ self.out_channels = out_channels
85
+ self.num_blocks = num_blocks
86
+ self.num_heads = num_heads or model_channels // num_head_channels
87
+ self.mlp_ratio = mlp_ratio
88
+ self.pe_mode = pe_mode
89
+ self.use_checkpoint = use_checkpoint
90
+ self.share_mod = share_mod
91
+ self.initialization = initialization
92
+ self.qk_rms_norm = qk_rms_norm
93
+ self.qk_rms_norm_cross = qk_rms_norm_cross
94
+ self.dtype = str_to_dtype(dtype)
95
+
96
+ self.t_embedder = TimestepEmbedder(model_channels)
97
+ if share_mod:
98
+ self.adaLN_modulation = nn.Sequential(
99
+ nn.SiLU(),
100
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
101
+ )
102
+
103
+ if pe_mode == "ape":
104
+ pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
105
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
106
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
107
+ pos_emb = pos_embedder(coords)
108
+ self.register_buffer("pos_emb", pos_emb)
109
+ elif pe_mode == "rope":
110
+ pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3)
111
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
112
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
113
+ rope_phases = pos_embedder(coords)
114
+ self.register_buffer("rope_phases", rope_phases)
115
+
116
+ if pe_mode != "rope":
117
+ self.rope_phases = None
118
+
119
+ self.input_layer = nn.Linear(in_channels, model_channels)
120
+
121
+ self.blocks = nn.ModuleList([
122
+ ModulatedTransformerCrossBlock(
123
+ model_channels,
124
+ cond_channels,
125
+ num_heads=self.num_heads,
126
+ mlp_ratio=self.mlp_ratio,
127
+ attn_mode='full',
128
+ use_checkpoint=self.use_checkpoint,
129
+ use_rope=(pe_mode == "rope"),
130
+ rope_freq=rope_freq,
131
+ share_mod=share_mod,
132
+ qk_rms_norm=self.qk_rms_norm,
133
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
134
+ )
135
+ for _ in range(num_blocks)
136
+ ])
137
+
138
+ self.out_layer = nn.Linear(model_channels, out_channels)
139
+
140
+ self.initialize_weights()
141
+ self.convert_to(self.dtype)
142
+
143
+ @property
144
+ def device(self) -> torch.device:
145
+ """
146
+ Return the device of the model.
147
+ """
148
+ return next(self.parameters()).device
149
+
150
+ def convert_to(self, dtype: torch.dtype) -> None:
151
+ """
152
+ Convert the torso of the model to the specified dtype.
153
+ """
154
+ self.dtype = dtype
155
+ self.blocks.apply(partial(convert_module_to, dtype=dtype))
156
+
157
+ def initialize_weights(self) -> None:
158
+ if self.initialization == 'vanilla':
159
+ # Initialize transformer layers:
160
+ def _basic_init(module):
161
+ if isinstance(module, nn.Linear):
162
+ torch.nn.init.xavier_uniform_(module.weight)
163
+ if module.bias is not None:
164
+ nn.init.constant_(module.bias, 0)
165
+ self.apply(_basic_init)
166
+
167
+ # Initialize timestep embedding MLP:
168
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
169
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
170
+
171
+ # Zero-out adaLN modulation layers in DiT blocks:
172
+ if self.share_mod:
173
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
174
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
175
+ else:
176
+ for block in self.blocks:
177
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
178
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
179
+
180
+ # Zero-out output layers:
181
+ nn.init.constant_(self.out_layer.weight, 0)
182
+ nn.init.constant_(self.out_layer.bias, 0)
183
+
184
+ elif self.initialization == 'scaled':
185
+ # Initialize transformer layers:
186
+ def _basic_init(module):
187
+ if isinstance(module, nn.Linear):
188
+ torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels)))
189
+ if module.bias is not None:
190
+ nn.init.constant_(module.bias, 0)
191
+ self.apply(_basic_init)
192
+
193
+ # Scaled init for to_out and ffn2
194
+ def _scaled_init(module):
195
+ if isinstance(module, nn.Linear):
196
+ torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels))
197
+ if module.bias is not None:
198
+ nn.init.constant_(module.bias, 0)
199
+ for block in self.blocks:
200
+ block.self_attn.to_out.apply(_scaled_init)
201
+ block.cross_attn.to_out.apply(_scaled_init)
202
+ block.mlp.mlp[2].apply(_scaled_init)
203
+
204
+ # Initialize input layer to make the initial representation have variance 1
205
+ nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels))
206
+ nn.init.zeros_(self.input_layer.bias)
207
+
208
+ # Initialize timestep embedding MLP:
209
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
210
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
211
+
212
+ # Zero-out adaLN modulation layers in DiT blocks:
213
+ if self.share_mod:
214
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
215
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
216
+ else:
217
+ for block in self.blocks:
218
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
219
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
220
+
221
+ # Zero-out output layers:
222
+ nn.init.constant_(self.out_layer.weight, 0)
223
+ nn.init.constant_(self.out_layer.bias, 0)
224
+
225
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
226
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
227
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
228
+
229
+ h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
230
+
231
+ h = self.input_layer(h)
232
+ if self.pe_mode == "ape":
233
+ h = h + self.pos_emb[None]
234
+ t_emb = self.t_embedder(t)
235
+ if self.share_mod:
236
+ t_emb = self.adaLN_modulation(t_emb)
237
+ t_emb = manual_cast(t_emb, self.dtype)
238
+ h = manual_cast(h, self.dtype)
239
+ cond = manual_cast(cond, self.dtype)
240
+ for block in self.blocks:
241
+ h = block(h, t_emb, cond, self.rope_phases)
242
+ h = manual_cast(h, x.dtype)
243
+ h = F.layer_norm(h, h.shape[-1:])
244
+ h = self.out_layer(h)
245
+
246
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
247
+
248
+ return h
trellis2/models/sparse_structure_vae.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ..modules.norm import GroupNorm32, ChannelLayerNorm32
6
+ from ..modules.spatial import pixel_shuffle_3d
7
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
8
+
9
+
10
+ def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
11
+ """
12
+ Return a normalization layer.
13
+ """
14
+ if norm_type == "group":
15
+ return GroupNorm32(32, *args, **kwargs)
16
+ elif norm_type == "layer":
17
+ return ChannelLayerNorm32(*args, **kwargs)
18
+ else:
19
+ raise ValueError(f"Invalid norm type {norm_type}")
20
+
21
+
22
+ class ResBlock3d(nn.Module):
23
+ def __init__(
24
+ self,
25
+ channels: int,
26
+ out_channels: Optional[int] = None,
27
+ norm_type: Literal["group", "layer"] = "layer",
28
+ ):
29
+ super().__init__()
30
+ self.channels = channels
31
+ self.out_channels = out_channels or channels
32
+
33
+ self.norm1 = norm_layer(norm_type, channels)
34
+ self.norm2 = norm_layer(norm_type, self.out_channels)
35
+ self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
36
+ self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
37
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ h = self.norm1(x)
41
+ h = F.silu(h)
42
+ h = self.conv1(h)
43
+ h = self.norm2(h)
44
+ h = F.silu(h)
45
+ h = self.conv2(h)
46
+ h = h + self.skip_connection(x)
47
+ return h
48
+
49
+
50
+ class DownsampleBlock3d(nn.Module):
51
+ def __init__(
52
+ self,
53
+ in_channels: int,
54
+ out_channels: int,
55
+ mode: Literal["conv", "avgpool"] = "conv",
56
+ ):
57
+ assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
58
+
59
+ super().__init__()
60
+ self.in_channels = in_channels
61
+ self.out_channels = out_channels
62
+
63
+ if mode == "conv":
64
+ self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
65
+ elif mode == "avgpool":
66
+ assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ if hasattr(self, "conv"):
70
+ return self.conv(x)
71
+ else:
72
+ return F.avg_pool3d(x, 2)
73
+
74
+
75
+ class UpsampleBlock3d(nn.Module):
76
+ def __init__(
77
+ self,
78
+ in_channels: int,
79
+ out_channels: int,
80
+ mode: Literal["conv", "nearest"] = "conv",
81
+ ):
82
+ assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
83
+
84
+ super().__init__()
85
+ self.in_channels = in_channels
86
+ self.out_channels = out_channels
87
+
88
+ if mode == "conv":
89
+ self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
90
+ elif mode == "nearest":
91
+ assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ if hasattr(self, "conv"):
95
+ x = self.conv(x)
96
+ return pixel_shuffle_3d(x, 2)
97
+ else:
98
+ return F.interpolate(x, scale_factor=2, mode="nearest")
99
+
100
+
101
+ class SparseStructureEncoder(nn.Module):
102
+ """
103
+ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
104
+
105
+ Args:
106
+ in_channels (int): Channels of the input.
107
+ latent_channels (int): Channels of the latent representation.
108
+ num_res_blocks (int): Number of residual blocks at each resolution.
109
+ channels (List[int]): Channels of the encoder blocks.
110
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
111
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
112
+ use_fp16 (bool): Whether to use FP16.
113
+ """
114
+ def __init__(
115
+ self,
116
+ in_channels: int,
117
+ latent_channels: int,
118
+ num_res_blocks: int,
119
+ channels: List[int],
120
+ num_res_blocks_middle: int = 2,
121
+ norm_type: Literal["group", "layer"] = "layer",
122
+ use_fp16: bool = False,
123
+ ):
124
+ super().__init__()
125
+ self.in_channels = in_channels
126
+ self.latent_channels = latent_channels
127
+ self.num_res_blocks = num_res_blocks
128
+ self.channels = channels
129
+ self.num_res_blocks_middle = num_res_blocks_middle
130
+ self.norm_type = norm_type
131
+ self.use_fp16 = use_fp16
132
+ self.dtype = torch.float16 if use_fp16 else torch.float32
133
+
134
+ self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
135
+
136
+ self.blocks = nn.ModuleList([])
137
+ for i, ch in enumerate(channels):
138
+ self.blocks.extend([
139
+ ResBlock3d(ch, ch)
140
+ for _ in range(num_res_blocks)
141
+ ])
142
+ if i < len(channels) - 1:
143
+ self.blocks.append(
144
+ DownsampleBlock3d(ch, channels[i+1])
145
+ )
146
+
147
+ self.middle_block = nn.Sequential(*[
148
+ ResBlock3d(channels[-1], channels[-1])
149
+ for _ in range(num_res_blocks_middle)
150
+ ])
151
+
152
+ self.out_layer = nn.Sequential(
153
+ norm_layer(norm_type, channels[-1]),
154
+ nn.SiLU(),
155
+ nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
156
+ )
157
+
158
+ if use_fp16:
159
+ self.convert_to_fp16()
160
+
161
+ @property
162
+ def device(self) -> torch.device:
163
+ """
164
+ Return the device of the model.
165
+ """
166
+ return next(self.parameters()).device
167
+
168
+ def convert_to_fp16(self) -> None:
169
+ """
170
+ Convert the torso of the model to float16.
171
+ """
172
+ self.use_fp16 = True
173
+ self.dtype = torch.float16
174
+ self.blocks.apply(convert_module_to_f16)
175
+ self.middle_block.apply(convert_module_to_f16)
176
+
177
+ def convert_to_fp32(self) -> None:
178
+ """
179
+ Convert the torso of the model to float32.
180
+ """
181
+ self.use_fp16 = False
182
+ self.dtype = torch.float32
183
+ self.blocks.apply(convert_module_to_f32)
184
+ self.middle_block.apply(convert_module_to_f32)
185
+
186
+ def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
187
+ h = self.input_layer(x)
188
+ h = h.type(self.dtype)
189
+
190
+ for block in self.blocks:
191
+ h = block(h)
192
+ h = self.middle_block(h)
193
+
194
+ h = h.type(x.dtype)
195
+ h = self.out_layer(h)
196
+
197
+ mean, logvar = h.chunk(2, dim=1)
198
+
199
+ if sample_posterior:
200
+ std = torch.exp(0.5 * logvar)
201
+ z = mean + std * torch.randn_like(std)
202
+ else:
203
+ z = mean
204
+
205
+ if return_raw:
206
+ return z, mean, logvar
207
+ return z
208
+
209
+
210
+ class SparseStructureDecoder(nn.Module):
211
+ """
212
+ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
213
+
214
+ Args:
215
+ out_channels (int): Channels of the output.
216
+ latent_channels (int): Channels of the latent representation.
217
+ num_res_blocks (int): Number of residual blocks at each resolution.
218
+ channels (List[int]): Channels of the decoder blocks.
219
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
220
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
221
+ use_fp16 (bool): Whether to use FP16.
222
+ """
223
+ def __init__(
224
+ self,
225
+ out_channels: int,
226
+ latent_channels: int,
227
+ num_res_blocks: int,
228
+ channels: List[int],
229
+ num_res_blocks_middle: int = 2,
230
+ norm_type: Literal["group", "layer"] = "layer",
231
+ use_fp16: bool = False,
232
+ ):
233
+ super().__init__()
234
+ self.out_channels = out_channels
235
+ self.latent_channels = latent_channels
236
+ self.num_res_blocks = num_res_blocks
237
+ self.channels = channels
238
+ self.num_res_blocks_middle = num_res_blocks_middle
239
+ self.norm_type = norm_type
240
+ self.use_fp16 = use_fp16
241
+ self.dtype = torch.float16 if use_fp16 else torch.float32
242
+
243
+ self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
244
+
245
+ self.middle_block = nn.Sequential(*[
246
+ ResBlock3d(channels[0], channels[0])
247
+ for _ in range(num_res_blocks_middle)
248
+ ])
249
+
250
+ self.blocks = nn.ModuleList([])
251
+ for i, ch in enumerate(channels):
252
+ self.blocks.extend([
253
+ ResBlock3d(ch, ch)
254
+ for _ in range(num_res_blocks)
255
+ ])
256
+ if i < len(channels) - 1:
257
+ self.blocks.append(
258
+ UpsampleBlock3d(ch, channels[i+1])
259
+ )
260
+
261
+ self.out_layer = nn.Sequential(
262
+ norm_layer(norm_type, channels[-1]),
263
+ nn.SiLU(),
264
+ nn.Conv3d(channels[-1], out_channels, 3, padding=1)
265
+ )
266
+
267
+ if use_fp16:
268
+ self.convert_to_fp16()
269
+
270
+ @property
271
+ def device(self) -> torch.device:
272
+ """
273
+ Return the device of the model.
274
+ """
275
+ return next(self.parameters()).device
276
+
277
+ def convert_to_fp16(self) -> None:
278
+ """
279
+ Convert the torso of the model to float16.
280
+ """
281
+ self.use_fp16 = True
282
+ self.dtype = torch.float16
283
+ self.blocks.apply(convert_module_to_f16)
284
+ self.middle_block.apply(convert_module_to_f16)
285
+
286
+ def convert_to_fp32(self) -> None:
287
+ """
288
+ Convert the torso of the model to float32.
289
+ """
290
+ self.use_fp16 = False
291
+ self.dtype = torch.float32
292
+ self.blocks.apply(convert_module_to_f32)
293
+ self.middle_block.apply(convert_module_to_f32)
294
+
295
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
296
+ h = self.input_layer(x)
297
+
298
+ h = h.type(self.dtype)
299
+
300
+ h = self.middle_block(h)
301
+ for block in self.blocks:
302
+ h = block(h)
303
+
304
+ h = h.type(x.dtype)
305
+ h = self.out_layer(h)
306
+ return h
trellis2/models/structured_latent_flow.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from functools import partial
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from ..trainers.utils import str_to_dtype
8
+ from ..modules.utils import convert_module_to, manual_cast
9
+ from ..modules.transformer import AbsolutePositionEmbedder
10
+ from ..modules import sparse as sp
11
+ from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
12
+ from .sparse_structure_flow import TimestepEmbedder
13
+ from .sparse_elastic_mixin import SparseTransformerElasticMixin
14
+
15
+
16
+ class SLatFlowModel(nn.Module):
17
+ def __init__(
18
+ self,
19
+ resolution: int,
20
+ in_channels: int,
21
+ model_channels: int,
22
+ cond_channels: int,
23
+ out_channels: int,
24
+ num_blocks: int,
25
+ num_heads: Optional[int] = None,
26
+ num_head_channels: Optional[int] = 64,
27
+ mlp_ratio: float = 4,
28
+ pe_mode: Literal["ape", "rope"] = "ape",
29
+ rope_freq: Tuple[float, float] = (1.0, 10000.0),
30
+ dtype: str = 'float32',
31
+ use_checkpoint: bool = False,
32
+ share_mod: bool = False,
33
+ initialization: str = 'vanilla',
34
+ qk_rms_norm: bool = False,
35
+ qk_rms_norm_cross: bool = False,
36
+ ):
37
+ super().__init__()
38
+ self.resolution = resolution
39
+ self.in_channels = in_channels
40
+ self.model_channels = model_channels
41
+ self.cond_channels = cond_channels
42
+ self.out_channels = out_channels
43
+ self.num_blocks = num_blocks
44
+ self.num_heads = num_heads or model_channels // num_head_channels
45
+ self.mlp_ratio = mlp_ratio
46
+ self.pe_mode = pe_mode
47
+ self.use_checkpoint = use_checkpoint
48
+ self.share_mod = share_mod
49
+ self.initialization = initialization
50
+ self.qk_rms_norm = qk_rms_norm
51
+ self.qk_rms_norm_cross = qk_rms_norm_cross
52
+ self.dtype = str_to_dtype(dtype)
53
+
54
+ self.t_embedder = TimestepEmbedder(model_channels)
55
+ if share_mod:
56
+ self.adaLN_modulation = nn.Sequential(
57
+ nn.SiLU(),
58
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
59
+ )
60
+
61
+ if pe_mode == "ape":
62
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
63
+
64
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
65
+
66
+ self.blocks = nn.ModuleList([
67
+ ModulatedSparseTransformerCrossBlock(
68
+ model_channels,
69
+ cond_channels,
70
+ num_heads=self.num_heads,
71
+ mlp_ratio=self.mlp_ratio,
72
+ attn_mode='full',
73
+ use_checkpoint=self.use_checkpoint,
74
+ use_rope=(pe_mode == "rope"),
75
+ rope_freq=rope_freq,
76
+ share_mod=self.share_mod,
77
+ qk_rms_norm=self.qk_rms_norm,
78
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
79
+ )
80
+ for _ in range(num_blocks)
81
+ ])
82
+
83
+ self.out_layer = sp.SparseLinear(model_channels, out_channels)
84
+
85
+ self.initialize_weights()
86
+ self.convert_to(self.dtype)
87
+
88
+ @property
89
+ def device(self) -> torch.device:
90
+ """
91
+ Return the device of the model.
92
+ """
93
+ return next(self.parameters()).device
94
+
95
+ def convert_to(self, dtype: torch.dtype) -> None:
96
+ """
97
+ Convert the torso of the model to the specified dtype.
98
+ """
99
+ self.dtype = dtype
100
+ self.blocks.apply(partial(convert_module_to, dtype=dtype))
101
+
102
+ def initialize_weights(self) -> None:
103
+ if self.initialization == 'vanilla':
104
+ # Initialize transformer layers:
105
+ def _basic_init(module):
106
+ if isinstance(module, nn.Linear):
107
+ torch.nn.init.xavier_uniform_(module.weight)
108
+ if module.bias is not None:
109
+ nn.init.constant_(module.bias, 0)
110
+ self.apply(_basic_init)
111
+
112
+ # Initialize timestep embedding MLP:
113
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
114
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
115
+
116
+ # Zero-out adaLN modulation layers in DiT blocks:
117
+ if self.share_mod:
118
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
119
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
120
+ else:
121
+ for block in self.blocks:
122
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
123
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
124
+
125
+ # Zero-out output layers:
126
+ nn.init.constant_(self.out_layer.weight, 0)
127
+ nn.init.constant_(self.out_layer.bias, 0)
128
+
129
+ elif self.initialization == 'scaled':
130
+ # Initialize transformer layers:
131
+ def _basic_init(module):
132
+ if isinstance(module, nn.Linear):
133
+ torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels)))
134
+ if module.bias is not None:
135
+ nn.init.constant_(module.bias, 0)
136
+ self.apply(_basic_init)
137
+
138
+ # Scaled init for to_out and ffn2
139
+ def _scaled_init(module):
140
+ if isinstance(module, nn.Linear):
141
+ torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels))
142
+ if module.bias is not None:
143
+ nn.init.constant_(module.bias, 0)
144
+ for block in self.blocks:
145
+ block.self_attn.to_out.apply(_scaled_init)
146
+ block.cross_attn.to_out.apply(_scaled_init)
147
+ block.mlp.mlp[2].apply(_scaled_init)
148
+
149
+ # Initialize input layer to make the initial representation have variance 1
150
+ nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels))
151
+ nn.init.zeros_(self.input_layer.bias)
152
+
153
+ # Initialize timestep embedding MLP:
154
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
155
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
156
+
157
+ # Zero-out adaLN modulation layers in DiT blocks:
158
+ if self.share_mod:
159
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
160
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
161
+ else:
162
+ for block in self.blocks:
163
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
164
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
165
+
166
+ # Zero-out output layers:
167
+ nn.init.constant_(self.out_layer.weight, 0)
168
+ nn.init.constant_(self.out_layer.bias, 0)
169
+
170
+ def forward(
171
+ self,
172
+ x: sp.SparseTensor,
173
+ t: torch.Tensor,
174
+ cond: Union[torch.Tensor, List[torch.Tensor]],
175
+ concat_cond: Optional[sp.SparseTensor] = None,
176
+ **kwargs
177
+ ) -> sp.SparseTensor:
178
+ if concat_cond is not None:
179
+ x = sp.sparse_cat([x, concat_cond], dim=-1)
180
+ if isinstance(cond, list):
181
+ cond = sp.VarLenTensor.from_tensor_list(cond)
182
+
183
+ h = self.input_layer(x)
184
+ h = manual_cast(h, self.dtype)
185
+ t_emb = self.t_embedder(t)
186
+ if self.share_mod:
187
+ t_emb = self.adaLN_modulation(t_emb)
188
+ t_emb = manual_cast(t_emb, self.dtype)
189
+ cond = manual_cast(cond, self.dtype)
190
+
191
+ if self.pe_mode == "ape":
192
+ pe = self.pos_embedder(h.coords[:, 1:])
193
+ h = h + manual_cast(pe, self.dtype)
194
+ for block in self.blocks:
195
+ h = block(h, t_emb, cond)
196
+
197
+ h = manual_cast(h, x.dtype)
198
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
199
+ h = self.out_layer(h)
200
+ return h
201
+
202
+
203
+ class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel):
204
+ """
205
+ SLat Flow Model with elastic memory management.
206
+ Used for training with low VRAM.
207
+ """
208
+ pass
trellis2/modules/attention/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .full_attn import *
2
+ from .modules import *
3
+ from .rope import *
trellis2/modules/attention/config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'flash_attn'
4
+ DEBUG = False
5
+
6
+ def __from_env():
7
+ import os
8
+
9
+ global BACKEND
10
+ global DEBUG
11
+
12
+ env_attn_backend = os.environ.get('ATTN_BACKEND')
13
+ env_attn_debug = os.environ.get('ATTN_DEBUG')
14
+
15
+ if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'sdpa', 'naive']:
16
+ BACKEND = env_attn_backend
17
+ if env_attn_debug is not None:
18
+ DEBUG = env_attn_debug == '1'
19
+
20
+ print(f"[ATTENTION] Using backend: {BACKEND}")
21
+
22
+
23
+ __from_env()
24
+
25
+
26
+ def set_backend(backend: Literal['xformers', 'flash_attn']):
27
+ global BACKEND
28
+ BACKEND = backend
29
+
30
+ def set_debug(debug: bool):
31
+ global DEBUG
32
+ DEBUG = debug
trellis2/modules/attention/full_attn.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from . import config
5
+
6
+
7
+ __all__ = [
8
+ 'scaled_dot_product_attention',
9
+ ]
10
+
11
+
12
+ def _naive_sdpa(q, k, v):
13
+ """
14
+ Naive implementation of scaled dot product attention.
15
+ """
16
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
17
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
18
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
19
+ scale_factor = 1 / math.sqrt(q.size(-1))
20
+ attn_weight = q @ k.transpose(-2, -1) * scale_factor
21
+ attn_weight = torch.softmax(attn_weight, dim=-1)
22
+ out = attn_weight @ v
23
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
24
+ return out
25
+
26
+
27
+ @overload
28
+ def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
29
+ """
30
+ Apply scaled dot product attention.
31
+
32
+ Args:
33
+ qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
34
+ """
35
+ ...
36
+
37
+ @overload
38
+ def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Apply scaled dot product attention.
41
+
42
+ Args:
43
+ q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
44
+ kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
45
+ """
46
+ ...
47
+
48
+ @overload
49
+ def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
50
+ """
51
+ Apply scaled dot product attention.
52
+
53
+ Args:
54
+ q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
55
+ k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
56
+ v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
57
+
58
+ Note:
59
+ k and v are assumed to have the same coordinate map.
60
+ """
61
+ ...
62
+
63
+ def scaled_dot_product_attention(*args, **kwargs):
64
+ arg_names_dict = {
65
+ 1: ['qkv'],
66
+ 2: ['q', 'kv'],
67
+ 3: ['q', 'k', 'v']
68
+ }
69
+ num_all_args = len(args) + len(kwargs)
70
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
71
+ for key in arg_names_dict[num_all_args][len(args):]:
72
+ assert key in kwargs, f"Missing argument {key}"
73
+
74
+ if num_all_args == 1:
75
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
76
+ assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
77
+ device = qkv.device
78
+
79
+ elif num_all_args == 2:
80
+ q = args[0] if len(args) > 0 else kwargs['q']
81
+ kv = args[1] if len(args) > 1 else kwargs['kv']
82
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
83
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
84
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
85
+ device = q.device
86
+
87
+ elif num_all_args == 3:
88
+ q = args[0] if len(args) > 0 else kwargs['q']
89
+ k = args[1] if len(args) > 1 else kwargs['k']
90
+ v = args[2] if len(args) > 2 else kwargs['v']
91
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
92
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
93
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
94
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
95
+ device = q.device
96
+
97
+ if config.BACKEND == 'xformers':
98
+ if 'xops' not in globals():
99
+ import xformers.ops as xops
100
+ if num_all_args == 1:
101
+ q, k, v = qkv.unbind(dim=2)
102
+ elif num_all_args == 2:
103
+ k, v = kv.unbind(dim=2)
104
+ out = xops.memory_efficient_attention(q, k, v)
105
+ elif config.BACKEND == 'flash_attn':
106
+ if 'flash_attn' not in globals():
107
+ import flash_attn
108
+ if num_all_args == 1:
109
+ out = flash_attn.flash_attn_qkvpacked_func(qkv)
110
+ elif num_all_args == 2:
111
+ out = flash_attn.flash_attn_kvpacked_func(q, kv)
112
+ elif num_all_args == 3:
113
+ out = flash_attn.flash_attn_func(q, k, v)
114
+ elif config.BACKEND == 'flash_attn_3':
115
+ if 'flash_attn_3' not in globals():
116
+ import flash_attn_interface as flash_attn_3
117
+ if num_all_args == 1:
118
+ out = flash_attn_3.flash_attn_qkvpacked_func(qkv)
119
+ elif num_all_args == 2:
120
+ out = flash_attn_3.flash_attn_kvpacked_func(q, kv)
121
+ elif num_all_args == 3:
122
+ out = flash_attn_3.flash_attn_func(q, k, v)
123
+ elif config.BACKEND == 'sdpa':
124
+ if 'sdpa' not in globals():
125
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
126
+ if num_all_args == 1:
127
+ q, k, v = qkv.unbind(dim=2)
128
+ elif num_all_args == 2:
129
+ k, v = kv.unbind(dim=2)
130
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
131
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
132
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
133
+ out = sdpa(q, k, v) # [N, H, L, C]
134
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
135
+ elif config.BACKEND == 'naive':
136
+ if num_all_args == 1:
137
+ q, k, v = qkv.unbind(dim=2)
138
+ elif num_all_args == 2:
139
+ k, v = kv.unbind(dim=2)
140
+ out = _naive_sdpa(q, k, v)
141
+ else:
142
+ raise ValueError(f"Unknown attention module: {config.BACKEND}")
143
+
144
+ return out
trellis2/modules/attention/modules.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .full_attn import scaled_dot_product_attention
6
+ from .rope import RotaryPositionEmbedder
7
+
8
+
9
+ class MultiHeadRMSNorm(nn.Module):
10
+ def __init__(self, dim: int, heads: int):
11
+ super().__init__()
12
+ self.scale = dim ** 0.5
13
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
17
+
18
+
19
+ class MultiHeadAttention(nn.Module):
20
+ def __init__(
21
+ self,
22
+ channels: int,
23
+ num_heads: int,
24
+ ctx_channels: Optional[int]=None,
25
+ type: Literal["self", "cross"] = "self",
26
+ attn_mode: Literal["full", "windowed"] = "full",
27
+ window_size: Optional[int] = None,
28
+ shift_window: Optional[Tuple[int, int, int]] = None,
29
+ qkv_bias: bool = True,
30
+ use_rope: bool = False,
31
+ rope_freq: Tuple[float, float] = (1.0, 10000.0),
32
+ qk_rms_norm: bool = False,
33
+ ):
34
+ super().__init__()
35
+ assert channels % num_heads == 0
36
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
37
+ assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
38
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
39
+
40
+ if attn_mode == "windowed":
41
+ raise NotImplementedError("Windowed attention is not yet implemented")
42
+
43
+ self.channels = channels
44
+ self.head_dim = channels // num_heads
45
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
46
+ self.num_heads = num_heads
47
+ self._type = type
48
+ self.attn_mode = attn_mode
49
+ self.window_size = window_size
50
+ self.shift_window = shift_window
51
+ self.use_rope = use_rope
52
+ self.qk_rms_norm = qk_rms_norm
53
+
54
+ if self._type == "self":
55
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
56
+ else:
57
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
58
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
59
+
60
+ if self.qk_rms_norm:
61
+ self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
62
+ self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
63
+
64
+ self.to_out = nn.Linear(channels, channels)
65
+
66
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
67
+ B, L, C = x.shape
68
+ if self._type == "self":
69
+ qkv = self.to_qkv(x)
70
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
71
+
72
+ if self.attn_mode == "full":
73
+ if self.qk_rms_norm or self.use_rope:
74
+ q, k, v = qkv.unbind(dim=2)
75
+ if self.qk_rms_norm:
76
+ q = self.q_rms_norm(q)
77
+ k = self.k_rms_norm(k)
78
+ if self.use_rope:
79
+ assert phases is not None, "Phases must be provided for RoPE"
80
+ q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
81
+ k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
82
+ h = scaled_dot_product_attention(q, k, v)
83
+ else:
84
+ h = scaled_dot_product_attention(qkv)
85
+ elif self.attn_mode == "windowed":
86
+ raise NotImplementedError("Windowed attention is not yet implemented")
87
+ else:
88
+ Lkv = context.shape[1]
89
+ q = self.to_q(x)
90
+ kv = self.to_kv(context)
91
+ q = q.reshape(B, L, self.num_heads, -1)
92
+ kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
93
+ if self.qk_rms_norm:
94
+ q = self.q_rms_norm(q)
95
+ k, v = kv.unbind(dim=2)
96
+ k = self.k_rms_norm(k)
97
+ h = scaled_dot_product_attention(q, k, v)
98
+ else:
99
+ h = scaled_dot_product_attention(q, kv)
100
+ h = h.reshape(B, L, -1)
101
+ h = self.to_out(h)
102
+ return h
trellis2/modules/attention/rope.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class RotaryPositionEmbedder(nn.Module):
7
+ def __init__(
8
+ self,
9
+ head_dim: int,
10
+ dim: int = 3,
11
+ rope_freq: Tuple[float, float] = (1.0, 10000.0)
12
+ ):
13
+ super().__init__()
14
+ assert head_dim % 2 == 0, "Head dim must be divisible by 2"
15
+ self.head_dim = head_dim
16
+ self.dim = dim
17
+ self.rope_freq = rope_freq
18
+ self.freq_dim = head_dim // 2 // dim
19
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
20
+ self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
21
+
22
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
23
+ self.freqs = self.freqs.to(indices.device)
24
+ phases = torch.outer(indices, self.freqs)
25
+ phases = torch.polar(torch.ones_like(phases), phases)
26
+ return phases
27
+
28
+ @staticmethod
29
+ def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
30
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
31
+ x_rotated = x_complex * phases.unsqueeze(-2)
32
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
33
+ return x_embed
34
+
35
+ def forward(self, indices: torch.Tensor) -> torch.Tensor:
36
+ """
37
+ Args:
38
+ indices (torch.Tensor): [..., N, C] tensor of spatial positions
39
+ """
40
+ assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}"
41
+ phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
42
+ if phases.shape[-1] < self.head_dim // 2:
43
+ padn = self.head_dim // 2 - phases.shape[-1]
44
+ phases = torch.cat([phases, torch.polar(
45
+ torch.ones(*phases.shape[:-1], padn, device=phases.device),
46
+ torch.zeros(*phases.shape[:-1], padn, device=phases.device)
47
+ )], dim=-1)
48
+ return phases
trellis2/modules/norm.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .utils import manual_cast
4
+
5
+
6
+ class LayerNorm32(nn.LayerNorm):
7
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
8
+ x_dtype = x.dtype
9
+ x = manual_cast(x, torch.float32)
10
+ o = super().forward(x)
11
+ return manual_cast(o, x_dtype)
12
+
13
+
14
+ class GroupNorm32(nn.GroupNorm):
15
+ """
16
+ A GroupNorm layer that converts to float32 before the forward pass.
17
+ """
18
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
19
+ x_dtype = x.dtype
20
+ x = manual_cast(x, torch.float32)
21
+ o = super().forward(x)
22
+ return manual_cast(o, x_dtype)
23
+
24
+
25
+ class ChannelLayerNorm32(LayerNorm32):
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ DIM = x.dim()
28
+ x = x.permute(0, *range(2, DIM), 1).contiguous()
29
+ x = super().forward(x)
30
+ x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
31
+ return x
32
+
trellis2/modules/sparse/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import config
2
+ import importlib
3
+
4
+ __attributes = {
5
+ 'VarLenTensor': 'basic',
6
+ 'varlen_cat': 'basic',
7
+ 'varlen_unbind': 'basic',
8
+ 'SparseTensor': 'basic',
9
+ 'sparse_cat': 'basic',
10
+ 'sparse_unbind': 'basic',
11
+ 'SparseGroupNorm': 'norm',
12
+ 'SparseLayerNorm': 'norm',
13
+ 'SparseGroupNorm32': 'norm',
14
+ 'SparseLayerNorm32': 'norm',
15
+ 'SparseReLU': 'nonlinearity',
16
+ 'SparseSiLU': 'nonlinearity',
17
+ 'SparseGELU': 'nonlinearity',
18
+ 'SparseActivation': 'nonlinearity',
19
+ 'SparseLinear': 'linear',
20
+ 'sparse_scaled_dot_product_attention': 'attention',
21
+ 'SerializeMode': 'attention',
22
+ 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
23
+ 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
24
+ 'sparse_windowed_scaled_dot_product_cross_attention': 'attention',
25
+ 'SparseRotaryPositionEmbedder': 'attention',
26
+ 'SparseMultiHeadAttention': 'attention',
27
+ 'SparseConv3d': 'conv',
28
+ 'SparseInverseConv3d': 'conv',
29
+ 'SparseDownsample': 'spatial',
30
+ 'SparseUpsample': 'spatial',
31
+ 'SparseSubdivide': 'spatial',
32
+ 'SparseSpatial2Channel': 'spatial',
33
+ 'SparseChannel2Spatial': 'spatial',
34
+ 'sparse_nearest_interpolate': 'spatial',
35
+ 'sparse_trilinear_interpolate': 'spatial',
36
+ 'encode_seq': 'serialize',
37
+ 'decode_seq': 'serialize',
38
+ }
39
+
40
+ __submodules = ['transformer', 'conv']
41
+
42
+ __all__ = list(__attributes.keys()) + __submodules
43
+
44
+ def __getattr__(name):
45
+ if name not in globals():
46
+ if name in __attributes:
47
+ module_name = __attributes[name]
48
+ module = importlib.import_module(f".{module_name}", __name__)
49
+ globals()[name] = getattr(module, name)
50
+ elif name in __submodules:
51
+ module = importlib.import_module(f".{name}", __name__)
52
+ globals()[name] = module
53
+ else:
54
+ raise AttributeError(f"module {__name__} has no attribute {name}")
55
+ return globals()[name]
56
+
57
+
58
+ # For Pylance
59
+ if __name__ == '__main__':
60
+ from .basic import *
61
+ from .norm import *
62
+ from .nonlinearity import *
63
+ from .linear import *
64
+ from .attention import *
65
+ from .conv import *
66
+ from .spatial import *
67
+ from .serialize import *
68
+ import transformer
69
+ import conv
trellis2/modules/sparse/attention/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .full_attn import *
2
+ from .windowed_attn import *
3
+ from .modules import *
trellis2/modules/sparse/attention/full_attn.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ from .. import VarLenTensor
4
+ from .. import config
5
+
6
+
7
+ __all__ = [
8
+ 'sparse_scaled_dot_product_attention',
9
+ ]
10
+
11
+
12
+ @overload
13
+ def sparse_scaled_dot_product_attention(qkv: VarLenTensor) -> VarLenTensor:
14
+ """
15
+ Apply scaled dot product attention to a sparse tensor.
16
+
17
+ Args:
18
+ qkv (VarLenTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
19
+ """
20
+ ...
21
+
22
+ @overload
23
+ def sparse_scaled_dot_product_attention(q: VarLenTensor, kv: Union[VarLenTensor, torch.Tensor]) -> VarLenTensor:
24
+ """
25
+ Apply scaled dot product attention to a sparse tensor.
26
+
27
+ Args:
28
+ q (VarLenTensor): A [N, *, H, C] sparse tensor containing Qs.
29
+ kv (VarLenTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
30
+ """
31
+ ...
32
+
33
+ @overload
34
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: VarLenTensor) -> torch.Tensor:
35
+ """
36
+ Apply scaled dot product attention to a sparse tensor.
37
+
38
+ Args:
39
+ q (torch.Tensor): A [N, L, H, C] dense tensor containing Qs.
40
+ kv (VarLenTensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
41
+ """
42
+ ...
43
+
44
+ @overload
45
+ def sparse_scaled_dot_product_attention(q: VarLenTensor, k: VarLenTensor, v: VarLenTensor) -> VarLenTensor:
46
+ """
47
+ Apply scaled dot product attention to a sparse tensor.
48
+
49
+ Args:
50
+ q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs.
51
+ k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks.
52
+ v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs.
53
+
54
+ Note:
55
+ k and v are assumed to have the same coordinate map.
56
+ """
57
+ ...
58
+
59
+ @overload
60
+ def sparse_scaled_dot_product_attention(q: VarLenTensor, k: torch.Tensor, v: torch.Tensor) -> VarLenTensor:
61
+ """
62
+ Apply scaled dot product attention to a sparse tensor.
63
+
64
+ Args:
65
+ q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs.
66
+ k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
67
+ v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
68
+ """
69
+ ...
70
+
71
+ @overload
72
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, k: VarLenTensor, v: VarLenTensor) -> torch.Tensor:
73
+ """
74
+ Apply scaled dot product attention to a sparse tensor.
75
+
76
+ Args:
77
+ q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
78
+ k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks.
79
+ v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs.
80
+ """
81
+ ...
82
+
83
+ def sparse_scaled_dot_product_attention(*args, **kwargs):
84
+ arg_names_dict = {
85
+ 1: ['qkv'],
86
+ 2: ['q', 'kv'],
87
+ 3: ['q', 'k', 'v']
88
+ }
89
+ num_all_args = len(args) + len(kwargs)
90
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
91
+ for key in arg_names_dict[num_all_args][len(args):]:
92
+ assert key in kwargs, f"Missing argument {key}"
93
+
94
+ if num_all_args == 1:
95
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
96
+ assert isinstance(qkv, VarLenTensor), f"qkv must be a VarLenTensor, got {type(qkv)}"
97
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
98
+ device = qkv.device
99
+
100
+ s = qkv
101
+ q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
102
+ kv_seqlen = q_seqlen
103
+ qkv = qkv.feats # [T, 3, H, C]
104
+
105
+ elif num_all_args == 2:
106
+ q = args[0] if len(args) > 0 else kwargs['q']
107
+ kv = args[1] if len(args) > 1 else kwargs['kv']
108
+ assert isinstance(q, VarLenTensor) and isinstance(kv, (VarLenTensor, torch.Tensor)) or \
109
+ isinstance(q, torch.Tensor) and isinstance(kv, VarLenTensor), \
110
+ f"Invalid types, got {type(q)} and {type(kv)}"
111
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
112
+ device = q.device
113
+
114
+ if isinstance(q, VarLenTensor):
115
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
116
+ s = q
117
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
118
+ q = q.feats # [T_Q, H, C]
119
+ else:
120
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
121
+ s = None
122
+ N, L, H, C = q.shape
123
+ q_seqlen = [L] * N
124
+ q = q.reshape(N * L, H, C) # [T_Q, H, C]
125
+
126
+ if isinstance(kv, VarLenTensor):
127
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
128
+ kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
129
+ kv = kv.feats # [T_KV, 2, H, C]
130
+ else:
131
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
132
+ N, L, _, H, C = kv.shape
133
+ kv_seqlen = [L] * N
134
+ kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
135
+
136
+ elif num_all_args == 3:
137
+ q = args[0] if len(args) > 0 else kwargs['q']
138
+ k = args[1] if len(args) > 1 else kwargs['k']
139
+ v = args[2] if len(args) > 2 else kwargs['v']
140
+ assert isinstance(q, VarLenTensor) and isinstance(k, (VarLenTensor, torch.Tensor)) and type(k) == type(v) or \
141
+ isinstance(q, torch.Tensor) and isinstance(k, VarLenTensor) and isinstance(v, VarLenTensor), \
142
+ f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
143
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
144
+ device = q.device
145
+
146
+ if isinstance(q, VarLenTensor):
147
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
148
+ s = q
149
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
150
+ q = q.feats # [T_Q, H, Ci]
151
+ else:
152
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
153
+ s = None
154
+ N, L, H, CI = q.shape
155
+ q_seqlen = [L] * N
156
+ q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
157
+
158
+ if isinstance(k, VarLenTensor):
159
+ assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
160
+ assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
161
+ kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
162
+ k = k.feats # [T_KV, H, Ci]
163
+ v = v.feats # [T_KV, H, Co]
164
+ else:
165
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
166
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
167
+ N, L, H, CI, CO = *k.shape, v.shape[-1]
168
+ kv_seqlen = [L] * N
169
+ k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
170
+ v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
171
+
172
+ if config.ATTN == 'xformers':
173
+ if 'xops' not in globals():
174
+ import xformers.ops as xops
175
+ if num_all_args == 1:
176
+ q, k, v = qkv.unbind(dim=1)
177
+ elif num_all_args == 2:
178
+ k, v = kv.unbind(dim=1)
179
+ q = q.unsqueeze(0)
180
+ k = k.unsqueeze(0)
181
+ v = v.unsqueeze(0)
182
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
183
+ out = xops.memory_efficient_attention(q, k, v, mask)[0]
184
+ elif config.ATTN == 'flash_attn':
185
+ if 'flash_attn' not in globals():
186
+ import flash_attn
187
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
188
+ if num_all_args in [2, 3]:
189
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
190
+ if num_all_args == 1:
191
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
192
+ elif num_all_args == 2:
193
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
194
+ elif num_all_args == 3:
195
+ out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
196
+ elif config.ATTN == 'flash_attn_3':
197
+ if 'flash_attn_3' not in globals():
198
+ import flash_attn_interface as flash_attn_3
199
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
200
+ if num_all_args in [2, 3]:
201
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
202
+ if num_all_args == 1:
203
+ out = flash_attn_3.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
204
+ elif num_all_args == 2:
205
+ out = flash_attn_3.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
206
+ elif num_all_args == 3:
207
+ out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
208
+ else:
209
+ raise ValueError(f"Unknown attention module: {config.ATTN}")
210
+
211
+ if s is not None:
212
+ return s.replace(out)
213
+ else:
214
+ return out.reshape(N, L, H, -1)
trellis2/modules/sparse/attention/modules.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .. import VarLenTensor, SparseTensor
6
+ from .full_attn import sparse_scaled_dot_product_attention
7
+ from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
8
+ from .rope import SparseRotaryPositionEmbedder
9
+
10
+
11
+ class SparseMultiHeadRMSNorm(nn.Module):
12
+ def __init__(self, dim: int, heads: int):
13
+ super().__init__()
14
+ self.scale = dim ** 0.5
15
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
16
+
17
+ def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
18
+ x_type = x.dtype
19
+ x = x.float()
20
+ if isinstance(x, VarLenTensor):
21
+ x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale)
22
+ else:
23
+ x = F.normalize(x, dim=-1) * self.gamma * self.scale
24
+ return x.to(x_type)
25
+
26
+
27
+ class SparseMultiHeadAttention(nn.Module):
28
+ def __init__(
29
+ self,
30
+ channels: int,
31
+ num_heads: int,
32
+ ctx_channels: Optional[int] = None,
33
+ type: Literal["self", "cross"] = "self",
34
+ attn_mode: Literal["full", "windowed", "double_windowed"] = "full",
35
+ window_size: Optional[int] = None,
36
+ shift_window: Optional[Tuple[int, int, int]] = None,
37
+ qkv_bias: bool = True,
38
+ use_rope: bool = False,
39
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
40
+ qk_rms_norm: bool = False,
41
+ ):
42
+ super().__init__()
43
+ assert channels % num_heads == 0
44
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
45
+ assert attn_mode in ["full", "windowed", "double_windowed"], f"Invalid attention mode: {attn_mode}"
46
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
47
+ assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
48
+ if attn_mode == 'double_windowed':
49
+ assert window_size % 2 == 0, "Window size must be even for double windowed attention"
50
+ assert num_heads % 2 == 0, "Number of heads must be even for double windowed attention"
51
+ self.channels = channels
52
+ self.head_dim = channels // num_heads
53
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
54
+ self.num_heads = num_heads
55
+ self._type = type
56
+ self.attn_mode = attn_mode
57
+ self.window_size = window_size
58
+ self.shift_window = shift_window
59
+ self.use_rope = use_rope
60
+ self.qk_rms_norm = qk_rms_norm
61
+
62
+ if self._type == "self":
63
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
64
+ else:
65
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
66
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
67
+
68
+ if self.qk_rms_norm:
69
+ self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
70
+ self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
71
+
72
+ self.to_out = nn.Linear(channels, channels)
73
+
74
+ if use_rope:
75
+ self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq)
76
+
77
+ @staticmethod
78
+ def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
79
+ if isinstance(x, VarLenTensor):
80
+ return x.replace(module(x.feats))
81
+ else:
82
+ return module(x)
83
+
84
+ @staticmethod
85
+ def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]:
86
+ if isinstance(x, VarLenTensor):
87
+ return x.reshape(*shape)
88
+ else:
89
+ return x.reshape(*x.shape[:2], *shape)
90
+
91
+ def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]:
92
+ if isinstance(x, VarLenTensor):
93
+ x_feats = x.feats.unsqueeze(0)
94
+ else:
95
+ x_feats = x
96
+ x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
97
+ return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
98
+
99
+ def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
100
+ if self._type == "self":
101
+ qkv = self._linear(self.to_qkv, x)
102
+ qkv = self._fused_pre(qkv, num_fused=3)
103
+ if self.qk_rms_norm or self.use_rope:
104
+ q, k, v = qkv.unbind(dim=-3)
105
+ if self.qk_rms_norm:
106
+ q = self.q_rms_norm(q)
107
+ k = self.k_rms_norm(k)
108
+ if self.use_rope:
109
+ q, k = self.rope(q, k)
110
+ qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
111
+ if self.attn_mode == "full":
112
+ h = sparse_scaled_dot_product_attention(qkv)
113
+ elif self.attn_mode == "windowed":
114
+ h = sparse_windowed_scaled_dot_product_self_attention(
115
+ qkv, self.window_size, shift_window=self.shift_window
116
+ )
117
+ elif self.attn_mode == "double_windowed":
118
+ qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
119
+ qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
120
+ h0 = sparse_windowed_scaled_dot_product_self_attention(
121
+ qkv0, self.window_size, shift_window=(0, 0, 0)
122
+ )
123
+ h1 = sparse_windowed_scaled_dot_product_self_attention(
124
+ qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
125
+ )
126
+ h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
127
+ else:
128
+ q = self._linear(self.to_q, x)
129
+ q = self._reshape_chs(q, (self.num_heads, -1))
130
+ kv = self._linear(self.to_kv, context)
131
+ kv = self._fused_pre(kv, num_fused=2)
132
+ if self.qk_rms_norm:
133
+ q = self.q_rms_norm(q)
134
+ k, v = kv.unbind(dim=-3)
135
+ k = self.k_rms_norm(k)
136
+ h = sparse_scaled_dot_product_attention(q, k, v)
137
+ else:
138
+ h = sparse_scaled_dot_product_attention(q, kv)
139
+ h = self._reshape_chs(h, (-1,))
140
+ h = self._linear(self.to_out, h)
141
+ return h
trellis2/modules/sparse/attention/rope.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..basic import SparseTensor
5
+
6
+
7
+ class SparseRotaryPositionEmbedder(nn.Module):
8
+ def __init__(
9
+ self,
10
+ head_dim: int,
11
+ dim: int = 3,
12
+ rope_freq: Tuple[float, float] = (1.0, 10000.0)
13
+ ):
14
+ super().__init__()
15
+ assert head_dim % 2 == 0, "Head dim must be divisible by 2"
16
+ self.head_dim = head_dim
17
+ self.dim = dim
18
+ self.rope_freq = rope_freq
19
+ self.freq_dim = head_dim // 2 // dim
20
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
21
+ self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
22
+
23
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
24
+ self.freqs = self.freqs.to(indices.device)
25
+ phases = torch.outer(indices, self.freqs)
26
+ phases = torch.polar(torch.ones_like(phases), phases)
27
+ return phases
28
+
29
+ def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
30
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
31
+ x_rotated = x_complex * phases.unsqueeze(-2)
32
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
33
+ return x_embed
34
+
35
+ def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ """
37
+ Args:
38
+ q (SparseTensor): [..., N, H, D] tensor of queries
39
+ k (SparseTensor): [..., N, H, D] tensor of keys
40
+ """
41
+ assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1"
42
+ phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}'
43
+ phases = q.get_spatial_cache(phases_cache_name)
44
+ if phases is None:
45
+ coords = q.coords[..., 1:]
46
+ phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1)
47
+ if phases.shape[-1] < self.head_dim // 2:
48
+ padn = self.head_dim // 2 - phases.shape[-1]
49
+ phases = torch.cat([phases, torch.polar(
50
+ torch.ones(*phases.shape[:-1], padn, device=phases.device),
51
+ torch.zeros(*phases.shape[:-1], padn, device=phases.device)
52
+ )], dim=-1)
53
+ q.register_spatial_cache(phases_cache_name, phases)
54
+ q_embed = q.replace(self._rotary_embedding(q.feats, phases))
55
+ if k is None:
56
+ return q_embed
57
+ k_embed = k.replace(self._rotary_embedding(k.feats, phases))
58
+ return q_embed, k_embed
trellis2/modules/sparse/attention/windowed_attn.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from .. import SparseTensor
5
+ from .. import config
6
+
7
+
8
+ __all__ = [
9
+ 'sparse_windowed_scaled_dot_product_self_attention',
10
+ 'sparse_windowed_scaled_dot_product_cross_attention',
11
+ ]
12
+
13
+
14
+ def calc_window_partition(
15
+ tensor: SparseTensor,
16
+ window_size: Union[int, Tuple[int, ...]],
17
+ shift_window: Union[int, Tuple[int, ...]] = 0,
18
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
19
+ """
20
+ Calculate serialization and partitioning for a set of coordinates.
21
+
22
+ Args:
23
+ tensor (SparseTensor): The input tensor.
24
+ window_size (int): The window size to use.
25
+ shift_window (Tuple[int, ...]): The shift of serialized coordinates.
26
+
27
+ Returns:
28
+ (torch.Tensor): Forwards indices.
29
+ (torch.Tensor): Backwards indices.
30
+ (torch.Tensor): Sequence lengths.
31
+ (dict): Attn func args.
32
+ """
33
+ DIM = tensor.coords.shape[1] - 1
34
+ shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
35
+ window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
36
+ shifted_coords = tensor.coords.clone().detach()
37
+ shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
38
+
39
+ MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
40
+ NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
41
+ OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
42
+
43
+ shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
44
+ shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
45
+ fwd_indices = torch.argsort(shifted_indices)
46
+ bwd_indices = torch.empty_like(fwd_indices)
47
+ bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
48
+ seq_lens = torch.bincount(shifted_indices)
49
+ mask = seq_lens != 0
50
+ seq_lens = seq_lens[mask]
51
+
52
+ if config.ATTN == 'xformers':
53
+ if 'xops' not in globals():
54
+ import xformers.ops as xops
55
+ attn_func_args = {
56
+ 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
57
+ }
58
+ elif config.ATTN == 'flash_attn':
59
+ attn_func_args = {
60
+ 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
61
+ 'max_seqlen': torch.max(seq_lens)
62
+ }
63
+
64
+ return fwd_indices, bwd_indices, seq_lens, attn_func_args
65
+
66
+
67
+ def sparse_windowed_scaled_dot_product_self_attention(
68
+ qkv: SparseTensor,
69
+ window_size: int,
70
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
71
+ ) -> SparseTensor:
72
+ """
73
+ Apply windowed scaled dot product self attention to a sparse tensor.
74
+
75
+ Args:
76
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
77
+ window_size (int): The window size to use.
78
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
79
+
80
+ Returns:
81
+ (SparseTensor): [N, *, H, C] sparse tensor containing the output features.
82
+ """
83
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
84
+
85
+ serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}'
86
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
87
+ if serialization_spatial_cache is None:
88
+ fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window)
89
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args))
90
+ else:
91
+ fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
92
+
93
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
94
+
95
+ if config.DEBUG:
96
+ start = 0
97
+ qkv_coords = qkv.coords[fwd_indices]
98
+ for i in range(len(seq_lens)):
99
+ seq_coords = qkv_coords[start:start+seq_lens[i]]
100
+ assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
101
+ f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
102
+ start += seq_lens[i]
103
+
104
+ if config.ATTN == 'xformers':
105
+ if 'xops' not in globals():
106
+ import xformers.ops as xops
107
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
108
+ q = q.unsqueeze(0) # [1, M, H, C]
109
+ k = k.unsqueeze(0) # [1, M, H, C]
110
+ v = v.unsqueeze(0) # [1, M, H, C]
111
+ out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
112
+ elif config.ATTN == 'flash_attn':
113
+ if 'flash_attn' not in globals():
114
+ import flash_attn
115
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
116
+
117
+ out = out[bwd_indices] # [T, H, C]
118
+
119
+ if config.DEBUG:
120
+ qkv_coords = qkv_coords[bwd_indices]
121
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
122
+
123
+ return qkv.replace(out)
124
+
125
+
126
+ def sparse_windowed_scaled_dot_product_cross_attention(
127
+ q: SparseTensor,
128
+ kv: SparseTensor,
129
+ q_window_size: int,
130
+ kv_window_size: int,
131
+ q_shift_window: Tuple[int, int, int] = (0, 0, 0),
132
+ kv_shift_window: Tuple[int, int, int] = (0, 0, 0),
133
+ ) -> SparseTensor:
134
+ """
135
+ Apply windowed scaled dot product cross attention to two sparse tensors.
136
+
137
+ Args:
138
+ q (SparseTensor): [N, *, H, C] sparse tensor containing Qs.
139
+ kv (SparseTensor): [N, *, 2, H, C] sparse tensor containing Ks and Vs.
140
+ q_window_size (int): The window size to use for Qs.
141
+ kv_window_size (int): The window size to use for Ks and Vs.
142
+ q_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Qs.
143
+ kv_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Ks and Vs.
144
+
145
+ Returns:
146
+ (SparseTensor): [N, *, H, C] sparse tensor containing the output features.
147
+ """
148
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
149
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
150
+
151
+ q_serialization_spatial_cache_name = f'windowed_attention_{q_window_size}_{q_shift_window}'
152
+ q_serialization_spatial_cache = q.get_spatial_cache(q_serialization_spatial_cache_name)
153
+ if q_serialization_spatial_cache is None:
154
+ q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = calc_window_partition(q, q_window_size, q_shift_window)
155
+ q.register_spatial_cache(q_serialization_spatial_cache_name, (q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args))
156
+ else:
157
+ q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = q_serialization_spatial_cache
158
+ kv_serialization_spatial_cache_name = f'windowed_attention_{kv_window_size}_{kv_shift_window}'
159
+ kv_serialization_spatial_cache = kv.get_spatial_cache(kv_serialization_spatial_cache_name)
160
+ if kv_serialization_spatial_cache is None:
161
+ kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = calc_window_partition(kv, kv_window_size, kv_shift_window)
162
+ kv.register_spatial_cache(kv_serialization_spatial_cache_name, (kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args))
163
+ else:
164
+ kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = kv_serialization_spatial_cache
165
+
166
+ assert len(q_seq_lens) == len(kv_seq_lens), "Number of sequences in q and kv must match"
167
+
168
+ q_feats = q.feats[q_fwd_indices] # [M, H, C]
169
+ kv_feats = kv.feats[kv_fwd_indices] # [M, 2, H, C]
170
+
171
+ if config.ATTN == 'xformers':
172
+ if 'xops' not in globals():
173
+ import xformers.ops as xops
174
+ k, v = kv_feats.unbind(dim=1) # [M, H, C]
175
+ q = q.unsqueeze(0) # [1, M, H, C]
176
+ k = k.unsqueeze(0) # [1, M, H, C]
177
+ v = v.unsqueeze(0) # [1, M, H, C]
178
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seq_lens, kv_seq_lens)
179
+ out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)[0] # [M, H, C]
180
+ elif config.ATTN == 'flash_attn':
181
+ if 'flash_attn' not in globals():
182
+ import flash_attn
183
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q_feats, kv_feats,
184
+ cu_seqlens_q=q_attn_func_args['cu_seqlens'], cu_seqlens_k=kv_attn_func_args['cu_seqlens'],
185
+ max_seqlen_q=q_attn_func_args['max_seqlen'], max_seqlen_k=kv_attn_func_args['max_seqlen'],
186
+ ) # [M, H, C]
187
+
188
+ out = out[q_bwd_indices] # [T, H, C]
189
+
190
+ return q.replace(out)
trellis2/modules/sparse/basic.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from fractions import Fraction
3
+ import torch
4
+ from . import config
5
+
6
+
7
+ __all__ = [
8
+ 'VarLenTensor',
9
+ 'varlen_cat',
10
+ 'varlen_unbind',
11
+ 'SparseTensor',
12
+ 'sparse_cat',
13
+ 'sparse_unbind',
14
+ ]
15
+
16
+
17
+ class VarLenTensor:
18
+ """
19
+ Sequential tensor with variable length.
20
+
21
+ Args:
22
+ feats (torch.Tensor): Features of the varlen tensor.
23
+ layout (List[slice]): Layout of the varlen tensor for each batch
24
+ """
25
+ def __init__(self, feats: torch.Tensor, layout: List[slice]=None):
26
+ self.feats = feats
27
+ self.layout = layout if layout is not None else [slice(0, feats.shape[0])]
28
+ self._cache = {}
29
+
30
+ @staticmethod
31
+ def layout_from_seqlen(seqlen: list) -> List[slice]:
32
+ """
33
+ Create a layout from a tensor of sequence lengths.
34
+ """
35
+ layout = []
36
+ start = 0
37
+ for l in seqlen:
38
+ layout.append(slice(start, start + l))
39
+ start += l
40
+ return layout
41
+
42
+ @staticmethod
43
+ def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor':
44
+ """
45
+ Create a VarLenTensor from a list of tensors.
46
+ """
47
+ feats = torch.cat(tensor_list, dim=0)
48
+ layout = []
49
+ start = 0
50
+ for tensor in tensor_list:
51
+ layout.append(slice(start, start + tensor.shape[0]))
52
+ start += tensor.shape[0]
53
+ return VarLenTensor(feats, layout)
54
+
55
+ def to_tensor_list(self) -> List[torch.Tensor]:
56
+ """
57
+ Convert a VarLenTensor to a list of tensors.
58
+ """
59
+ tensor_list = []
60
+ for s in self.layout:
61
+ tensor_list.append(self.feats[s])
62
+ return tensor_list
63
+
64
+ def __len__(self) -> int:
65
+ return len(self.layout)
66
+
67
+ @property
68
+ def shape(self) -> torch.Size:
69
+ return torch.Size([len(self.layout), *self.feats.shape[1:]])
70
+
71
+ def dim(self) -> int:
72
+ return len(self.shape)
73
+
74
+ @property
75
+ def ndim(self) -> int:
76
+ return self.dim()
77
+
78
+ @property
79
+ def dtype(self):
80
+ return self.feats.dtype
81
+
82
+ @property
83
+ def device(self):
84
+ return self.feats.device
85
+
86
+ @property
87
+ def seqlen(self) -> torch.LongTensor:
88
+ if 'seqlen' not in self._cache:
89
+ self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device)
90
+ return self._cache['seqlen']
91
+
92
+ @property
93
+ def cum_seqlen(self) -> torch.LongTensor:
94
+ if 'cum_seqlen' not in self._cache:
95
+ self._cache['cum_seqlen'] = torch.cat([
96
+ torch.tensor([0], dtype=torch.long, device=self.device),
97
+ self.seqlen.cumsum(dim=0)
98
+ ], dim=0)
99
+ return self._cache['cum_seqlen']
100
+
101
+ @property
102
+ def batch_boardcast_map(self) -> torch.LongTensor:
103
+ """
104
+ Get the broadcast map for the varlen tensor.
105
+ """
106
+ if 'batch_boardcast_map' not in self._cache:
107
+ self._cache['batch_boardcast_map'] = torch.repeat_interleave(
108
+ torch.arange(len(self.layout), device=self.device),
109
+ self.seqlen,
110
+ )
111
+ return self._cache['batch_boardcast_map']
112
+
113
+ @overload
114
+ def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ...
115
+
116
+ @overload
117
+ def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ...
118
+
119
+ def to(self, *args, **kwargs) -> 'VarLenTensor':
120
+ device = None
121
+ dtype = None
122
+ if len(args) == 2:
123
+ device, dtype = args
124
+ elif len(args) == 1:
125
+ if isinstance(args[0], torch.dtype):
126
+ dtype = args[0]
127
+ else:
128
+ device = args[0]
129
+ if 'dtype' in kwargs:
130
+ assert dtype is None, "to() received multiple values for argument 'dtype'"
131
+ dtype = kwargs['dtype']
132
+ if 'device' in kwargs:
133
+ assert device is None, "to() received multiple values for argument 'device'"
134
+ device = kwargs['device']
135
+ non_blocking = kwargs.get('non_blocking', False)
136
+ copy = kwargs.get('copy', False)
137
+
138
+ new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy)
139
+ return self.replace(new_feats)
140
+
141
+ def type(self, dtype):
142
+ new_feats = self.feats.type(dtype)
143
+ return self.replace(new_feats)
144
+
145
+ def cpu(self) -> 'VarLenTensor':
146
+ new_feats = self.feats.cpu()
147
+ return self.replace(new_feats)
148
+
149
+ def cuda(self) -> 'VarLenTensor':
150
+ new_feats = self.feats.cuda()
151
+ return self.replace(new_feats)
152
+
153
+ def half(self) -> 'VarLenTensor':
154
+ new_feats = self.feats.half()
155
+ return self.replace(new_feats)
156
+
157
+ def float(self) -> 'VarLenTensor':
158
+ new_feats = self.feats.float()
159
+ return self.replace(new_feats)
160
+
161
+ def detach(self) -> 'VarLenTensor':
162
+ new_feats = self.feats.detach()
163
+ return self.replace(new_feats)
164
+
165
+ def reshape(self, *shape) -> 'VarLenTensor':
166
+ new_feats = self.feats.reshape(self.feats.shape[0], *shape)
167
+ return self.replace(new_feats)
168
+
169
+ def unbind(self, dim: int) -> List['VarLenTensor']:
170
+ return varlen_unbind(self, dim)
171
+
172
+ def replace(self, feats: torch.Tensor) -> 'VarLenTensor':
173
+ new_tensor = VarLenTensor(
174
+ feats=feats,
175
+ layout=self.layout,
176
+ )
177
+ new_tensor._cache = self._cache
178
+ return new_tensor
179
+
180
+ def to_dense(self, max_length=None) -> torch.Tensor:
181
+ """
182
+ Convert a VarLenTensor to a dense representation without for-loop.
183
+
184
+ Returns:
185
+ dense (torch.Tensor): (N, L, C) dense tensor
186
+ mask (torch.BoolTensor): (N, L) mask indicating valid positions
187
+ """
188
+ N = len(self)
189
+ L = max_length or self.seqlen.max().item()
190
+ spatial = self.feats.shape[1:]
191
+ idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L)
192
+ mask = (idx < self.seqlen.unsqueeze(1))
193
+ mapping = mask.reshape(-1).cumsum(dim=0) - 1
194
+ dense = self.feats[mapping]
195
+ dense = dense.reshape(N, L, *spatial)
196
+ return dense, mask
197
+
198
+ def __neg__(self) -> 'VarLenTensor':
199
+ return self.replace(-self.feats)
200
+
201
+ def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor':
202
+ if isinstance(other, torch.Tensor):
203
+ try:
204
+ other = torch.broadcast_to(other, self.shape)
205
+ other = other[self.batch_boardcast_map]
206
+ except:
207
+ pass
208
+ if isinstance(other, VarLenTensor):
209
+ other = other.feats
210
+ new_feats = op(self.feats, other)
211
+ new_tensor = self.replace(new_feats)
212
+ return new_tensor
213
+
214
+ def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
215
+ return self.__elemwise__(other, torch.add)
216
+
217
+ def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
218
+ return self.__elemwise__(other, torch.add)
219
+
220
+ def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
221
+ return self.__elemwise__(other, torch.sub)
222
+
223
+ def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
224
+ return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
225
+
226
+ def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
227
+ return self.__elemwise__(other, torch.mul)
228
+
229
+ def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
230
+ return self.__elemwise__(other, torch.mul)
231
+
232
+ def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
233
+ return self.__elemwise__(other, torch.div)
234
+
235
+ def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
236
+ return self.__elemwise__(other, lambda x, y: torch.div(y, x))
237
+
238
+ def __getitem__(self, idx):
239
+ if isinstance(idx, int):
240
+ idx = [idx]
241
+ elif isinstance(idx, slice):
242
+ idx = range(*idx.indices(self.shape[0]))
243
+ elif isinstance(idx, list):
244
+ assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}"
245
+ elif isinstance(idx, torch.Tensor):
246
+ if idx.dtype == torch.bool:
247
+ assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
248
+ idx = idx.nonzero().squeeze(1)
249
+ elif idx.dtype in [torch.int32, torch.int64]:
250
+ assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
251
+ else:
252
+ raise ValueError(f"Unknown index type: {idx.dtype}")
253
+ else:
254
+ raise ValueError(f"Unknown index type: {type(idx)}")
255
+
256
+ new_feats = []
257
+ new_layout = []
258
+ start = 0
259
+ for new_idx, old_idx in enumerate(idx):
260
+ new_feats.append(self.feats[self.layout[old_idx]])
261
+ new_layout.append(slice(start, start + len(new_feats[-1])))
262
+ start += len(new_feats[-1])
263
+ new_feats = torch.cat(new_feats, dim=0).contiguous()
264
+ new_tensor = VarLenTensor(feats=new_feats, layout=new_layout)
265
+ return new_tensor
266
+
267
+ def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
268
+ if isinstance(dim, int):
269
+ dim = (dim,)
270
+
271
+ if op =='mean':
272
+ red = self.feats.mean(dim=dim, keepdim=keepdim)
273
+ elif op =='sum':
274
+ red = self.feats.sum(dim=dim, keepdim=keepdim)
275
+ elif op == 'prod':
276
+ red = self.feats.prod(dim=dim, keepdim=keepdim)
277
+ else:
278
+ raise ValueError(f"Unsupported reduce operation: {op}")
279
+
280
+ if dim is None or 0 in dim:
281
+ return red
282
+
283
+ red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen)
284
+ return red
285
+
286
+ def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
287
+ return self.reduce(op='mean', dim=dim, keepdim=keepdim)
288
+
289
+ def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
290
+ return self.reduce(op='sum', dim=dim, keepdim=keepdim)
291
+
292
+ def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
293
+ return self.reduce(op='prod', dim=dim, keepdim=keepdim)
294
+
295
+ def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
296
+ mean = self.mean(dim=dim, keepdim=True)
297
+ mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True)
298
+ std = (mean2 - mean ** 2).sqrt()
299
+ return std
300
+
301
+ def __repr__(self) -> str:
302
+ return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})"
303
+
304
+
305
+ def varlen_cat(inputs: List[VarLenTensor], dim: int = 0) -> VarLenTensor:
306
+ """
307
+ Concatenate a list of varlen tensors.
308
+
309
+ Args:
310
+ inputs (List[VarLenTensor]): List of varlen tensors to concatenate.
311
+ """
312
+ if dim == 0:
313
+ new_feats = torch.cat([input.feats for input in inputs], dim=0)
314
+ start = 0
315
+ new_layout = []
316
+ for input in inputs:
317
+ for l in input.layout:
318
+ new_layout.append(slice(start, start + l.stop - l.start))
319
+ start += l.stop - l.start
320
+ output = VarLenTensor(feats=new_feats, layout=new_layout)
321
+ else:
322
+ feats = torch.cat([input.feats for input in inputs], dim=dim)
323
+ output = inputs[0].replace(feats)
324
+
325
+ return output
326
+
327
+
328
+ def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]:
329
+ """
330
+ Unbind a varlen tensor along a dimension.
331
+
332
+ Args:
333
+ input (VarLenTensor): Varlen tensor to unbind.
334
+ dim (int): Dimension to unbind.
335
+ """
336
+ if dim == 0:
337
+ return [input[i] for i in range(len(input))]
338
+ else:
339
+ feats = input.feats.unbind(dim)
340
+ return [input.replace(f) for f in feats]
341
+
342
+
343
+ class SparseTensor(VarLenTensor):
344
+ """
345
+ Sparse tensor with support for both torchsparse and spconv backends.
346
+
347
+ Parameters:
348
+ - feats (torch.Tensor): Features of the sparse tensor.
349
+ - coords (torch.Tensor): Coordinates of the sparse tensor.
350
+ - shape (torch.Size): Shape of the sparse tensor.
351
+ - layout (List[slice]): Layout of the sparse tensor for each batch
352
+ - data (SparseTensorData): Sparse tensor data used for convolusion
353
+
354
+ NOTE:
355
+ - Data corresponding to a same batch should be contiguous.
356
+ - Coords should be in [0, 1023]
357
+ """
358
+ SparseTensorData = None
359
+
360
+ @overload
361
+ def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ...
362
+
363
+ @overload
364
+ def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ...
365
+
366
+ def __init__(self, *args, **kwargs):
367
+ # Lazy import of sparse tensor backend
368
+ if self.SparseTensorData is None:
369
+ import importlib
370
+ if config.CONV == 'torchsparse':
371
+ self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor
372
+ elif config.CONV == 'spconv':
373
+ self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
374
+
375
+ method_id = 0
376
+ if len(args) != 0:
377
+ method_id = 0 if isinstance(args[0], torch.Tensor) else 1
378
+ else:
379
+ method_id = 1 if 'data' in kwargs else 0
380
+
381
+ if method_id == 0:
382
+ feats, coords, shape = args + (None,) * (3 - len(args))
383
+ if 'feats' in kwargs:
384
+ feats = kwargs['feats']
385
+ del kwargs['feats']
386
+ if 'coords' in kwargs:
387
+ coords = kwargs['coords']
388
+ del kwargs['coords']
389
+ if 'shape' in kwargs:
390
+ shape = kwargs['shape']
391
+ del kwargs['shape']
392
+
393
+ if config.CONV == 'torchsparse':
394
+ self.data = self.SparseTensorData(feats, coords, **kwargs)
395
+ elif config.CONV == 'spconv':
396
+ spatial_shape = list(coords.max(0)[0] + 1)
397
+ self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs)
398
+ self.data._features = feats
399
+ else:
400
+ self.data = {
401
+ 'feats': feats,
402
+ 'coords': coords,
403
+ }
404
+ elif method_id == 1:
405
+ data, shape = args + (None,) * (2 - len(args))
406
+ if 'data' in kwargs:
407
+ data = kwargs['data']
408
+ del kwargs['data']
409
+ if 'shape' in kwargs:
410
+ shape = kwargs['shape']
411
+ del kwargs['shape']
412
+
413
+ self.data = data
414
+
415
+ self._shape = shape
416
+ self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)))
417
+ self._spatial_cache = kwargs.get('spatial_cache', {})
418
+
419
+ if config.DEBUG:
420
+ try:
421
+ assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
422
+ assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
423
+ assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
424
+ for i in range(self.shape[0]):
425
+ assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
426
+ except Exception as e:
427
+ print('Debugging information:')
428
+ print(f"- Shape: {self.shape}")
429
+ print(f"- Layout: {self.layout}")
430
+ print(f"- Scale: {self._scale}")
431
+ print(f"- Coords: {self.coords}")
432
+ raise e
433
+
434
+ @staticmethod
435
+ def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor':
436
+ """
437
+ Create a SparseTensor from a list of tensors.
438
+ """
439
+ feats = torch.cat(feats_list, dim=0)
440
+ coords = []
441
+ for i, coord in enumerate(coords_list):
442
+ coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1)
443
+ coords.append(coord)
444
+ coords = torch.cat(coords, dim=0)
445
+ return SparseTensor(feats, coords)
446
+
447
+ def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
448
+ """
449
+ Convert a SparseTensor to list of tensors.
450
+ """
451
+ feats_list = []
452
+ coords_list = []
453
+ for s in self.layout:
454
+ feats_list.append(self.feats[s])
455
+ coords_list.append(self.coords[s])
456
+ return feats_list, coords_list
457
+
458
+ def __len__(self) -> int:
459
+ return len(self.layout)
460
+
461
+ def __cal_shape(self, feats, coords):
462
+ shape = []
463
+ shape.append(coords[:, 0].max().item() + 1)
464
+ shape.extend([*feats.shape[1:]])
465
+ return torch.Size(shape)
466
+
467
+ def __cal_layout(self, coords, batch_size):
468
+ seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
469
+ offset = torch.cumsum(seq_len, dim=0)
470
+ layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
471
+ return layout
472
+
473
+ def __cal_spatial_shape(self, coords):
474
+ return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist())
475
+
476
+ @property
477
+ def shape(self) -> torch.Size:
478
+ if self._shape is None:
479
+ self._shape = self.__cal_shape(self.feats, self.coords)
480
+ return self._shape
481
+
482
+ @property
483
+ def layout(self) -> List[slice]:
484
+ layout = self.get_spatial_cache('layout')
485
+ if layout is None:
486
+ layout = self.__cal_layout(self.coords, self.shape[0])
487
+ self.register_spatial_cache('layout', layout)
488
+ return layout
489
+
490
+ @property
491
+ def spatial_shape(self) -> torch.Size:
492
+ spatial_shape = self.get_spatial_cache('shape')
493
+ if spatial_shape is None:
494
+ spatial_shape = self.__cal_spatial_shape(self.coords)
495
+ self.register_spatial_cache('shape', spatial_shape)
496
+ return spatial_shape
497
+
498
+ @property
499
+ def feats(self) -> torch.Tensor:
500
+ if config.CONV == 'torchsparse':
501
+ return self.data.F
502
+ elif config.CONV == 'spconv':
503
+ return self.data.features
504
+ else:
505
+ return self.data['feats']
506
+
507
+ @feats.setter
508
+ def feats(self, value: torch.Tensor):
509
+ if config.CONV == 'torchsparse':
510
+ self.data.F = value
511
+ elif config.CONV == 'spconv':
512
+ self.data.features = value
513
+ else:
514
+ self.data['feats'] = value
515
+
516
+ @property
517
+ def coords(self) -> torch.Tensor:
518
+ if config.CONV == 'torchsparse':
519
+ return self.data.C
520
+ elif config.CONV == 'spconv':
521
+ return self.data.indices
522
+ else:
523
+ return self.data['coords']
524
+
525
+ @coords.setter
526
+ def coords(self, value: torch.Tensor):
527
+ if config.CONV == 'torchsparse':
528
+ self.data.C = value
529
+ elif config.CONV == 'spconv':
530
+ self.data.indices = value
531
+ else:
532
+ self.data['coords'] = value
533
+
534
+ @property
535
+ def dtype(self):
536
+ return self.feats.dtype
537
+
538
+ @property
539
+ def device(self):
540
+ return self.feats.device
541
+
542
+ @property
543
+ def seqlen(self) -> torch.LongTensor:
544
+ seqlen = self.get_spatial_cache('seqlen')
545
+ if seqlen is None:
546
+ seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device)
547
+ self.register_spatial_cache('seqlen', seqlen)
548
+ return seqlen
549
+
550
+ @property
551
+ def cum_seqlen(self) -> torch.LongTensor:
552
+ cum_seqlen = self.get_spatial_cache('cum_seqlen')
553
+ if cum_seqlen is None:
554
+ cum_seqlen = torch.cat([
555
+ torch.tensor([0], dtype=torch.long, device=self.device),
556
+ self.seqlen.cumsum(dim=0)
557
+ ], dim=0)
558
+ self.register_spatial_cache('cum_seqlen', cum_seqlen)
559
+ return cum_seqlen
560
+
561
+ @property
562
+ def batch_boardcast_map(self) -> torch.LongTensor:
563
+ """
564
+ Get the broadcast map for the varlen tensor.
565
+ """
566
+ batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map')
567
+ if batch_boardcast_map is None:
568
+ batch_boardcast_map = torch.repeat_interleave(
569
+ torch.arange(len(self.layout), device=self.device),
570
+ self.seqlen,
571
+ )
572
+ self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map)
573
+ return batch_boardcast_map
574
+
575
+ @overload
576
+ def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ...
577
+
578
+ @overload
579
+ def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ...
580
+
581
+ def to(self, *args, **kwargs) -> 'SparseTensor':
582
+ device = None
583
+ dtype = None
584
+ if len(args) == 2:
585
+ device, dtype = args
586
+ elif len(args) == 1:
587
+ if isinstance(args[0], torch.dtype):
588
+ dtype = args[0]
589
+ else:
590
+ device = args[0]
591
+ if 'dtype' in kwargs:
592
+ assert dtype is None, "to() received multiple values for argument 'dtype'"
593
+ dtype = kwargs['dtype']
594
+ if 'device' in kwargs:
595
+ assert device is None, "to() received multiple values for argument 'device'"
596
+ device = kwargs['device']
597
+ non_blocking = kwargs.get('non_blocking', False)
598
+ copy = kwargs.get('copy', False)
599
+
600
+ new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy)
601
+ new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy)
602
+ return self.replace(new_feats, new_coords)
603
+
604
+ def type(self, dtype):
605
+ new_feats = self.feats.type(dtype)
606
+ return self.replace(new_feats)
607
+
608
+ def cpu(self) -> 'SparseTensor':
609
+ new_feats = self.feats.cpu()
610
+ new_coords = self.coords.cpu()
611
+ return self.replace(new_feats, new_coords)
612
+
613
+ def cuda(self) -> 'SparseTensor':
614
+ new_feats = self.feats.cuda()
615
+ new_coords = self.coords.cuda()
616
+ return self.replace(new_feats, new_coords)
617
+
618
+ def half(self) -> 'SparseTensor':
619
+ new_feats = self.feats.half()
620
+ return self.replace(new_feats)
621
+
622
+ def float(self) -> 'SparseTensor':
623
+ new_feats = self.feats.float()
624
+ return self.replace(new_feats)
625
+
626
+ def detach(self) -> 'SparseTensor':
627
+ new_coords = self.coords.detach()
628
+ new_feats = self.feats.detach()
629
+ return self.replace(new_feats, new_coords)
630
+
631
+ def reshape(self, *shape) -> 'SparseTensor':
632
+ new_feats = self.feats.reshape(self.feats.shape[0], *shape)
633
+ return self.replace(new_feats)
634
+
635
+ def unbind(self, dim: int) -> List['SparseTensor']:
636
+ return sparse_unbind(self, dim)
637
+
638
+ def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
639
+ if config.CONV == 'torchsparse':
640
+ new_data = self.SparseTensorData(
641
+ feats=feats,
642
+ coords=self.data.coords if coords is None else coords,
643
+ stride=self.data.stride,
644
+ spatial_range=self.data.spatial_range,
645
+ )
646
+ new_data._caches = self.data._caches
647
+ elif config.CONV == 'spconv':
648
+ new_data = self.SparseTensorData(
649
+ self.data.features.reshape(self.data.features.shape[0], -1),
650
+ self.data.indices,
651
+ self.data.spatial_shape,
652
+ self.data.batch_size,
653
+ self.data.grid,
654
+ self.data.voxel_num,
655
+ self.data.indice_dict
656
+ )
657
+ new_data._features = feats
658
+ new_data.benchmark = self.data.benchmark
659
+ new_data.benchmark_record = self.data.benchmark_record
660
+ new_data.thrust_allocator = self.data.thrust_allocator
661
+ new_data._timer = self.data._timer
662
+ new_data.force_algo = self.data.force_algo
663
+ new_data.int8_scale = self.data.int8_scale
664
+ if coords is not None:
665
+ new_data.indices = coords
666
+ else:
667
+ new_data = {
668
+ 'feats': feats,
669
+ 'coords': self.data['coords'] if coords is None else coords,
670
+ }
671
+ new_tensor = SparseTensor(
672
+ new_data,
673
+ shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None,
674
+ scale=self._scale,
675
+ spatial_cache=self._spatial_cache
676
+ )
677
+ return new_tensor
678
+
679
+ def to_dense(self) -> torch.Tensor:
680
+ if config.CONV == 'torchsparse':
681
+ return self.data.dense()
682
+ elif config.CONV == 'spconv':
683
+ return self.data.dense()
684
+ else:
685
+ spatial_shape = self.spatial_shape
686
+ ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device)
687
+ idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1)
688
+ ret[tuple(idx)] = self.feats
689
+ return ret
690
+
691
+ @staticmethod
692
+ def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
693
+ N, C = dim
694
+ x = torch.arange(aabb[0], aabb[3] + 1)
695
+ y = torch.arange(aabb[1], aabb[4] + 1)
696
+ z = torch.arange(aabb[2], aabb[5] + 1)
697
+ coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
698
+ coords = torch.cat([
699
+ torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
700
+ coords.repeat(N, 1),
701
+ ], dim=1).to(dtype=torch.int32, device=device)
702
+ feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
703
+ return SparseTensor(feats=feats, coords=coords)
704
+
705
+ def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
706
+ new_cache = {}
707
+ for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
708
+ if k in self._spatial_cache:
709
+ new_cache[k] = self._spatial_cache[k]
710
+ if k in other._spatial_cache:
711
+ if k not in new_cache:
712
+ new_cache[k] = other._spatial_cache[k]
713
+ else:
714
+ new_cache[k].update(other._spatial_cache[k])
715
+ return new_cache
716
+
717
+ def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor':
718
+ if isinstance(other, torch.Tensor):
719
+ try:
720
+ other = torch.broadcast_to(other, self.shape)
721
+ other = other[self.batch_boardcast_map]
722
+ except:
723
+ pass
724
+ if isinstance(other, VarLenTensor):
725
+ other = other.feats
726
+ new_feats = op(self.feats, other)
727
+ new_tensor = self.replace(new_feats)
728
+ if isinstance(other, SparseTensor):
729
+ new_tensor._spatial_cache = self.__merge_sparse_cache(other)
730
+ return new_tensor
731
+
732
+ def __getitem__(self, idx):
733
+ if isinstance(idx, int):
734
+ idx = [idx]
735
+ elif isinstance(idx, slice):
736
+ idx = range(*idx.indices(self.shape[0]))
737
+ elif isinstance(idx, list):
738
+ assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}"
739
+ elif isinstance(idx, torch.Tensor):
740
+ if idx.dtype == torch.bool:
741
+ assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
742
+ idx = idx.nonzero().squeeze(1)
743
+ elif idx.dtype in [torch.int32, torch.int64]:
744
+ assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
745
+ else:
746
+ raise ValueError(f"Unknown index type: {idx.dtype}")
747
+ else:
748
+ raise ValueError(f"Unknown index type: {type(idx)}")
749
+
750
+ new_coords = []
751
+ new_feats = []
752
+ new_layout = []
753
+ new_shape = torch.Size([len(idx)] + list(self.shape[1:]))
754
+ start = 0
755
+ for new_idx, old_idx in enumerate(idx):
756
+ new_coords.append(self.coords[self.layout[old_idx]].clone())
757
+ new_coords[-1][:, 0] = new_idx
758
+ new_feats.append(self.feats[self.layout[old_idx]])
759
+ new_layout.append(slice(start, start + len(new_coords[-1])))
760
+ start += len(new_coords[-1])
761
+ new_coords = torch.cat(new_coords, dim=0).contiguous()
762
+ new_feats = torch.cat(new_feats, dim=0).contiguous()
763
+ new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape)
764
+ new_tensor.register_spatial_cache('layout', new_layout)
765
+ return new_tensor
766
+
767
+ def clear_spatial_cache(self) -> None:
768
+ """
769
+ Clear all spatial caches.
770
+ """
771
+ self._spatial_cache = {}
772
+
773
+ def register_spatial_cache(self, key, value) -> None:
774
+ """
775
+ Register a spatial cache.
776
+ The spatial cache can be any thing you want to cache.
777
+ The registery and retrieval of the cache is based on current scale.
778
+ """
779
+ scale_key = str(self._scale)
780
+ if scale_key not in self._spatial_cache:
781
+ self._spatial_cache[scale_key] = {}
782
+ self._spatial_cache[scale_key][key] = value
783
+
784
+ def get_spatial_cache(self, key=None):
785
+ """
786
+ Get a spatial cache.
787
+ """
788
+ scale_key = str(self._scale)
789
+ cur_scale_cache = self._spatial_cache.get(scale_key, {})
790
+ if key is None:
791
+ return cur_scale_cache
792
+ return cur_scale_cache.get(key, None)
793
+
794
+ def __repr__(self) -> str:
795
+ return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})"
796
+
797
+ def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
798
+ """
799
+ Concatenate a list of sparse tensors.
800
+
801
+ Args:
802
+ inputs (List[SparseTensor]): List of sparse tensors to concatenate.
803
+ """
804
+ if dim == 0:
805
+ start = 0
806
+ coords = []
807
+ for input in inputs:
808
+ coords.append(input.coords.clone())
809
+ coords[-1][:, 0] += start
810
+ start += input.shape[0]
811
+ coords = torch.cat(coords, dim=0)
812
+ feats = torch.cat([input.feats for input in inputs], dim=0)
813
+ output = SparseTensor(
814
+ coords=coords,
815
+ feats=feats,
816
+ )
817
+ else:
818
+ feats = torch.cat([input.feats for input in inputs], dim=dim)
819
+ output = inputs[0].replace(feats)
820
+
821
+ return output
822
+
823
+
824
+ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
825
+ """
826
+ Unbind a sparse tensor along a dimension.
827
+
828
+ Args:
829
+ input (SparseTensor): Sparse tensor to unbind.
830
+ dim (int): Dimension to unbind.
831
+ """
832
+ if dim == 0:
833
+ return [input[i] for i in range(input.shape[0])]
834
+ else:
835
+ feats = input.feats.unbind(dim)
836
+ return [input.replace(f) for f in feats]
trellis2/modules/sparse/config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ CONV = 'flex_gemm'
4
+ DEBUG = False
5
+ ATTN = 'flash_attn'
6
+
7
+ def __from_env():
8
+ import os
9
+
10
+ global CONV
11
+ global DEBUG
12
+ global ATTN
13
+
14
+ env_sparse_conv_backend = os.environ.get('SPARSE_CONV_BACKEND')
15
+ env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
+ env_sparse_attn_backend = os.environ.get('SPARSE_ATTN_BACKEND')
17
+ if env_sparse_attn_backend is None:
18
+ env_sparse_attn_backend = os.environ.get('ATTN_BACKEND')
19
+
20
+ if env_sparse_conv_backend is not None and env_sparse_conv_backend in ['none', 'spconv', 'torchsparse', 'flex_gemm']:
21
+ CONV = env_sparse_conv_backend
22
+ if env_sparse_debug is not None:
23
+ DEBUG = env_sparse_debug == '1'
24
+ if env_sparse_attn_backend is not None and env_sparse_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3']:
25
+ ATTN = env_sparse_attn_backend
26
+
27
+ print(f"[SPARSE] Conv backend: {CONV}; Attention backend: {ATTN}")
28
+
29
+
30
+ __from_env()
31
+
32
+
33
+ def set_conv_backend(backend: Literal['none', 'spconv', 'torchsparse', 'flex_gemm']):
34
+ global CONV
35
+ CONV = backend
36
+
37
+ def set_debug(debug: bool):
38
+ global DEBUG
39
+ DEBUG = debug
40
+
41
+ def set_attn_backend(backend: Literal['xformers', 'flash_attn']):
42
+ global ATTN
43
+ ATTN = backend
trellis2/modules/sparse/conv/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .conv import SparseConv3d, SparseInverseConv3d
2
+ from . import config
trellis2/modules/sparse/conv/config.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
2
+ FLEX_GEMM_ALGO = 'masked_implicit_gemm_splitk' # 'explicit_gemm', 'implicit_gemm', 'implicit_gemm_splitk', 'masked_implicit_gemm', 'masked_implicit_gemm_splitk'
3
+ FLEX_GEMM_HASHMAP_RATIO = 2.0 # Ratio of hashmap size to input size
trellis2/modules/sparse/conv/conv.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .. import config
2
+ import importlib
3
+ import torch
4
+ import torch.nn as nn
5
+ from .. import SparseTensor
6
+
7
+
8
+ _backends = {}
9
+
10
+
11
+ class SparseConv3d(nn.Module):
12
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
13
+ super(SparseConv3d, self).__init__()
14
+ if config.CONV not in _backends:
15
+ _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__)
16
+ _backends[config.CONV].sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key)
17
+
18
+ def forward(self, x: SparseTensor) -> SparseTensor:
19
+ return _backends[config.CONV].sparse_conv3d_forward(self, x)
20
+
21
+
22
+ class SparseInverseConv3d(nn.Module):
23
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
24
+ super(SparseInverseConv3d, self).__init__()
25
+ if config.CONV not in _backends:
26
+ _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__)
27
+ _backends[config.CONV].sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, bias, indice_key)
28
+
29
+ def forward(self, x: SparseTensor) -> SparseTensor:
30
+ return _backends[config.CONV].sparse_inverse_conv3d_forward(self, x)
trellis2/modules/sparse/conv/conv_flex_gemm.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from .. import SparseTensor
5
+ from . import config
6
+ import flex_gemm
7
+ from flex_gemm.ops.spconv import sparse_submanifold_conv3d
8
+
9
+
10
+ def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
11
+ assert stride == 1 and (padding is None), 'Currently flex_gemm implementation only support submanifold sparse convolution (stride=1, padding=None)'
12
+
13
+ self.in_channels = in_channels
14
+ self.out_channels = out_channels
15
+ self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3
16
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3
17
+ self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3
18
+
19
+ self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size)))
20
+ if bias:
21
+ self.bias = nn.Parameter(torch.empty(out_channels))
22
+ else:
23
+ self.register_parameter("bias", None)
24
+
25
+ # initialize parameters
26
+ torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
27
+ if self.bias is not None:
28
+ fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
29
+ if fan_in != 0:
30
+ bound = 1 / math.sqrt(fan_in)
31
+ torch.nn.init.uniform_(self.bias, -bound, bound)
32
+
33
+ # Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci)
34
+ self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous())
35
+
36
+
37
+ def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
38
+ flex_gemm.ops.spconv.set_algorithm(config.FLEX_GEMM_ALGO)
39
+ flex_gemm.ops.spconv.set_hashmap_ratio(config.FLEX_GEMM_HASHMAP_RATIO)
40
+
41
+ # check if neighbor map is already computed
42
+ Co, Kd, Kh, Kw, Ci = self.weight.shape
43
+ neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
44
+ neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
45
+
46
+ out, neighbor_cache_ = sparse_submanifold_conv3d(
47
+ x.feats,
48
+ x.coords,
49
+ torch.Size([*x.shape, *x.spatial_shape]),
50
+ self.weight,
51
+ self.bias,
52
+ neighbor_cache,
53
+ self.dilation
54
+ )
55
+
56
+ if neighbor_cache is None:
57
+ x.register_spatial_cache(neighbor_cache_key, neighbor_cache_)
58
+
59
+ out = x.replace(out)
60
+ return out
61
+
62
+
63
+ def sparse_inverse_conv3d_init(self, *args, **kwargs):
64
+ raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet')
65
+
66
+
67
+ def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
68
+ raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet')
trellis2/modules/sparse/conv/conv_spconv.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .. import SparseTensor
4
+ from . import config
5
+ import spconv.pytorch as spconv
6
+
7
+
8
+ def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
9
+ algo = None
10
+ if config.SPCONV_ALGO == 'native':
11
+ algo = spconv.ConvAlgo.Native
12
+ elif config.SPCONV_ALGO == 'implicit_gemm':
13
+ algo = spconv.ConvAlgo.MaskImplicitGemm
14
+ if stride == 1 and (padding is None):
15
+ self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
16
+ else:
17
+ self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
18
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
19
+ self.padding = padding
20
+
21
+
22
+ def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
23
+ spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
24
+ new_data = self.conv(x.data)
25
+ new_shape = [x.shape[0], self.conv.out_channels]
26
+ new_layout = None if spatial_changed else x.layout
27
+
28
+ if spatial_changed and (x.shape[0] != 1):
29
+ # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
30
+ fwd = new_data.indices[:, 0].argsort()
31
+ bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
32
+ sorted_feats = new_data.features[fwd]
33
+ sorted_coords = new_data.indices[fwd]
34
+ unsorted_data = new_data
35
+ new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
36
+
37
+ out = SparseTensor(
38
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
39
+ scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
40
+ spatial_cache=x._spatial_cache,
41
+ )
42
+
43
+ if spatial_changed and (x.shape[0] != 1):
44
+ out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
45
+ out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
46
+
47
+ return out
48
+
49
+
50
+ def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
51
+ self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
52
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
53
+
54
+
55
+ def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
56
+ spatial_changed = any(s != 1 for s in self.stride)
57
+ if spatial_changed:
58
+ # recover the original spconv order
59
+ data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
60
+ bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
61
+ data = data.replace_feature(x.feats[bwd])
62
+ else:
63
+ data = x.data
64
+
65
+ new_data = self.conv(data)
66
+ new_shape = [x.shape[0], self.conv.out_channels]
67
+ new_layout = None if spatial_changed else x.layout
68
+ out = SparseTensor(
69
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
70
+ scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
71
+ spatial_cache=x._spatial_cache,
72
+ )
73
+ return out
trellis2/modules/sparse/conv/conv_torchsparse.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .. import SparseTensor
4
+ import torchsparse
5
+
6
+
7
+ def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
8
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
9
+
10
+
11
+ def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
12
+ out = self.conv(x.data)
13
+ new_shape = [x.shape[0], self.conv.out_channels]
14
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
15
+ out._spatial_cache = x._spatial_cache
16
+ out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
17
+ return out
18
+
19
+
20
+ def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
21
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
22
+
23
+
24
+ def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
25
+ out = self.conv(x.data)
26
+ new_shape = [x.shape[0], self.conv.out_channels]
27
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
28
+ out._spatial_cache = x._spatial_cache
29
+ out._scale = tuple([s / stride for s, stride in zip(x._scale, self.conv.stride)])
30
+ return out
trellis2/modules/sparse/linear.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import VarLenTensor
4
+
5
+ __all__ = [
6
+ 'SparseLinear'
7
+ ]
8
+
9
+
10
+ class SparseLinear(nn.Linear):
11
+ def __init__(self, in_features, out_features, bias=True):
12
+ super(SparseLinear, self).__init__(in_features, out_features, bias)
13
+
14
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
15
+ return input.replace(super().forward(input.feats))
trellis2/modules/sparse/nonlinearity.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import VarLenTensor
4
+
5
+ __all__ = [
6
+ 'SparseReLU',
7
+ 'SparseSiLU',
8
+ 'SparseGELU',
9
+ 'SparseActivation'
10
+ ]
11
+
12
+
13
+ class SparseReLU(nn.ReLU):
14
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
15
+ return input.replace(super().forward(input.feats))
16
+
17
+
18
+ class SparseSiLU(nn.SiLU):
19
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
20
+ return input.replace(super().forward(input.feats))
21
+
22
+
23
+ class SparseGELU(nn.GELU):
24
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
25
+ return input.replace(super().forward(input.feats))
26
+
27
+
28
+ class SparseActivation(nn.Module):
29
+ def __init__(self, activation: nn.Module):
30
+ super().__init__()
31
+ self.activation = activation
32
+
33
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
34
+ return input.replace(self.activation(input.feats))
35
+
trellis2/modules/sparse/norm.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..utils import manual_cast
4
+ from . import VarLenTensor
5
+ from . import config
6
+
7
+ __all__ = [
8
+ 'SparseGroupNorm',
9
+ 'SparseLayerNorm',
10
+ 'SparseGroupNorm32',
11
+ 'SparseLayerNorm32',
12
+ ]
13
+
14
+
15
+ class SparseGroupNorm(nn.GroupNorm):
16
+ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
17
+ super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
18
+
19
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
20
+ nfeats = torch.zeros_like(input.feats)
21
+ for k in range(input.shape[0]):
22
+ bfeats = input.feats[input.layout[k]]
23
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
24
+ bfeats = super().forward(bfeats)
25
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
26
+ nfeats[input.layout[k]] = bfeats
27
+ return input.replace(nfeats)
28
+
29
+
30
+ class SparseLayerNorm(nn.LayerNorm):
31
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
32
+ super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
33
+
34
+ def forward(self, input: VarLenTensor) -> VarLenTensor:
35
+ nfeats = torch.zeros_like(input.feats)
36
+ for k in range(input.shape[0]):
37
+ bfeats = input.feats[input.layout[k]]
38
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
39
+ bfeats = super().forward(bfeats)
40
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
41
+ nfeats[input.layout[k]] = bfeats
42
+ return input.replace(nfeats)
43
+
44
+
45
+ class SparseGroupNorm32(SparseGroupNorm):
46
+ """
47
+ A GroupNorm layer that converts to float32 before the forward pass.
48
+ """
49
+ def forward(self, x: VarLenTensor) -> VarLenTensor:
50
+ x_dtype = x.dtype
51
+ x = manual_cast(x, torch.float32)
52
+ o = super().forward(x)
53
+ return manual_cast(o, x_dtype)
54
+
55
+
56
+ class SparseLayerNorm32(SparseLayerNorm):
57
+ """
58
+ A LayerNorm layer that converts to float32 before the forward pass.
59
+ """
60
+ def forward(self, x: VarLenTensor) -> VarLenTensor:
61
+ x_dtype = x.dtype
62
+ x = manual_cast(x, torch.float32)
63
+ o = super().forward(x)
64
+ return manual_cast(o, x_dtype)
trellis2/modules/sparse/spatial/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .basic import *
2
+ from .spatial2channel import *
trellis2/modules/sparse/spatial/basic.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from .. import SparseTensor
5
+
6
+ __all__ = [
7
+ 'SparseDownsample',
8
+ 'SparseUpsample',
9
+ ]
10
+
11
+
12
+ class SparseDownsample(nn.Module):
13
+ """
14
+ Downsample a sparse tensor by a factor of `factor`.
15
+ Implemented as average pooling.
16
+ """
17
+ def __init__(self, factor: int, mode: Literal['mean', 'max'] = 'mean'):
18
+ super(SparseDownsample, self).__init__()
19
+ self.factor = factor
20
+ self.mode = mode
21
+ assert self.mode in ['mean', 'max'], f'Invalid mode: {self.mode}'
22
+
23
+ def forward(self, x: SparseTensor) -> SparseTensor:
24
+ cache = x.get_spatial_cache(f'downsample_{self.factor}')
25
+ if cache is None:
26
+ DIM = x.coords.shape[-1] - 1
27
+
28
+ coord = list(x.coords.unbind(dim=-1))
29
+ for i in range(DIM):
30
+ coord[i+1] = coord[i+1] // self.factor
31
+
32
+ MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
33
+ OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
34
+ code = sum([c * o for c, o in zip(coord, OFFSET)])
35
+ code, idx = code.unique(return_inverse=True)
36
+
37
+ new_coords = torch.stack(
38
+ [code // OFFSET[0]] +
39
+ [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
40
+ dim=-1
41
+ )
42
+ else:
43
+ new_coords, idx = cache
44
+
45
+ new_feats = torch.scatter_reduce(
46
+ torch.zeros(new_coords.shape[0], x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype),
47
+ dim=0,
48
+ index=idx.unsqueeze(1).expand(-1, x.feats.shape[1]),
49
+ src=x.feats,
50
+ reduce=self.mode,
51
+ include_self=False,
52
+ )
53
+ out = SparseTensor(new_feats, new_coords, x._shape)
54
+ out._scale = tuple([s * self.factor for s in x._scale])
55
+ out._spatial_cache = x._spatial_cache
56
+
57
+ if cache is None:
58
+ x.register_spatial_cache(f'downsample_{self.factor}', (new_coords, idx))
59
+ out.register_spatial_cache(f'upsample_{self.factor}', (x.coords, idx))
60
+ out.register_spatial_cache(f'shape', torch.Size(MAX))
61
+ if self.training:
62
+ subidx = x.coords[:, 1:] % self.factor
63
+ subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
64
+ subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool)
65
+ subdivision[idx, subidx] = True
66
+ out.register_spatial_cache(f'subdivision', subdivision)
67
+
68
+ return out
69
+
70
+
71
+ class SparseUpsample(nn.Module):
72
+ """
73
+ Upsample a sparse tensor by a factor of `factor`.
74
+ Implemented as nearest neighbor interpolation.
75
+ """
76
+ def __init__(
77
+ self, factor: int
78
+ ):
79
+ super(SparseUpsample, self).__init__()
80
+ self.factor = factor
81
+
82
+ def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor:
83
+ DIM = x.coords.shape[-1] - 1
84
+
85
+ cache = x.get_spatial_cache(f'upsample_{self.factor}')
86
+ if cache is None:
87
+ if subdivision is None:
88
+ raise ValueError('Cache not found. Provide subdivision tensor or pair SparseUpsample with SparseDownsample.')
89
+ else:
90
+ sub = subdivision.feats
91
+ N_leaf = sub.sum(dim=-1)
92
+ subidx = sub.nonzero()[:, -1]
93
+ new_coords = x.coords.clone().detach()
94
+ new_coords[:, 1:] *= self.factor
95
+ new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0])
96
+ for i in range(DIM):
97
+ new_coords[:, i+1] += subidx // self.factor ** i % self.factor
98
+ idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0])
99
+ else:
100
+ new_coords, idx = cache
101
+
102
+ new_feats = x.feats[idx]
103
+ out = SparseTensor(new_feats, new_coords, x._shape)
104
+ out._scale = tuple([s / self.factor for s in x._scale])
105
+ if cache is not None: # only keep cache when subdiv following it
106
+ out._spatial_cache = x._spatial_cache
107
+
108
+ return out
109
+
trellis2/modules/sparse/spatial/spatial2channel.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from .. import SparseTensor
5
+
6
+
7
+ class SparseSpatial2Channel(nn.Module):
8
+ """
9
+ Downsample a sparse tensor by a factor of `factor`.
10
+ Implemented as rearranging its features from spatial to channel.
11
+ """
12
+ def __init__(self, factor: int = 2):
13
+ super(SparseSpatial2Channel, self).__init__()
14
+ self.factor = factor
15
+
16
+ def forward(self, x: SparseTensor) -> SparseTensor:
17
+ DIM = x.coords.shape[-1] - 1
18
+ cache = x.get_spatial_cache(f'spatial2channel_{self.factor}')
19
+ if cache is None:
20
+ coord = list(x.coords.unbind(dim=-1))
21
+ for i in range(DIM):
22
+ coord[i+1] = coord[i+1] // self.factor
23
+ subidx = x.coords[:, 1:] % self.factor
24
+ subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
25
+
26
+ MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
27
+ OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
28
+ code = sum([c * o for c, o in zip(coord, OFFSET)])
29
+ code, idx = code.unique(return_inverse=True)
30
+
31
+ new_coords = torch.stack(
32
+ [code // OFFSET[0]] +
33
+ [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
34
+ dim=-1
35
+ )
36
+ else:
37
+ new_coords, idx, subidx = cache
38
+
39
+ new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype)
40
+ new_feats[idx * self.factor ** DIM + subidx] = x.feats
41
+
42
+ out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM]))
43
+ out._scale = tuple([s * self.factor for s in x._scale])
44
+ out._spatial_cache = x._spatial_cache
45
+
46
+ if cache is None:
47
+ x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx))
48
+ out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx))
49
+ out.register_spatial_cache(f'shape', torch.Size(MAX))
50
+ if self.training:
51
+ subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool)
52
+ subdivision[idx, subidx] = True
53
+ out.register_spatial_cache(f'subdivision', subdivision)
54
+
55
+ return out
56
+
57
+
58
+ class SparseChannel2Spatial(nn.Module):
59
+ """
60
+ Upsample a sparse tensor by a factor of `factor`.
61
+ Implemented as rearranging its features from channel to spatial.
62
+ """
63
+ def __init__(self, factor: int = 2):
64
+ super(SparseChannel2Spatial, self).__init__()
65
+ self.factor = factor
66
+
67
+ def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor:
68
+ DIM = x.coords.shape[-1] - 1
69
+
70
+ cache = x.get_spatial_cache(f'channel2spatial_{self.factor}')
71
+ if cache is None:
72
+ if subdivision is None:
73
+ raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.')
74
+ else:
75
+ sub = subdivision.feats # [N, self.factor ** DIM]
76
+ N_leaf = sub.sum(dim=-1) # [N]
77
+ subidx = sub.nonzero()[:, -1]
78
+ new_coords = x.coords.clone().detach()
79
+ new_coords[:, 1:] *= self.factor
80
+ new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0])
81
+ for i in range(DIM):
82
+ new_coords[:, i+1] += subidx // self.factor ** i % self.factor
83
+ idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0])
84
+ else:
85
+ new_coords, idx, subidx = cache
86
+
87
+ x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1)
88
+ new_feats = x_feats[idx * self.factor ** DIM + subidx]
89
+ out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM]))
90
+ out._scale = tuple([s / self.factor for s in x._scale])
91
+ if cache is not None: # only keep cache when subdiv following it
92
+ out._spatial_cache = x._spatial_cache
93
+ return out
trellis2/modules/sparse/transformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .blocks import *
2
+ from .modulated import *
trellis2/modules/sparse/transformer/blocks.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..basic import VarLenTensor, SparseTensor
5
+ from ..linear import SparseLinear
6
+ from ..nonlinearity import SparseGELU
7
+ from ..attention import SparseMultiHeadAttention
8
+ from ...norm import LayerNorm32
9
+
10
+
11
+ class SparseFeedForwardNet(nn.Module):
12
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
13
+ super().__init__()
14
+ self.mlp = nn.Sequential(
15
+ SparseLinear(channels, int(channels * mlp_ratio)),
16
+ SparseGELU(approximate="tanh"),
17
+ SparseLinear(int(channels * mlp_ratio), channels),
18
+ )
19
+
20
+ def forward(self, x: VarLenTensor) -> VarLenTensor:
21
+ return self.mlp(x)
22
+
23
+
24
+ class SparseTransformerBlock(nn.Module):
25
+ """
26
+ Sparse Transformer block (MSA + FFN).
27
+ """
28
+ def __init__(
29
+ self,
30
+ channels: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ attn_mode: Literal["full", "swin"] = "full",
34
+ window_size: Optional[int] = None,
35
+ shift_window: Optional[Tuple[int, int, int]] = None,
36
+ use_checkpoint: bool = False,
37
+ use_rope: bool = False,
38
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
39
+ qk_rms_norm: bool = False,
40
+ qkv_bias: bool = True,
41
+ ln_affine: bool = False,
42
+ ):
43
+ super().__init__()
44
+ self.use_checkpoint = use_checkpoint
45
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
46
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
47
+ self.attn = SparseMultiHeadAttention(
48
+ channels,
49
+ num_heads=num_heads,
50
+ attn_mode=attn_mode,
51
+ window_size=window_size,
52
+ shift_window=shift_window,
53
+ qkv_bias=qkv_bias,
54
+ use_rope=use_rope,
55
+ rope_freq=rope_freq,
56
+ qk_rms_norm=qk_rms_norm,
57
+ )
58
+ self.mlp = SparseFeedForwardNet(
59
+ channels,
60
+ mlp_ratio=mlp_ratio,
61
+ )
62
+
63
+ def _forward(self, x: SparseTensor) -> SparseTensor:
64
+ h = x.replace(self.norm1(x.feats))
65
+ h = self.attn(h)
66
+ x = x + h
67
+ h = x.replace(self.norm2(x.feats))
68
+ h = self.mlp(h)
69
+ x = x + h
70
+ return x
71
+
72
+ def forward(self, x: SparseTensor) -> SparseTensor:
73
+ if self.use_checkpoint:
74
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
75
+ else:
76
+ return self._forward(x)
77
+
78
+
79
+ class SparseTransformerCrossBlock(nn.Module):
80
+ """
81
+ Sparse Transformer cross-attention block (MSA + MCA + FFN).
82
+ """
83
+ def __init__(
84
+ self,
85
+ channels: int,
86
+ ctx_channels: int,
87
+ num_heads: int,
88
+ mlp_ratio: float = 4.0,
89
+ attn_mode: Literal["full", "swin"] = "full",
90
+ window_size: Optional[int] = None,
91
+ shift_window: Optional[Tuple[int, int, int]] = None,
92
+ use_checkpoint: bool = False,
93
+ use_rope: bool = False,
94
+ qk_rms_norm: bool = False,
95
+ qk_rms_norm_cross: bool = False,
96
+ qkv_bias: bool = True,
97
+ ln_affine: bool = False,
98
+ ):
99
+ super().__init__()
100
+ self.use_checkpoint = use_checkpoint
101
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
102
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
103
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
104
+ self.self_attn = SparseMultiHeadAttention(
105
+ channels,
106
+ num_heads=num_heads,
107
+ type="self",
108
+ attn_mode=attn_mode,
109
+ window_size=window_size,
110
+ shift_window=shift_window,
111
+ qkv_bias=qkv_bias,
112
+ use_rope=use_rope,
113
+ qk_rms_norm=qk_rms_norm,
114
+ )
115
+ self.cross_attn = SparseMultiHeadAttention(
116
+ channels,
117
+ ctx_channels=ctx_channels,
118
+ num_heads=num_heads,
119
+ type="cross",
120
+ attn_mode="full",
121
+ qkv_bias=qkv_bias,
122
+ qk_rms_norm=qk_rms_norm_cross,
123
+ )
124
+ self.mlp = SparseFeedForwardNet(
125
+ channels,
126
+ mlp_ratio=mlp_ratio,
127
+ )
128
+
129
+ def _forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
130
+ h = x.replace(self.norm1(x.feats))
131
+ h = self.self_attn(h)
132
+ x = x + h
133
+ h = x.replace(self.norm2(x.feats))
134
+ h = self.cross_attn(h, context)
135
+ x = x + h
136
+ h = x.replace(self.norm3(x.feats))
137
+ h = self.mlp(h)
138
+ x = x + h
139
+ return x
140
+
141
+ def forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
142
+ if self.use_checkpoint:
143
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
144
+ else:
145
+ return self._forward(x, context)
trellis2/modules/sparse/transformer/modulated.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..basic import VarLenTensor, SparseTensor
5
+ from ..attention import SparseMultiHeadAttention
6
+ from ...norm import LayerNorm32
7
+ from .blocks import SparseFeedForwardNet
8
+
9
+
10
+ class ModulatedSparseTransformerBlock(nn.Module):
11
+ """
12
+ Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
13
+ """
14
+ def __init__(
15
+ self,
16
+ channels: int,
17
+ num_heads: int,
18
+ mlp_ratio: float = 4.0,
19
+ attn_mode: Literal["full", "swin"] = "full",
20
+ window_size: Optional[int] = None,
21
+ shift_window: Optional[Tuple[int, int, int]] = None,
22
+ use_checkpoint: bool = False,
23
+ use_rope: bool = False,
24
+ rope_freq: Tuple[float, float] = (1.0, 10000.0),
25
+ qk_rms_norm: bool = False,
26
+ qkv_bias: bool = True,
27
+ share_mod: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.use_checkpoint = use_checkpoint
31
+ self.share_mod = share_mod
32
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
33
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
34
+ self.attn = SparseMultiHeadAttention(
35
+ channels,
36
+ num_heads=num_heads,
37
+ attn_mode=attn_mode,
38
+ window_size=window_size,
39
+ shift_window=shift_window,
40
+ qkv_bias=qkv_bias,
41
+ use_rope=use_rope,
42
+ rope_freq=rope_freq,
43
+ qk_rms_norm=qk_rms_norm,
44
+ )
45
+ self.mlp = SparseFeedForwardNet(
46
+ channels,
47
+ mlp_ratio=mlp_ratio,
48
+ )
49
+ if not share_mod:
50
+ self.adaLN_modulation = nn.Sequential(
51
+ nn.SiLU(),
52
+ nn.Linear(channels, 6 * channels, bias=True)
53
+ )
54
+ else:
55
+ self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
56
+
57
+ def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
58
+ if self.share_mod:
59
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
60
+ else:
61
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
62
+ h = x.replace(self.norm1(x.feats))
63
+ h = h * (1 + scale_msa) + shift_msa
64
+ h = self.attn(h)
65
+ h = h * gate_msa
66
+ x = x + h
67
+ h = x.replace(self.norm2(x.feats))
68
+ h = h * (1 + scale_mlp) + shift_mlp
69
+ h = self.mlp(h)
70
+ h = h * gate_mlp
71
+ x = x + h
72
+ return x
73
+
74
+ def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
75
+ if self.use_checkpoint:
76
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
77
+ else:
78
+ return self._forward(x, mod)
79
+
80
+
81
+ class ModulatedSparseTransformerCrossBlock(nn.Module):
82
+ """
83
+ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
84
+ """
85
+ def __init__(
86
+ self,
87
+ channels: int,
88
+ ctx_channels: int,
89
+ num_heads: int,
90
+ mlp_ratio: float = 4.0,
91
+ attn_mode: Literal["full", "swin"] = "full",
92
+ window_size: Optional[int] = None,
93
+ shift_window: Optional[Tuple[int, int, int]] = None,
94
+ use_checkpoint: bool = False,
95
+ use_rope: bool = False,
96
+ rope_freq: Tuple[float, float] = (1.0, 10000.0),
97
+ qk_rms_norm: bool = False,
98
+ qk_rms_norm_cross: bool = False,
99
+ qkv_bias: bool = True,
100
+ share_mod: bool = False,
101
+
102
+ ):
103
+ super().__init__()
104
+ self.use_checkpoint = use_checkpoint
105
+ self.share_mod = share_mod
106
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
107
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
108
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
109
+ self.self_attn = SparseMultiHeadAttention(
110
+ channels,
111
+ num_heads=num_heads,
112
+ type="self",
113
+ attn_mode=attn_mode,
114
+ window_size=window_size,
115
+ shift_window=shift_window,
116
+ qkv_bias=qkv_bias,
117
+ use_rope=use_rope,
118
+ rope_freq=rope_freq,
119
+ qk_rms_norm=qk_rms_norm,
120
+ )
121
+ self.cross_attn = SparseMultiHeadAttention(
122
+ channels,
123
+ ctx_channels=ctx_channels,
124
+ num_heads=num_heads,
125
+ type="cross",
126
+ attn_mode="full",
127
+ qkv_bias=qkv_bias,
128
+ qk_rms_norm=qk_rms_norm_cross,
129
+ )
130
+ self.mlp = SparseFeedForwardNet(
131
+ channels,
132
+ mlp_ratio=mlp_ratio,
133
+ )
134
+ if not share_mod:
135
+ self.adaLN_modulation = nn.Sequential(
136
+ nn.SiLU(),
137
+ nn.Linear(channels, 6 * channels, bias=True)
138
+ )
139
+ else:
140
+ self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
141
+
142
+ def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
143
+ if self.share_mod:
144
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
145
+ else:
146
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
147
+ h = x.replace(self.norm1(x.feats))
148
+ h = h * (1 + scale_msa) + shift_msa
149
+ h = self.self_attn(h)
150
+ h = h * gate_msa
151
+ x = x + h
152
+ h = x.replace(self.norm2(x.feats))
153
+ h = self.cross_attn(h, context)
154
+ x = x + h
155
+ h = x.replace(self.norm3(x.feats))
156
+ h = h * (1 + scale_mlp) + shift_mlp
157
+ h = self.mlp(h)
158
+ h = h * gate_mlp
159
+ x = x + h
160
+ return x
161
+
162
+ def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
163
+ if self.use_checkpoint:
164
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
165
+ else:
166
+ return self._forward(x, mod, context)
trellis2/modules/spatial.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
5
+ """
6
+ 3D pixel shuffle.
7
+ """
8
+ B, C, H, W, D = x.shape
9
+ C_ = C // scale_factor**3
10
+ x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
11
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
12
+ x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
13
+ return x
14
+
15
+
16
+ def patchify(x: torch.Tensor, patch_size: int):
17
+ """
18
+ Patchify a tensor.
19
+
20
+ Args:
21
+ x (torch.Tensor): (N, C, *spatial) tensor
22
+ patch_size (int): Patch size
23
+ """
24
+ DIM = x.dim() - 2
25
+ for d in range(2, DIM + 2):
26
+ assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
27
+
28
+ x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
29
+ x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
30
+ x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
31
+ return x
32
+
33
+
34
+ def unpatchify(x: torch.Tensor, patch_size: int):
35
+ """
36
+ Unpatchify a tensor.
37
+
38
+ Args:
39
+ x (torch.Tensor): (N, C, *spatial) tensor
40
+ patch_size (int): Patch size
41
+ """
42
+ DIM = x.dim() - 2
43
+ assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
44
+
45
+ x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
46
+ x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
47
+ x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
48
+ return x
trellis2/modules/transformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .blocks import *
2
+ from .modulated import *
trellis2/modules/transformer/blocks.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..attention import MultiHeadAttention
5
+ from ..norm import LayerNorm32
6
+
7
+
8
+ class AbsolutePositionEmbedder(nn.Module):
9
+ """
10
+ Embeds spatial positions into vector representations.
11
+ """
12
+ def __init__(self, channels: int, in_channels: int = 3):
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.in_channels = in_channels
16
+ self.freq_dim = channels // in_channels // 2
17
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
18
+ self.freqs = 1.0 / (10000 ** self.freqs)
19
+
20
+ def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
21
+ """
22
+ Create sinusoidal position embeddings.
23
+
24
+ Args:
25
+ x: a 1-D Tensor of N indices
26
+
27
+ Returns:
28
+ an (N, D) Tensor of positional embeddings.
29
+ """
30
+ self.freqs = self.freqs.to(x.device)
31
+ out = torch.outer(x, self.freqs)
32
+ out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
33
+ return out
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ """
37
+ Args:
38
+ x (torch.Tensor): (N, D) tensor of spatial positions
39
+ """
40
+ N, D = x.shape
41
+ assert D == self.in_channels, "Input dimension must match number of input channels"
42
+ embed = self._sin_cos_embedding(x.reshape(-1))
43
+ embed = embed.reshape(N, -1)
44
+ if embed.shape[1] < self.channels:
45
+ embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
46
+ return embed
47
+
48
+
49
+ class FeedForwardNet(nn.Module):
50
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
51
+ super().__init__()
52
+ self.mlp = nn.Sequential(
53
+ nn.Linear(channels, int(channels * mlp_ratio)),
54
+ nn.GELU(approximate="tanh"),
55
+ nn.Linear(int(channels * mlp_ratio), channels),
56
+ )
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ return self.mlp(x)
60
+
61
+
62
+ class TransformerBlock(nn.Module):
63
+ """
64
+ Transformer block (MSA + FFN).
65
+ """
66
+ def __init__(
67
+ self,
68
+ channels: int,
69
+ num_heads: int,
70
+ mlp_ratio: float = 4.0,
71
+ attn_mode: Literal["full", "windowed"] = "full",
72
+ window_size: Optional[int] = None,
73
+ shift_window: Optional[int] = None,
74
+ use_checkpoint: bool = False,
75
+ use_rope: bool = False,
76
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
77
+ qk_rms_norm: bool = False,
78
+ qkv_bias: bool = True,
79
+ ln_affine: bool = True,
80
+ ):
81
+ super().__init__()
82
+ self.use_checkpoint = use_checkpoint
83
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
84
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
85
+ self.attn = MultiHeadAttention(
86
+ channels,
87
+ num_heads=num_heads,
88
+ attn_mode=attn_mode,
89
+ window_size=window_size,
90
+ shift_window=shift_window,
91
+ qkv_bias=qkv_bias,
92
+ use_rope=use_rope,
93
+ rope_freq=rope_freq,
94
+ qk_rms_norm=qk_rms_norm,
95
+ )
96
+ self.mlp = FeedForwardNet(
97
+ channels,
98
+ mlp_ratio=mlp_ratio,
99
+ )
100
+
101
+ def _forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
102
+ h = self.norm1(x)
103
+ h = self.attn(h, phases=phases)
104
+ x = x + h
105
+ h = self.norm2(x)
106
+ h = self.mlp(h)
107
+ x = x + h
108
+ return x
109
+
110
+ def forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
111
+ if self.use_checkpoint:
112
+ return torch.utils.checkpoint.checkpoint(self._forward, x, phases, use_reentrant=False)
113
+ else:
114
+ return self._forward(x, phases)
115
+
116
+
117
+ class TransformerCrossBlock(nn.Module):
118
+ """
119
+ Transformer cross-attention block (MSA + MCA + FFN).
120
+ """
121
+ def __init__(
122
+ self,
123
+ channels: int,
124
+ ctx_channels: int,
125
+ num_heads: int,
126
+ mlp_ratio: float = 4.0,
127
+ attn_mode: Literal["full", "windowed"] = "full",
128
+ window_size: Optional[int] = None,
129
+ shift_window: Optional[Tuple[int, int, int]] = None,
130
+ use_checkpoint: bool = False,
131
+ use_rope: bool = False,
132
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
133
+ qk_rms_norm: bool = False,
134
+ qk_rms_norm_cross: bool = False,
135
+ qkv_bias: bool = True,
136
+ ln_affine: bool = False,
137
+ ):
138
+ super().__init__()
139
+ self.use_checkpoint = use_checkpoint
140
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
141
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
142
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
143
+ self.self_attn = MultiHeadAttention(
144
+ channels,
145
+ num_heads=num_heads,
146
+ type="self",
147
+ attn_mode=attn_mode,
148
+ window_size=window_size,
149
+ shift_window=shift_window,
150
+ qkv_bias=qkv_bias,
151
+ use_rope=use_rope,
152
+ rope_freq=rope_freq,
153
+ qk_rms_norm=qk_rms_norm,
154
+ )
155
+ self.cross_attn = MultiHeadAttention(
156
+ channels,
157
+ ctx_channels=ctx_channels,
158
+ num_heads=num_heads,
159
+ type="cross",
160
+ attn_mode="full",
161
+ qkv_bias=qkv_bias,
162
+ qk_rms_norm=qk_rms_norm_cross,
163
+ )
164
+ self.mlp = FeedForwardNet(
165
+ channels,
166
+ mlp_ratio=mlp_ratio,
167
+ )
168
+
169
+ def _forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
170
+ h = self.norm1(x)
171
+ h = self.self_attn(h, phases=phases)
172
+ x = x + h
173
+ h = self.norm2(x)
174
+ h = self.cross_attn(h, context)
175
+ x = x + h
176
+ h = self.norm3(x)
177
+ h = self.mlp(h)
178
+ x = x + h
179
+ return x
180
+
181
+ def forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
182
+ if self.use_checkpoint:
183
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, phases, use_reentrant=False)
184
+ else:
185
+ return self._forward(x, context, phases)
186
+
trellis2/modules/transformer/modulated.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..attention import MultiHeadAttention
5
+ from ..norm import LayerNorm32
6
+ from .blocks import FeedForwardNet
7
+
8
+
9
+ class ModulatedTransformerBlock(nn.Module):
10
+ """
11
+ Transformer block (MSA + FFN) with adaptive layer norm conditioning.
12
+ """
13
+ def __init__(
14
+ self,
15
+ channels: int,
16
+ num_heads: int,
17
+ mlp_ratio: float = 4.0,
18
+ attn_mode: Literal["full", "windowed"] = "full",
19
+ window_size: Optional[int] = None,
20
+ shift_window: Optional[Tuple[int, int, int]] = None,
21
+ use_checkpoint: bool = False,
22
+ use_rope: bool = False,
23
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
24
+ qk_rms_norm: bool = False,
25
+ qkv_bias: bool = True,
26
+ share_mod: bool = False,
27
+ ):
28
+ super().__init__()
29
+ self.use_checkpoint = use_checkpoint
30
+ self.share_mod = share_mod
31
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
32
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
33
+ self.attn = MultiHeadAttention(
34
+ channels,
35
+ num_heads=num_heads,
36
+ attn_mode=attn_mode,
37
+ window_size=window_size,
38
+ shift_window=shift_window,
39
+ qkv_bias=qkv_bias,
40
+ use_rope=use_rope,
41
+ rope_freq=rope_freq,
42
+ qk_rms_norm=qk_rms_norm,
43
+ )
44
+ self.mlp = FeedForwardNet(
45
+ channels,
46
+ mlp_ratio=mlp_ratio,
47
+ )
48
+ if not share_mod:
49
+ self.adaLN_modulation = nn.Sequential(
50
+ nn.SiLU(),
51
+ nn.Linear(channels, 6 * channels, bias=True)
52
+ )
53
+ else:
54
+ self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
55
+
56
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
57
+ if self.share_mod:
58
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
59
+ else:
60
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
61
+ h = self.norm1(x)
62
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
63
+ h = self.attn(h, phases=phases)
64
+ h = h * gate_msa.unsqueeze(1)
65
+ x = x + h
66
+ h = self.norm2(x)
67
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
68
+ h = self.mlp(h)
69
+ h = h * gate_mlp.unsqueeze(1)
70
+ x = x + h
71
+ return x
72
+
73
+ def forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
74
+ if self.use_checkpoint:
75
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, phases, use_reentrant=False)
76
+ else:
77
+ return self._forward(x, mod, phases)
78
+
79
+
80
+ class ModulatedTransformerCrossBlock(nn.Module):
81
+ """
82
+ Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
83
+ """
84
+ def __init__(
85
+ self,
86
+ channels: int,
87
+ ctx_channels: int,
88
+ num_heads: int,
89
+ mlp_ratio: float = 4.0,
90
+ attn_mode: Literal["full", "windowed"] = "full",
91
+ window_size: Optional[int] = None,
92
+ shift_window: Optional[Tuple[int, int, int]] = None,
93
+ use_checkpoint: bool = False,
94
+ use_rope: bool = False,
95
+ rope_freq: Tuple[int, int] = (1.0, 10000.0),
96
+ qk_rms_norm: bool = False,
97
+ qk_rms_norm_cross: bool = False,
98
+ qkv_bias: bool = True,
99
+ share_mod: bool = False,
100
+ ):
101
+ super().__init__()
102
+ self.use_checkpoint = use_checkpoint
103
+ self.share_mod = share_mod
104
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
105
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
106
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
107
+ self.self_attn = MultiHeadAttention(
108
+ channels,
109
+ num_heads=num_heads,
110
+ type="self",
111
+ attn_mode=attn_mode,
112
+ window_size=window_size,
113
+ shift_window=shift_window,
114
+ qkv_bias=qkv_bias,
115
+ use_rope=use_rope,
116
+ rope_freq=rope_freq,
117
+ qk_rms_norm=qk_rms_norm,
118
+ )
119
+ self.cross_attn = MultiHeadAttention(
120
+ channels,
121
+ ctx_channels=ctx_channels,
122
+ num_heads=num_heads,
123
+ type="cross",
124
+ attn_mode="full",
125
+ qkv_bias=qkv_bias,
126
+ qk_rms_norm=qk_rms_norm_cross,
127
+ )
128
+ self.mlp = FeedForwardNet(
129
+ channels,
130
+ mlp_ratio=mlp_ratio,
131
+ )
132
+ if not share_mod:
133
+ self.adaLN_modulation = nn.Sequential(
134
+ nn.SiLU(),
135
+ nn.Linear(channels, 6 * channels, bias=True)
136
+ )
137
+ else:
138
+ self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
139
+
140
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
141
+ if self.share_mod:
142
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
143
+ else:
144
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
145
+ h = self.norm1(x)
146
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
147
+ h = self.self_attn(h, phases=phases)
148
+ h = h * gate_msa.unsqueeze(1)
149
+ x = x + h
150
+ h = self.norm2(x)
151
+ h = self.cross_attn(h, context)
152
+ x = x + h
153
+ h = self.norm3(x)
154
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
155
+ h = self.mlp(h)
156
+ h = h * gate_mlp.unsqueeze(1)
157
+ x = x + h
158
+ return x
159
+
160
+ def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
161
+ if self.use_checkpoint:
162
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False)
163
+ else:
164
+ return self._forward(x, mod, context, phases)
165
+
trellis2/modules/utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..modules import sparse as sp
4
+
5
+ MIX_PRECISION_MODULES = (
6
+ nn.Conv1d,
7
+ nn.Conv2d,
8
+ nn.Conv3d,
9
+ nn.ConvTranspose1d,
10
+ nn.ConvTranspose2d,
11
+ nn.ConvTranspose3d,
12
+ nn.Linear,
13
+ sp.SparseConv3d,
14
+ sp.SparseInverseConv3d,
15
+ sp.SparseLinear,
16
+ )
17
+
18
+
19
+ def convert_module_to_f16(l):
20
+ """
21
+ Convert primitive modules to float16.
22
+ """
23
+ if isinstance(l, MIX_PRECISION_MODULES):
24
+ for p in l.parameters():
25
+ p.data = p.data.half()
26
+
27
+
28
+ def convert_module_to_f32(l):
29
+ """
30
+ Convert primitive modules to float32, undoing convert_module_to_f16().
31
+ """
32
+ if isinstance(l, MIX_PRECISION_MODULES):
33
+ for p in l.parameters():
34
+ p.data = p.data.float()
35
+
36
+
37
+ def convert_module_to(l, dtype):
38
+ """
39
+ Convert primitive modules to the given dtype.
40
+ """
41
+ if isinstance(l, MIX_PRECISION_MODULES):
42
+ for p in l.parameters():
43
+ p.data = p.data.to(dtype)
44
+
45
+
46
+ def zero_module(module):
47
+ """
48
+ Zero out the parameters of a module and return it.
49
+ """
50
+ for p in module.parameters():
51
+ p.detach().zero_()
52
+ return module
53
+
54
+
55
+ def scale_module(module, scale):
56
+ """
57
+ Scale the parameters of a module and return it.
58
+ """
59
+ for p in module.parameters():
60
+ p.detach().mul_(scale)
61
+ return module
62
+
63
+
64
+ def modulate(x, shift, scale):
65
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
66
+
67
+
68
+ def manual_cast(tensor, dtype):
69
+ """
70
+ Cast if autocast is not enabled.
71
+ """
72
+ if not torch.is_autocast_enabled():
73
+ return tensor.type(dtype)
74
+ return tensor
trellis2/pipelines/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ "Trellis2ImageTo3DPipeline": "trellis2_image_to_3d",
5
+ "Trellis2ImageTo3DCascadePipeline": "trellis2_image_to_3d_cascade",
6
+ "Trellis2ImageToTexturePipeline": "trellis2_image_to_tex",
7
+ }
8
+
9
+ __submodules = ['samplers', 'rembg']
10
+
11
+ __all__ = list(__attributes.keys()) + __submodules
12
+
13
+ def __getattr__(name):
14
+ if name not in globals():
15
+ if name in __attributes:
16
+ module_name = __attributes[name]
17
+ module = importlib.import_module(f".{module_name}", __name__)
18
+ globals()[name] = getattr(module, name)
19
+ elif name in __submodules:
20
+ module = importlib.import_module(f".{name}", __name__)
21
+ globals()[name] = module
22
+ else:
23
+ raise AttributeError(f"module {__name__} has no attribute {name}")
24
+ return globals()[name]
25
+
26
+
27
+ def from_pretrained(path: str):
28
+ """
29
+ Load a pipeline from a model folder or a Hugging Face model hub.
30
+
31
+ Args:
32
+ path: The path to the model. Can be either local path or a Hugging Face model name.
33
+ """
34
+ import os
35
+ import json
36
+ is_local = os.path.exists(f"{path}/pipeline.json")
37
+
38
+ if is_local:
39
+ config_file = f"{path}/pipeline.json"
40
+ else:
41
+ from huggingface_hub import hf_hub_download
42
+ config_file = hf_hub_download(path, "pipeline.json")
43
+
44
+ with open(config_file, 'r') as f:
45
+ config = json.load(f)
46
+ return globals()[config['name']].from_pretrained(path)
47
+
48
+
49
+ # For PyLance
50
+ if __name__ == '__main__':
51
+ from . import samplers, rembg
52
+ from .trellis_image_to_3d import TrellisImageTo3DPipeline
53
+ from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline
54
+ from .trellis2_image_to_3d_cascade import Trellis2ImageTo3DCascadePipeline
55
+ from .trellis2_image_to_tex import Trellis2ImageToTexturePipeline
trellis2/pipelines/base.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from .. import models
5
+
6
+
7
+ class Pipeline:
8
+ """
9
+ A base class for pipelines.
10
+ """
11
+ def __init__(
12
+ self,
13
+ models: dict[str, nn.Module] = None,
14
+ ):
15
+ if models is None:
16
+ return
17
+ self.models = models
18
+ for model in self.models.values():
19
+ model.eval()
20
+
21
+ @staticmethod
22
+ def from_pretrained(path: str) -> "Pipeline":
23
+ """
24
+ Load a pretrained model.
25
+ """
26
+ import os
27
+ import json
28
+ is_local = os.path.exists(f"{path}/pipeline.json")
29
+
30
+ if is_local:
31
+ config_file = f"{path}/pipeline.json"
32
+ else:
33
+ from huggingface_hub import hf_hub_download
34
+ config_file = hf_hub_download(path, "pipeline.json")
35
+
36
+ with open(config_file, 'r') as f:
37
+ args = json.load(f)['args']
38
+
39
+ _models = {}
40
+ for k, v in args['models'].items():
41
+ try:
42
+ _models[k] = models.from_pretrained(f"{path}/{v}")
43
+ except Exception as e:
44
+ _models[k] = models.from_pretrained(v)
45
+
46
+ new_pipeline = Pipeline(_models)
47
+ new_pipeline._pretrained_args = args
48
+ return new_pipeline
49
+
50
+ @property
51
+ def device(self) -> torch.device:
52
+ if hasattr(self, '_device'):
53
+ return self._device
54
+ for model in self.models.values():
55
+ if hasattr(model, 'device'):
56
+ return model.device
57
+ for model in self.models.values():
58
+ if hasattr(model, 'parameters'):
59
+ return next(model.parameters()).device
60
+ raise RuntimeError("No device found.")
61
+
62
+ def to(self, device: torch.device) -> None:
63
+ for model in self.models.values():
64
+ model.to(device)
65
+
66
+ def cuda(self) -> None:
67
+ self.to(torch.device("cuda"))
68
+
69
+ def cpu(self) -> None:
70
+ self.to(torch.device("cpu"))
trellis2/pipelines/rembg/BiRefNet.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from transformers import AutoModelForImageSegmentation
3
+ import torch
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+
7
+
8
+ class BiRefNet:
9
+ def __init__(self, model_name: str = "ZhengPeng7/BiRefNet"):
10
+ self.model = AutoModelForImageSegmentation.from_pretrained(
11
+ model_name, trust_remote_code=True
12
+ )
13
+ self.model.eval()
14
+ self.transform_image = transforms.Compose(
15
+ [
16
+ transforms.Resize((1024, 1024)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
19
+ ]
20
+ )
21
+
22
+ def to(self, device: str):
23
+ self.model.to(device)
24
+
25
+ def cuda(self):
26
+ self.model.cuda()
27
+
28
+ def cpu(self):
29
+ self.model.cpu()
30
+
31
+ def __call__(self, image: Image.Image) -> Image.Image:
32
+ image_size = image.size
33
+ input_images = self.transform_image(image).unsqueeze(0).to("cuda")
34
+ # Prediction
35
+ with torch.no_grad():
36
+ preds = self.model(input_images)[-1].sigmoid().cpu()
37
+ pred = preds[0].squeeze()
38
+ pred_pil = transforms.ToPILImage()(pred)
39
+ mask = pred_pil.resize(image_size)
40
+ image.putalpha(mask)
41
+ return image
42
+
trellis2/pipelines/rembg/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .BiRefNet import *