Spaces:
Runtime error
Runtime error
| import logging | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| logger = logging.getLogger(__name__) | |
| class Encoder(nn.Module): | |
| ''' | |
| Converts N hidden tokens into N seperate latent codes. | |
| ''' | |
| latent_token_size: int | |
| n_latent_tokens: int | |
| def __call__(self, encoding): | |
| latent_tokens = nn.Dense(self.latent_token_size)(encoding) | |
| raw_latent_code = latent_tokens[:, : self.n_latent_tokens, :] | |
| # TODO does this just apply tanh to each latent token? Or across the whole batch | |
| latent_code = jnp.tanh(raw_latent_code) | |
| return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim) | |
| VAE_ENCODER_MODELS = { | |
| '': Encoder, | |
| } | |