kencwt commited on
Commit
15d68f9
·
1 Parent(s): 83dc22d

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -210
inference.py DELETED
@@ -1,210 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Motif-Video 2B — Text-to-Video inference.
3
-
4
- GPU requirements: ~24GB VRAM for 720p (1280x736, 121 frames).
5
- Requires: torch, diffusers (with MotifVideoPipeline), transformers>=5.5.4,
6
- accelerate, ftfy, einops, sentencepiece, regex
7
-
8
- Uses Adaptive Projected Guidance (APG) and DPMSolver++ scheduler by default.
9
- """
10
-
11
- import argparse
12
-
13
- import torch
14
- from diffusers import (
15
- AdaptiveProjectedGuidance,
16
- DPMSolverMultistepScheduler,
17
- MotifVideoPipeline,
18
- )
19
- from diffusers.utils import export_to_video
20
-
21
- _DEFAULT_NEGATIVE_PROMPT = (
22
- "text overlay, graphic overlay, watermark, logo, subtitles, timestamp, "
23
- "broadcast graphics, UI elements, random letters, frozen pose, rigid, "
24
- "static expression, jerky motion, mechanical motion, discontinuous motion, "
25
- "flat framing, depthless, dull lighting, monotone, crushed shadows, "
26
- "blown-out highlights, shifting background, fading background, poor continuity, "
27
- "identity drift, deformation, flickering, ghosting, smearing, duplication, "
28
- "mutated proportions, inconsistent clothing, flat colors, desaturated, "
29
- "tonally compressed, poor background separation, exposure shift, "
30
- "uneven brightness, color balance shift"
31
- )
32
-
33
-
34
- def parse_args():
35
- parser = argparse.ArgumentParser(description="Motif-Video 2B Inference (T2V)")
36
- parser.add_argument(
37
- "--model-path",
38
- type=str,
39
- default="Motif-Technologies/Motif-Video-2B",
40
- help="HuggingFace model ID or local checkpoint path",
41
- )
42
- parser.add_argument(
43
- "--prompt",
44
- type=str,
45
- default="A category-five hurricane, viewed from inside the eye, reveals a circular stadium of cloud walls rising to fifty thousand feet with an eerie disk of blue sky directly overhead. Shot from a NOAA reconnaissance aircraft mounted camera, the perspective looks outward toward the eyewall — a near-vertical curtain of rotating cloud and lightning that is simultaneously terrifying and transcendent. The inner surface of the eyewall catches the setting sun, painting it in improbable shades of peach and rose. The camera slowly pans 360 degrees to complete one full revolution, capturing the entire coliseum of the storm. Below, the ocean surface is a white blur of foam and spray. The documentary-style cinematography strips away all artifice to present the storm as an entity of pure elemental power.",
46
- help="Text prompt for video generation",
47
- )
48
- parser.add_argument(
49
- "--negative-prompt",
50
- type=str,
51
- default=_DEFAULT_NEGATIVE_PROMPT,
52
- help="Negative prompt",
53
- )
54
- parser.add_argument("--output", type=str, default="output.mp4", help="Output video file path")
55
- parser.add_argument("--num-frames", type=int, default=121, help="Number of frames to generate (121 = ~5s at 24fps)")
56
- parser.add_argument("--height", type=int, default=736, help="Video height in pixels")
57
- parser.add_argument("--width", type=int, default=1280, help="Video width in pixels")
58
- parser.add_argument("--guidance-scale", type=float, default=8.0, help="Classifier-free guidance scale")
59
- parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of denoising steps")
60
- parser.add_argument("--fps", type=int, default=24, help="Output video frame rate")
61
- parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
62
- parser.add_argument(
63
- "--dtype",
64
- type=str,
65
- default="bfloat16",
66
- choices=["float16", "bfloat16", "float32"],
67
- help="Model dtype",
68
- )
69
- parser.add_argument(
70
- "--use-sage-attention",
71
- action="store_true",
72
- help="Enable SageAttention for ~2x faster attention (requires: pip install sageattention>=2.1.1 from GitHub source)",
73
- )
74
- return parser.parse_args()
75
-
76
-
77
- def _enable_sage_attention(transformer):
78
- """Patch transformer attention to use SageAttention.
79
-
80
- Only patches _compute_attention (self-attention path). Cross-attention
81
- uses _handle_cross_attention_mode which calls F.sdpa directly and is
82
- unaffected by this patch.
83
-
84
- Mask handling follows motif-models dispatch_optimized_attention pattern:
85
- - mask=None: sage directly
86
- - mask with uniform active length: slice active region -> sage -> pad back
87
- - mask with non-uniform active length: SDPA fallback
88
- """
89
- from sageattention import sageattn
90
- from diffusers.models.transformers.transformer_motif_video import MotifVideoAttnProcessor2_0
91
-
92
- _orig_compute = MotifVideoAttnProcessor2_0._compute_attention
93
-
94
- def _sage_compute(self, query, key, value, attention_mask):
95
- if attention_mask is None:
96
- out = sageattn(
97
- query.contiguous(), key.contiguous(), value.contiguous(),
98
- tensor_layout="HND", is_causal=False,
99
- )
100
- out = out.transpose(1, 2).flatten(2, 3).to(query.dtype)
101
- return out
102
-
103
- # Find active token count from mask (shape: [B, 1, 1, S])
104
- padding_indices = attention_mask.sum(dim=-1).long().flatten()
105
- common_padding_index = padding_indices[0]
106
- is_uniform = (padding_indices == common_padding_index).all()
107
-
108
- if not is_uniform:
109
- return _orig_compute(self, query, key, value, attention_mask)
110
-
111
- active_len = common_padding_index.item()
112
- S = query.shape[2]
113
-
114
- if active_len == S:
115
- out = sageattn(
116
- query.contiguous(), key.contiguous(), value.contiguous(),
117
- tensor_layout="HND", is_causal=False,
118
- )
119
- out = out.transpose(1, 2).flatten(2, 3).to(query.dtype)
120
- return out
121
-
122
- # Slice to active region, run sage, pad back
123
- q_a = query[:, :, :active_len, :].contiguous()
124
- k_a = key[:, :, :active_len, :].contiguous()
125
- v_a = value[:, :, :active_len, :].contiguous()
126
-
127
- out_a = sageattn(q_a, k_a, v_a, tensor_layout="HND", is_causal=False)
128
-
129
- out = query.new_zeros(query.shape)
130
- out[:, :, :active_len, :] = out_a
131
- out = out.transpose(1, 2).flatten(2, 3).to(query.dtype)
132
- return out
133
-
134
- MotifVideoAttnProcessor2_0._compute_attention = _sage_compute
135
- transformer.to(memory_format=torch.channels_last_3d)
136
- print("[SageAttention] Enabled (patched _compute_attention + channels_last_3d)")
137
-
138
-
139
- def main():
140
- args = parse_args()
141
-
142
- dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
143
- torch_dtype = dtype_map[args.dtype]
144
-
145
- print(f"[T2V] Loading model from: {args.model_path}")
146
-
147
- guider = AdaptiveProjectedGuidance(
148
- guidance_scale=args.guidance_scale,
149
- adaptive_projected_guidance_rescale=12.0,
150
- adaptive_projected_guidance_momentum=0.1,
151
- use_original_formulation=True,
152
- normalization_dims="spatial",
153
- )
154
-
155
- pipe = MotifVideoPipeline.from_pretrained(
156
- args.model_path,
157
- torch_dtype=torch_dtype,
158
- guider=guider,
159
- )
160
-
161
- # Replace scheduler with DPMSolver++ for faster convergence and better quality.
162
- # Subclass ignores pipeline-supplied sigmas (PR branch always passes them)
163
- # and uses its own flow-matching sigma schedule instead.
164
- class _FlowDPMSolver(DPMSolverMultistepScheduler):
165
- def set_timesteps(self, num_inference_steps=None, device=None,
166
- sigmas=None, mu=None, timesteps=None):
167
- if sigmas is not None and num_inference_steps is None:
168
- num_inference_steps = len(sigmas)
169
- super().set_timesteps(
170
- num_inference_steps=num_inference_steps,
171
- device=device, timesteps=timesteps,
172
- )
173
-
174
- pipe.scheduler = _FlowDPMSolver(
175
- num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
176
- algorithm_type="dpmsolver++",
177
- solver_order=2,
178
- prediction_type="flow_prediction",
179
- use_flow_sigmas=True,
180
- flow_shift=15.0,
181
- )
182
-
183
- # Offload model components to CPU between uses to reduce peak VRAM
184
- pipe.enable_model_cpu_offload()
185
-
186
- if args.use_sage_attention:
187
- _enable_sage_attention(pipe.transformer)
188
-
189
- generator = torch.Generator(device="cuda").manual_seed(args.seed)
190
-
191
- print(f"Generating video: {args.width}x{args.height}, {args.num_frames} frames, {args.num_inference_steps} steps")
192
- output = pipe(
193
- prompt=args.prompt,
194
- negative_prompt=args.negative_prompt,
195
- height=args.height,
196
- width=args.width,
197
- num_frames=args.num_frames,
198
- num_inference_steps=args.num_inference_steps,
199
- frame_rate=args.fps,
200
- use_linear_quadratic_schedule=False,
201
- generator=generator,
202
- )
203
-
204
- video_frames = output.frames[0]
205
- export_to_video(video_frames, args.output, fps=args.fps)
206
- print(f"Video saved to: {args.output}")
207
-
208
-
209
- if __name__ == "__main__":
210
- main()