ACMDM_Motion_Generation / models /LengthEstimator.py
sourxbhh's picture
Add model directory
0f34fb9
import torch.nn as nn
#################################################################################
# Length Estimator #
#################################################################################
class LengthEstimator(nn.Module):
def __init__(self, input_size, output_size):
super(LengthEstimator, self).__init__()
nd = 512
self.output = nn.Sequential(
nn.Linear(input_size, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.2),
nn.Linear(nd, nd // 2),
nn.LayerNorm(nd // 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.2),
nn.Linear(nd // 2, nd // 4),
nn.LayerNorm(nd // 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd // 4, output_size)
)
self.output.apply(self.__init_weights)
def __init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, text_emb):
return self.output(text_emb)