| from typing import Sequence |
| from functools import partial |
| import jax |
| import jax.numpy as jnp |
| import flax.linen as nn |
|
|
|
|
| class DQNNet(nn.Module): |
| features: Sequence[int] |
| n_actions: int |
|
|
| @nn.compact |
| def __call__(self, x): |
| initializer = nn.initializers.xavier_uniform() |
| x = nn.relu( |
| nn.Conv(features=self.features[0], kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)( |
| jnp.array(x, ndmin=4) / 255.0 |
| ) |
| ) |
| x = nn.relu(nn.Conv(features=self.features[1], kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x)) |
| x = nn.relu(nn.Conv(features=self.features[2], kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x)) |
| x = x.reshape((x.shape[0], -1)) |
|
|
| x = jnp.squeeze(x) |
|
|
| for idx_layer in range(3, len(self.features)): |
| x = nn.relu((nn.Dense(self.features[idx_layer], kernel_init=initializer)(x))) |
|
|
| return nn.Dense(self.n_actions, kernel_init=initializer)(x) |
|
|
|
|
| class QNetwork: |
| def __init__(self, features: Sequence[int], n_actions: int) -> None: |
| self.network = DQNNet(features, n_actions) |
|
|
| @partial(jax.jit, static_argnames="self") |
| def best_action(self, params, state: jnp.ndarray) -> jnp.int8: |
| return jnp.argmax(self.network.apply(params, state)).astype(jnp.int8) |
|
|