Spaces:
Runtime error
Runtime error
| import logging | |
| import flax.linen as nn | |
| logger = logging.getLogger(__name__) | |
| class Decoder(nn.Module): | |
| ''' | |
| Converts latent code -> transformer encoding. | |
| ''' | |
| dim_model: int | |
| n_latent_tokens: int | |
| def __call__(self, latent_code): # (batch, latent_tokens_per_sequence, latent_token_dim) | |
| raw_latent_tokens = nn.Dense(self.dim_model)(latent_code) | |
| latent_tokens = nn.LayerNorm()(raw_latent_tokens) | |
| return latent_tokens # (batch, latent_tokens_per_sequence, dim_model) | |
| VAE_DECODER_MODELS = { | |
| '': Decoder, | |
| } | |