ACMDM_Motion_Generation / models /ACMDM_ControlNet.py
sourxbhh's picture
Add model directory
0f34fb9
import torch
import torch.nn as nn
from models.ACMDM import ACMDM
from models.ACMDM import TimestepEmbedder, ACMDMTransBlock, LlamaRMSNorm
from models.ROPE import RopeND
from utils.eval_utils import eval_decorator
from utils.train_utils import lengths_to_mask
#################################################################################
# ACMDM+ControlNet #
#################################################################################
class ACMDM_ControlNet(ACMDM):
def __init__(self, input_dim, cond_mode, base_checkpoint, latent_dim=256, ff_size=1024, num_layers=8,
num_heads=4, dropout=0.2, clip_dim=512,
diff_model='Flow', cond_drop_prob=0.1, max_length=49,
patch_size=(1, 22), stride_size=(1, 22),
clip_version='ViT-B/32', freeze_base=True, need_base=True, **kargs):
# --------------------------------------------------------------------------
# ACMDM
super().__init__(input_dim, cond_mode, latent_dim=latent_dim, ff_size=ff_size, num_layers=num_layers,
num_heads=num_heads, dropout=dropout, clip_dim=clip_dim,
diff_model=diff_model, cond_drop_prob=cond_drop_prob, max_length=max_length,
patch_size=patch_size, stride_size=stride_size,
clip_version=clip_version, **kargs)
# --------------------------------------------------------------------------
# ControlNet
self.c_t_embedder = TimestepEmbedder(self.latent_dim)
self.c_control_embedder = c_control_embedder(3, self.latent_dim, patch_size=self.patch_size,
stride_size=self.stride_size)
self.c_x_embedder = nn.Conv2d(self.input_dim, self.latent_dim, kernel_size=self.patch_size,
stride=self.stride_size, bias=True)
self.c_y_embedder = nn.Linear(self.clip_dim, self.latent_dim)
self.c_rope = RopeND(nd=1, nd_split=[1], max_lens=self.max_lens)
self.ControlNet = nn.ModuleList([
ACMDMTransBlock(self.latent_dim, num_heads, mlp_size=ff_size, rope=self.c_rope, qk_norm=True) for _ in
range(num_layers)
])
self.zero_Linear = nn.ModuleList([
nn.Linear(self.latent_dim, self.latent_dim) for _ in range(num_layers)
])
self.initialize_weights_control()
if need_base:
for key, value in list(base_checkpoint['ema_acmdm'].items()):
if key.startswith('ACMDMTransformer.'):
new_key = key.replace('ACMDMTransformer.', 'ControlNet.')
base_checkpoint['ema_acmdm'][new_key] = value.clone()
missing_keys, unexpected_keys = self.load_state_dict(base_checkpoint['ema_acmdm'], strict=False)
assert len(unexpected_keys) == 0
if self.cond_mode == 'text':
print('ReLoading CLIP...')
self.clip_version = clip_version
self.clip_model = self.load_and_freeze_clip(clip_version)
if freeze_base:
for param in self.t_embedder.parameters():
param.requires_grad = False
for param in self.x_embedder.parameters():
param.requires_grad = False
for param in self.y_embedder.parameters():
param.requires_grad = False
for param in self.final_layer.parameters():
param.requires_grad = False
for param in self.ACMDMTransformer.parameters():
param.requires_grad = False
def initialize_weights_control(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.ACMDMTransformer:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.c_t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.c_t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.ControlNet:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.c_control_embedder.zero_linear.weight, 0)
nn.init.constant_(self.c_control_embedder.zero_linear.bias, 0)
for block in self.zero_Linear:
nn.init.constant_(block.weight, 0)
nn.init.constant_(block.bias, 0)
def forward_with_control(self, x, t, conds, attention_mask, cfg1=1.0, cfg2=1.0, control=None, index=None,
force_mask=False):
if not (cfg1 == 1.0 and cfg2 == 1.0):
half = x[: len(x) // 3]
x = torch.cat([half, half, half], dim=0)
# controlnet
c_t = self.c_t_embedder(t, dtype=x.dtype)
conds = self.mask_cond(conds, force_mask=force_mask)
c_control = self.c_control_embedder(control * index)
if self.training and self.cond_drop_prob > 0.:
mask = torch.bernoulli(torch.ones(c_control.shape[0], device=c_control.device) * self.cond_drop_prob).view(c_control.shape[0], 1, 1)
c_control = c_control * (1. - mask)
if not (cfg1 == 1.0 and cfg2 == 1.0):
c_control = torch.cat([c_control, c_control, torch.zeros_like(c_control)], dim=0)
c_x = self.c_x_embedder(x).flatten(2).transpose(1, 2)
c_y = self.c_y_embedder(conds)
c_y = c_t.unsqueeze(1) + c_y.unsqueeze(1)
c_x = c_x + c_control
c_position_ids = self.position_ids_precompute[:, :c_x.shape[1]]
c_out = []
for c_block, c_linear in zip(self.ControlNet, self.zero_Linear):
c_x = c_block(c_x, c_y, attention_mask, position_ids=c_position_ids)
c_out.append(c_linear(c_x))
# main branch
tt = self.t_embedder(t, dtype=x.dtype)
x = self.x_embedder(x)
x = x.flatten(2).transpose(1, 2)
conds = self.y_embedder(conds)
y = tt.unsqueeze(1) + conds.unsqueeze(1)
position_ids = self.position_ids_precompute[:, :x.shape[1]]
# merging
for block, c in zip(self.ACMDMTransformer, c_out):
x = block(x, y, attention_mask, position_ids=position_ids)
x = x + c
x = self.final_layer(x, y)
if not (cfg1 == 1.0 and cfg2 == 1.0):
cond_eps, uncond_eps1, uncond_eps2 = torch.split(x, len(x) // 3, dim=0)
half_eps = cond_eps + (cfg1-1) * (cond_eps - uncond_eps1) + (cfg2-1) * (cond_eps - uncond_eps2)
x = torch.cat([half_eps, half_eps, half_eps], dim=0)
return x
def forward_control_loss(self, latents, y, m_lens, original, index, ae, mean_std):
latents = latents.permute(0, 2, 3, 1)
b, l, j, d = latents.shape
device = latents.device
non_pad_mask = lengths_to_mask(m_lens, l)
latents = torch.where(non_pad_mask.unsqueeze(-1).unsqueeze(-1), latents, torch.zeros_like(latents))
target = latents.clone().permute(0, 3, 1, 2).detach()
original = original.clone().detach()
force_mask = False
if self.cond_mode == 'text':
with torch.no_grad():
cond_vector = self.encode_text(y)
elif self.cond_mode == 'action':
cond_vector = self.enc_action(y).to(device).float()
elif self.cond_mode == 'uncond':
cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
force_mask = True
else:
raise NotImplementedError("Unsupported condition mode!!!")
attention_mask = non_pad_mask.unsqueeze(-1).repeat(1, 1, self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
random_indices = torch.randint(0, len(index), (b,)).to(device)
indexx = torch.tensor(index, device=device)[random_indices]
mask_seq = torch.zeros((b, 3, l*4, j), device=device)
for i in range(b):
seq_num = torch.randint(1, m_lens[i]*4, (1,))
choose_seq = torch.sort(torch.randperm(m_lens[i]*4)[:seq_num.item()]).values
mask_seq[i, :, choose_seq, indexx[i]] = 1.0
model_kwargs = dict(conds=cond_vector, attention_mask=attention_mask, control=original, index=mask_seq,
force_mask=force_mask, mean_std=mean_std)
if self.diff_model == "Flow":
loss_dict = self.train_diffusion.training_losses(self.forward_with_control, target, ae=ae,
model_kwargs=model_kwargs)
else:
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
loss_dict = self.train_diffusion.training_losses(self.forward_with_control, target, t, model_kwargs)
loss = loss_dict["loss"]
loss = (loss * non_pad_mask).sum() / non_pad_mask.sum()
return loss, loss_dict["loss_control"]
@torch.no_grad()
@eval_decorator
def generate_control(self,
conds,
m_lens,
control,
index,
density,
cond_scale,
temperature=1,
j=22
):
device = next(self.parameters()).device
l = control.shape[2]//4
b = len(m_lens)
if self.cond_mode == 'text':
with torch.no_grad():
cond_vector = self.encode_text(conds)
elif self.cond_mode == 'action':
cond_vector = self.enc_action(conds).to(device)
elif self.cond_mode == 'uncond':
cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
else:
raise NotImplementedError("Unsupported condition mode!!!")
padding_mask = ~lengths_to_mask(m_lens, l)
noise = torch.randn(b, self.input_dim, l, j).to(device)
control = control.clone()
cfg1 = cond_scale[0]
cfg2 = cond_scale[1]
if not (cfg1 == 1.0 and cfg2 == 1.0):
# (1) with text and with control (2) no text and with control (3) with text and no control
cond_vector = torch.cat([cond_vector, torch.zeros_like(cond_vector), cond_vector], dim=0)
random_indices = torch.tensor(0, device=device).repeat(b) # no random in inference
indexx = torch.tensor(index, device=device)[random_indices]
mask_seq = torch.zeros((b, 3, l * 4, j), device=device)
for i in range(b):
if density in [1, 2, 5]:
seq_num = density
else:
seq_num = int(m_lens[i] *4* density / 100)
choose_seq = torch.sort(torch.randperm(m_lens[i] * 4)[:seq_num]).values
mask_seq[i, :, choose_seq, indexx[i]] = 1.0
attention_mask = (~padding_mask).unsqueeze(-1).repeat(1, 1, self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
model_kwargs = dict(conds=cond_vector, attention_mask=attention_mask, cfg1=cfg1, cfg2=cfg2, index=mask_seq,
control=control)
sample_fn = self.forward_with_control
if not (cfg1 == 1.0 and cfg2 == 1.0):
model_kwargs["attention_mask"] = attention_mask.repeat(3, 1, 1, 1)
noise = torch.cat([noise, noise, noise], dim=0)
if self.diff_model == "Flow":
model_fn = self.gen_diffusion.sample_ode() # default to ode sampling
sampled_token_latent = model_fn(noise, sample_fn, **model_kwargs)[-1]
else:
sampled_token_latent = self.gen_diffusion.p_sample_loop(
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs,
progress=False,
temperature=temperature
)
if not (cfg1 == 1.0 and cfg2 == 1.0):
sampled_token_latent, _, _ = sampled_token_latent.chunk(3, dim=0)
sampled_token_latent = sampled_token_latent.permute(0, 2, 3, 1)
latents = torch.where(padding_mask.unsqueeze(-1).unsqueeze(-1), torch.zeros_like(sampled_token_latent),
sampled_token_latent)
return latents.permute(0, 3, 1, 2), mask_seq
#################################################################################
# ACMDM Zoos #
#################################################################################
def acmdm_raw_flow_s_ps22_control(**kwargs):
layer = 8
return ACMDM_ControlNet(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
diff_model="Flow", cond_drop_prob=0.1, max_length=49,
patch_size=(1, 22), stride_size=(1, 22), freeze_base=True, **kwargs)
ACMDM_ControlNet_Models = {
'ACMDM-Flow-S-PatchSize22-ControlNet': acmdm_raw_flow_s_ps22_control,
}
#################################################################################
# Inner Architectures #
#################################################################################
def modulate(x, shift, scale):
return x * (1 + scale) + shift
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
class c_control_embedder(nn.Module):
def __init__(
self,
in_features: int,
hidden_features,
patch_size,
stride_size,
) -> None:
super().__init__()
self.patch_embed = nn.Conv2d(in_features, hidden_features, kernel_size=(4,patch_size[1]), stride=(4,stride_size[1]), bias=True)
self.norm = LlamaRMSNorm(hidden_features, eps=1e-6)
self.zero_linear = nn.Linear(hidden_features, hidden_features)
def forward(self, x):
x = self.patch_embed(x).flatten(2).transpose(1, 2)
x = self.norm(x)
x = self.zero_linear(x)
return x