Wrappers
src.jaxnasium.LogWrapper
¤
Log the episode returns and lengths. Modeled after the LogWrapper in PureJaxRL.
This wrapper inserts episode returns and lengths into the info dictionary of the
TimeStep object. The returned_episode_returns and returned_episode_lengths
are the returns and lengths of the last completed episode.
After collecting a trajectory of n steps and collecting all the info dicts,
the episode returns may be collected via:
return_values = jax.tree.map(
lambda x: x[data["returned_episode"]], data["returned_episode_returns"]
)
Arguments:
_env: Environment to wrap.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
observation_space: Space | PyTree[Space]
property
¤
auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState]
¤
Auto-resets the environment when the episode is terminated or truncated.
Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.
Arguments:
key: JAX PRNG key.timestep_step: The timestep returned by thestep_envmethod.state_step: The state returned by thestep_envmethod.
Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.
sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]]
¤
Convenience method to sample a random action from the environment's action space.
While one could use self.action_space.sample(key), this method additionally works on composite action spaces.
sample_observation(key: PRNGKeyArray) -> TObservation
¤
Convenience method to sample a random observation from the environment's observation space.
While one could use self.observation_space.sample(key), this method additionally works
on composite observation spaces.
multi_agent: bool
property
¤
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).
__getattr__(name)
¤
reset(key: PRNGKeyArray) -> Tuple[TObservation, LogEnvState]
¤
step(key: PRNGKeyArray, state: LogEnvState, action: PyTree[int | float | Array]) -> Tuple[TimeStep, LogEnvState]
¤
src.jaxnasium.VecEnvWrapper
¤
Wrapper to vectorize environments.
Simply calls jax.vmap on the reset and step methods of the environment.
The number of environmnents is determined by the leading axis of the
inputs to the reset and step methods, as if you would call jax.vmap directly.
We use a wrapper instead of jax.vmap in each algorithm directly to control where
the vectorization happens. This allows other wrappers to act on the vectorized
environment, e.g. NormalizeVecObsWrapper and NormalizeVecRewardWrapper.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
observation_space: Space | PyTree[Space]
property
¤
auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState]
¤
Auto-resets the environment when the episode is terminated or truncated.
Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.
Arguments:
key: JAX PRNG key.timestep_step: The timestep returned by thestep_envmethod.state_step: The state returned by thestep_envmethod.
Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.
sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]]
¤
Convenience method to sample a random action from the environment's action space.
While one could use self.action_space.sample(key), this method additionally works on composite action spaces.
sample_observation(key: PRNGKeyArray) -> TObservation
¤
Convenience method to sample a random observation from the environment's observation space.
While one could use self.observation_space.sample(key), this method additionally works
on composite observation spaces.
multi_agent: bool
property
¤
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).
__getattr__(name)
¤
reset(key: PRNGKeyArray) -> Tuple[TObservation, Any]
¤
step(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
src.jaxnasium.NormalizeVecObsWrapper
¤
Normalize the observations of the environment via running mean and variance.
This wrapper acts on vectorized environments and in turn should be wrapped within
a VecEnvWrapper.
Arguments:
_env: Environment to wrap.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
observation_space: Space | PyTree[Space]
property
¤
auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState]
¤
Auto-resets the environment when the episode is terminated or truncated.
Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.
Arguments:
key: JAX PRNG key.timestep_step: The timestep returned by thestep_envmethod.state_step: The state returned by thestep_envmethod.
Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.
sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]]
¤
Convenience method to sample a random action from the environment's action space.
While one could use self.action_space.sample(key), this method additionally works on composite action spaces.
sample_observation(key: PRNGKeyArray) -> TObservation
¤
Convenience method to sample a random observation from the environment's observation space.
While one could use self.observation_space.sample(key), this method additionally works
on composite observation spaces.
multi_agent: bool
property
¤
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).
__getattr__(name)
¤
__check_init__()
¤
update_state_and_get_obs(obs, state: NormalizeVecObsState)
¤
reset(key: PRNGKeyArray) -> Tuple[TObservation, NormalizeVecObsState]
¤
step(key: PRNGKeyArray, state: NormalizeVecObsState, action: PyTree[int | float | Array]) -> Tuple[TimeStep, NormalizeVecObsState]
¤
src.jaxnasium.NormalizeVecRewardWrapper
¤
Normalize the rewards of the environment via running mean and variance.
This wrapper acts on vectorized environments and in turn should be wrapped within
a VecEnvWrapper.
Arguments:
_env: Environment to wrap.gamma: Discount factor for the rewards.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
observation_space: Space | PyTree[Space]
property
¤
auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState]
¤
Auto-resets the environment when the episode is terminated or truncated.
Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.
Arguments:
key: JAX PRNG key.timestep_step: The timestep returned by thestep_envmethod.state_step: The state returned by thestep_envmethod.
Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.
sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]]
¤
Convenience method to sample a random action from the environment's action space.
While one could use self.action_space.sample(key), this method additionally works on composite action spaces.
sample_observation(key: PRNGKeyArray) -> TObservation
¤
Convenience method to sample a random observation from the environment's observation space.
While one could use self.observation_space.sample(key), this method additionally works
on composite observation spaces.
multi_agent: bool
property
¤
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).
__getattr__(name)
¤
gamma: float = 0.99
class-attribute
instance-attribute
¤
__check_init__()
¤
reset(key: PRNGKeyArray) -> Tuple[TObservation, NormalizeVecRewState]
¤
step(key: PRNGKeyArray, state: NormalizeVecRewState, action: PyTree[int | float | Array]) -> Tuple[TimeStep, NormalizeVecRewState]
¤
src.jaxnasium.FlattenObservationWrapper
¤
Flatten the observations of the environment.
Flattens each observation in the environment to a single vector. When the observation is a PyTree of arrays, it flattens each array and returns the same PyTree structure with the flattened arrays.
Arguments:
_env: Environment to wrap.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState]
¤
Auto-resets the environment when the episode is terminated or truncated.
Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.
Arguments:
key: JAX PRNG key.timestep_step: The timestep returned by thestep_envmethod.state_step: The state returned by thestep_envmethod.
Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.
sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]]
¤
Convenience method to sample a random action from the environment's action space.
While one could use self.action_space.sample(key), this method additionally works on composite action spaces.
sample_observation(key: PRNGKeyArray) -> TObservation
¤
Convenience method to sample a random observation from the environment's observation space.
While one could use self.observation_space.sample(key), this method additionally works
on composite observation spaces.
multi_agent: bool
property
¤
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).
__getattr__(name)
¤
reset(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
step(key: PRNGKeyArray, state: TEnvState, action: PyTree[int | float | Array]) -> Tuple[TimeStep, TEnvState]
¤
observation_space: Space
property
¤
src.jaxnasium.TransformRewardWrapper
¤
Transform the rewards of the environment using a given function.
Arguments:
_env: Environment to wrap.transform_fn: Function to transform the rewards.
reset(key: PRNGKeyArray) -> Tuple[TObservation, Any]
¤
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
observation_space: Space | PyTree[Space]
property
¤
auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState]
¤
Auto-resets the environment when the episode is terminated or truncated.
Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.
Arguments:
key: JAX PRNG key.timestep_step: The timestep returned by thestep_envmethod.state_step: The state returned by thestep_envmethod.
Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.
sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]]
¤
Convenience method to sample a random action from the environment's action space.
While one could use self.action_space.sample(key), this method additionally works on composite action spaces.
sample_observation(key: PRNGKeyArray) -> TObservation
¤
Convenience method to sample a random observation from the environment's observation space.
While one could use self.observation_space.sample(key), this method additionally works
on composite observation spaces.
multi_agent: bool
property
¤
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).
__getattr__(name)
¤
transform_fn: Callable
instance-attribute
¤
step(key: PRNGKeyArray, state: TEnvState, action: PyTree[int | float | Array]) -> Tuple[TimeStep, TEnvState]
¤
src.jaxnasium.ScaleRewardWrapper
¤
Scale the rewards of the environment by a given factor.
Arguments:
_env: Environment to wrap.scale: Factor to scale the rewards by.
step(key: PRNGKeyArray, state: TEnvState, action: PyTree[int | float | Array]) -> Tuple[TimeStep, TEnvState]
¤
reset(key: PRNGKeyArray) -> Tuple[TObservation, Any]
¤
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
observation_space: Space | PyTree[Space]
property
¤
auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState]
¤
Auto-resets the environment when the episode is terminated or truncated.
Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.
Arguments:
key: JAX PRNG key.timestep_step: The timestep returned by thestep_envmethod.state_step: The state returned by thestep_envmethod.
Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.
sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]]
¤
Convenience method to sample a random action from the environment's action space.
While one could use self.action_space.sample(key), this method additionally works on composite action spaces.
sample_observation(key: PRNGKeyArray) -> TObservation
¤
Convenience method to sample a random observation from the environment's observation space.
While one could use self.observation_space.sample(key), this method additionally works
on composite observation spaces.
multi_agent: bool
property
¤
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).
__getattr__(name)
¤
scale: float = scale
instance-attribute
¤
__init__(env: Environment, scale: float = 1.0)
¤
transform_fn = lambda r: r * scale
instance-attribute
¤
src.jaxnasium.DiscreteActionWrapper
¤
Wrapper to convert continuous actions to discrete actions.
Arguments:
_env: Environment to wrap.num_actions: Number of discrete actions to convert to.
reset(key: PRNGKeyArray) -> Tuple[TObservation, Any]
¤
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
observation_space: Space | PyTree[Space]
property
¤
auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState]
¤
Auto-resets the environment when the episode is terminated or truncated.
Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.
Arguments:
key: JAX PRNG key.timestep_step: The timestep returned by thestep_envmethod.state_step: The state returned by thestep_envmethod.
Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.
sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]]
¤
Convenience method to sample a random action from the environment's action space.
While one could use self.action_space.sample(key), this method additionally works on composite action spaces.
sample_observation(key: PRNGKeyArray) -> TObservation
¤
Convenience method to sample a random observation from the environment's observation space.
While one could use self.observation_space.sample(key), this method additionally works
on composite observation spaces.
multi_agent: bool
property
¤
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).
__getattr__(name)
¤
num_actions: int
instance-attribute
¤
step(key: PRNGKeyArray, state: TEnvState, action: int | Int[Array, ' num_actions']) -> Tuple[TimeStep, TEnvState]
¤
action_space: Discrete | MultiDiscrete
property
¤
original_action_space: Space
property
¤
Return the original action space of the environment. This is useful for algorithms that need to know the original action space.
src.jaxnasium.FlattenActionSpaceWrapper
¤
Wrapper to convert (PyTrees of) (multi-)discrete action spaces to a single discrete action space. This grows the action space (significantly for large action spaces), but allows to use algorithms that only support discrete action spaces.
First flattens each MultiDiscrete action space to a single discrete action space, then combines possibly remaining discrete action spaces to a single discrete action space.
Arguments:
_env: Environment to wrap.
reset(key: PRNGKeyArray) -> Tuple[TObservation, Any]
¤
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
observation_space: Space | PyTree[Space]
property
¤
auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState]
¤
Auto-resets the environment when the episode is terminated or truncated.
Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.
Arguments:
key: JAX PRNG key.timestep_step: The timestep returned by thestep_envmethod.state_step: The state returned by thestep_envmethod.
Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.
sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]]
¤
Convenience method to sample a random action from the environment's action space.
While one could use self.action_space.sample(key), this method additionally works on composite action spaces.
sample_observation(key: PRNGKeyArray) -> TObservation
¤
Convenience method to sample a random observation from the environment's observation space.
While one could use self.observation_space.sample(key), this method additionally works
on composite observation spaces.
multi_agent: bool
property
¤
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).
__getattr__(name)
¤
step(key: PRNGKeyArray, state: TEnvState, action: int) -> Tuple[TimeStep, TEnvState]
¤
action_space: Discrete
property
¤
original_action_space: Space
property
¤
Return the original action space of the environment.
Utility functions¤
src.jaxnasium.is_wrapped(wrapped_env: Environment, wrapper_class: type | str) -> bool
¤
Check if the environment is wrapped with a specific wrapper class.
src.jaxnasium.remove_wrapper(wrapped_env: Environment, wrapper_class: type) -> Environment
¤
Remove a specific wrapper class from the environment.