Skip to content

Spaces

src.jymkit._wrappers.Space ¤

The base class for all spaces in JymKit. Instead of using this class directly, use the subclasses Box, Discrete, and MultiDiscrete. Composite spaces can be created by simply combining spaces in an arbitrary PyTree. For example, a tuple of Box spaces can be created as follows:

from jymkit import Box

box1 = Box(low=0, high=1, shape=(3,))
box2 = Box(low=0, high=1, shape=(4,))
box3 = Box(low=0, high=1, shape=(5,))
composite_space = (box1, box2, box3)

JymKit algorithms assume multi-agent environments are such a composite space, where the first level of the PyTree is the agent dimension. For example, a multi-agent environment observation space may look like this:

from jymkit import Box
from jymkit import MultiDiscrete

agent1_obs = Box(low=0, high=1, shape=(3,))
agent2_obs = Box(low=0, high=1, shape=(4,))
agent3_obs = MultiDiscrete(nvec=[2, 3])
env_obs_space = {
    "agent1": agent1_obs,
    "agent2": agent2_obs,
    "agent3": agent3_obs,
}

Spaces are purposefully not registered PyTree nodes.

src.jymkit.Box dataclass ¤

The standard Box space for continuous action/observation spaces.

Arguments:

  • low (int / Array[int]): The lower bound of the space.
  • high (int / Array[int]): The upper bound of the space.
  • shape: The shape of the space.
  • dtype: The data type of the space. Default is jnp.float32.
sample(rng: PRNGKeyArray) -> Array ¤

Sample random action uniformly from set of continuous choices.

src.jymkit.Discrete dataclass ¤

The standard discrete space for discrete action/observation spaces.

Arguments:

  • n (int): The number of discrete actions.
  • dtype: The data type of the space. Default is jnp.int16.
sample(rng: PRNGKeyArray) -> Int[Array, ''] ¤

Sample random action uniformly from set of discrete choices.

src.jymkit.MultiDiscrete dataclass ¤

The standard multi-discrete space for discrete action/observation spaces.

Arguments:

  • nvec (Array[int]): The number of discrete actions for each dimension.
  • dtype: The data type of the space. Default is jnp.int16.
sample(rng: PRNGKeyArray) -> Int[Array, ''] ¤

Sample random action uniformly from set of discrete choices.