Wrappers
src.jymkit.LogWrapper
¤
Log the episode returns and lengths. Modeled after the LogWrapper in PureJaxRL.
This wrapper inserts episode returns and lengths into the info
dictionary of the
TimeStep
object. The returned_episode_returns
and returned_episode_lengths
are the returns and lengths of the last completed episode.
After collecting a trajectory of n
steps and collecting all the info dicts,
the episode returns may be collected via:
return_values = jax.tree.map(
lambda x: x[data["returned_episode"]], data["returned_episode_returns"]
)
Arguments:
_env
: Environment to wrap.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
observation_space: Space | PyTree[Space]
property
¤
_multi_agent: bool
property
¤
Indicates if the environment is a multi-agent environment.
For multi-agent environments, include a property multi_agent = True
in the subclass.
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. This is useful for environments with multiple agents.
_env: Environment
instance-attribute
¤
__getattr__(name)
¤
src.jymkit.VecEnvWrapper
¤
Wrapper to vectorize environments.
Simply calls jax.vmap
on the reset
and step
methods of the environment.
The number of environmnents is determined by the leading axis of the
inputs to the reset
and step
methods, as if you would call jax.vmap
directly.
We use a wrapper instead of jax.vmap
in each algorithm directly to control where
the vectorization happens. This allows other wrappers to act on the vectorized
environment, e.g. NormalizeVecObsWrapper
and NormalizeVecRewardWrapper
.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
observation_space: Space | PyTree[Space]
property
¤
_multi_agent: bool
property
¤
Indicates if the environment is a multi-agent environment.
For multi-agent environments, include a property multi_agent = True
in the subclass.
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. This is useful for environments with multiple agents.
_env: Environment
instance-attribute
¤
__getattr__(name)
¤
src.jymkit.NormalizeVecObsWrapper
¤
Normalize the observations of the environment via running mean and variance.
This wrapper acts on vectorized environments and in turn should be wrapped within
a VecEnvWrapper
.
Arguments:
_env
: Environment to wrap.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
observation_space: Space | PyTree[Space]
property
¤
_multi_agent: bool
property
¤
Indicates if the environment is a multi-agent environment.
For multi-agent environments, include a property multi_agent = True
in the subclass.
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. This is useful for environments with multiple agents.
_env: Environment
instance-attribute
¤
__getattr__(name)
¤
src.jymkit.NormalizeVecRewardWrapper
¤
Normalize the rewards of the environment via running mean and variance.
This wrapper acts on vectorized environments and in turn should be wrapped within
a VecEnvWrapper
.
Arguments:
_env
: Environment to wrap.gamma
: Discount factor for the rewards.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
observation_space: Space | PyTree[Space]
property
¤
_multi_agent: bool
property
¤
Indicates if the environment is a multi-agent environment.
For multi-agent environments, include a property multi_agent = True
in the subclass.
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. This is useful for environments with multiple agents.
_env: Environment
instance-attribute
¤
__getattr__(name)
¤
src.jymkit.FlattenObservationWrapper
¤
Flatten the observations of the environment.
Flattens each observation in the environment to a single vector. When the observation is a PyTree of arrays, it flattens each array and returns the same PyTree structure with the flattened arrays.
Arguments:
_env
: Environment to wrap.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
action_space: Space | PyTree[Space]
property
¤
_multi_agent: bool
property
¤
Indicates if the environment is a multi-agent environment.
For multi-agent environments, include a property multi_agent = True
in the subclass.
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. This is useful for environments with multiple agents.
_env: Environment
instance-attribute
¤
__getattr__(name)
¤
src.jymkit.GymnaxWrapper
¤
Wrapper for Gymnax environments to transform them into the Jymkit environment interface.
Arguments:
_env
: Gymnax environment.handle_truncation
: If True, the wrapper will reimplement the autoreset behavior to include truncated information and the terminal_observation in the info dictionary. If False, the wrapper will mirror the Gymnax behavior by ignoring truncations. Default=True.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
_multi_agent: bool
property
¤
Indicates if the environment is a multi-agent environment.
For multi-agent environments, include a property multi_agent = True
in the subclass.
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. This is useful for environments with multiple agents.
__getattr__(name)
¤
Utility functions¤
src.jymkit.is_wrapped(wrapped_env: Environment, wrapper_class: type) -> bool
¤
Check if the environment is wrapped with a specific wrapper class.
src.jymkit.remove_wrapper(wrapped_env: Environment, wrapper_class: type) -> Environment
¤
Remove a specific wrapper class from the environment.