Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| ################################################################################# | |
| # Length Estimator # | |
| ################################################################################# | |
| class LengthEstimator(nn.Module): | |
| def __init__(self, input_size, output_size): | |
| super(LengthEstimator, self).__init__() | |
| nd = 512 | |
| self.output = nn.Sequential( | |
| nn.Linear(input_size, nd), | |
| nn.LayerNorm(nd), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Dropout(0.2), | |
| nn.Linear(nd, nd // 2), | |
| nn.LayerNorm(nd // 2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Dropout(0.2), | |
| nn.Linear(nd // 2, nd // 4), | |
| nn.LayerNorm(nd // 4), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Linear(nd // 4, output_size) | |
| ) | |
| self.output.apply(self.__init_weights) | |
| def __init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def forward(self, text_emb): | |
| return self.output(text_emb) |