Skip to content

Wrappers

src.jymkit.LogWrapper ¤

Log the episode returns and lengths. Modeled after the LogWrapper in PureJaxRL.

This wrapper inserts episode returns and lengths into the info dictionary of the TimeStep object. The returned_episode_returns and returned_episode_lengths are the returns and lengths of the last completed episode.

After collecting a trajectory of n steps and collecting all the info dicts, the episode returns may be collected via:

return_values = jax.tree.map(
    lambda x: x[data["returned_episode"]], data["returned_episode_returns"]
)

Arguments:

  • _env: Environment to wrap.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState] ¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState] ¤
action_space: Space | PyTree[Space] property ¤
observation_space: Space | PyTree[Space] property ¤
_multi_agent: bool property ¤

Indicates if the environment is a multi-agent environment. For multi-agent environments, include a property multi_agent = True in the subclass.

agent_structure: PyTreeDef property ¤

Returns the structure of the agent space. This is useful for environments with multiple agents.

_env: Environment instance-attribute ¤
__getattr__(name) ¤

src.jymkit.VecEnvWrapper ¤

Wrapper to vectorize environments. Simply calls jax.vmap on the reset and step methods of the environment. The number of environmnents is determined by the leading axis of the inputs to the reset and step methods, as if you would call jax.vmap directly.

We use a wrapper instead of jax.vmap in each algorithm directly to control where the vectorization happens. This allows other wrappers to act on the vectorized environment, e.g. NormalizeVecObsWrapper and NormalizeVecRewardWrapper.

step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState] ¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState] ¤
action_space: Space | PyTree[Space] property ¤
observation_space: Space | PyTree[Space] property ¤
_multi_agent: bool property ¤

Indicates if the environment is a multi-agent environment. For multi-agent environments, include a property multi_agent = True in the subclass.

agent_structure: PyTreeDef property ¤

Returns the structure of the agent space. This is useful for environments with multiple agents.

_env: Environment instance-attribute ¤
__getattr__(name) ¤

src.jymkit.NormalizeVecObsWrapper ¤

Normalize the observations of the environment via running mean and variance. This wrapper acts on vectorized environments and in turn should be wrapped within a VecEnvWrapper.

Arguments:

  • _env: Environment to wrap.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState] ¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState] ¤
action_space: Space | PyTree[Space] property ¤
observation_space: Space | PyTree[Space] property ¤
_multi_agent: bool property ¤

Indicates if the environment is a multi-agent environment. For multi-agent environments, include a property multi_agent = True in the subclass.

agent_structure: PyTreeDef property ¤

Returns the structure of the agent space. This is useful for environments with multiple agents.

_env: Environment instance-attribute ¤
__getattr__(name) ¤

src.jymkit.NormalizeVecRewardWrapper ¤

Normalize the rewards of the environment via running mean and variance. This wrapper acts on vectorized environments and in turn should be wrapped within a VecEnvWrapper.

Arguments:

  • _env: Environment to wrap.
  • gamma: Discount factor for the rewards.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState] ¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState] ¤
action_space: Space | PyTree[Space] property ¤
observation_space: Space | PyTree[Space] property ¤
_multi_agent: bool property ¤

Indicates if the environment is a multi-agent environment. For multi-agent environments, include a property multi_agent = True in the subclass.

agent_structure: PyTreeDef property ¤

Returns the structure of the agent space. This is useful for environments with multiple agents.

_env: Environment instance-attribute ¤
__getattr__(name) ¤

src.jymkit.FlattenObservationWrapper ¤

Flatten the observations of the environment.

Flattens each observation in the environment to a single vector. When the observation is a PyTree of arrays, it flattens each array and returns the same PyTree structure with the flattened arrays.

Arguments:

  • _env: Environment to wrap.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState] ¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState] ¤
action_space: Space | PyTree[Space] property ¤
_multi_agent: bool property ¤

Indicates if the environment is a multi-agent environment. For multi-agent environments, include a property multi_agent = True in the subclass.

agent_structure: PyTreeDef property ¤

Returns the structure of the agent space. This is useful for environments with multiple agents.

_env: Environment instance-attribute ¤
__getattr__(name) ¤

src.jymkit.GymnaxWrapper ¤

Wrapper for Gymnax environments to transform them into the Jymkit environment interface.

Arguments:

  • _env: Gymnax environment.
  • handle_truncation: If True, the wrapper will reimplement the autoreset behavior to include truncated information and the terminal_observation in the info dictionary. If False, the wrapper will mirror the Gymnax behavior by ignoring truncations. Default=True.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState] ¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState] ¤
_multi_agent: bool property ¤

Indicates if the environment is a multi-agent environment. For multi-agent environments, include a property multi_agent = True in the subclass.

agent_structure: PyTreeDef property ¤

Returns the structure of the agent space. This is useful for environments with multiple agents.

__getattr__(name) ¤

Utility functions¤

src.jymkit.is_wrapped(wrapped_env: Environment, wrapper_class: type) -> bool ¤

Check if the environment is wrapped with a specific wrapper class.

src.jymkit.remove_wrapper(wrapped_env: Environment, wrapper_class: type) -> Environment ¤

Remove a specific wrapper class from the environment.