Environment
src.jaxnasium.Environment
¤
Base environment class for JAX-compatible environments. Create your environment by subclassing this.
step and reset should typically not be overridden, as they merely handle the
auto-reset logic. Instead, the environment-specific logic should be implemented in the
step_env and reset_env methods.
step(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
Steps the environment forward with the given action and performs auto-reset when necessary.
Additionally, this function inserts the original observation (before auto-resetting) in
the info dictionary to bootstrap correctly on truncated episodes (info={"_TERMINAL_OBSERVATION": obs, ...})
This function should typically not be overridden. Instead, the environment-specific logic
should be implemented in the step_env method.
Returns a TimeStep object (observation, reward, terminated, truncated, info) and the new state.
Arguments:
key: JAX PRNG key.state: Current state of the environment.action: Action to take in the environment.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
abstractmethod
¤
Defines the environment-specific step logic. I.e. here the state of the environment is updated according to the transition function.
Returns a TimeStep object (observation, reward, terminated, truncated, info) and the new state.
Arguments:
key: JAX PRNG key.state: Current state of the environment.action: Action to take in the environment.
reset(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
Resets the environment to an initial state and returns the initial observation.
Environment-specific logic is defined in the reset_env method. Typically, this function
should not be overridden.
Returns the initial observation and the initial state of the environment.
Arguments:
key: JAX PRNG key.
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
abstractmethod
¤
Defines the environment-specific reset logic.
Returns the initial observation and the initial state of the environment.
Arguments:
key: JAX PRNG key.
observation_space: Space | PyTree[Space]
abstractmethod
property
¤
Defines the space of possible observations from the environment.
For multi-agent environments, this should be a PyTree of spaces.
See jaxnasium.spaces for more information on how to define (composite) observation spaces.
sample_observation(key: PRNGKeyArray) -> TObservation
¤
Convenience method to sample a random observation from the environment's observation space.
While one could use self.observation_space.sample(key), this method additionally works
on composite observation spaces.
action_space: Space | PyTree[Space]
abstractmethod
property
¤
Defines the space of valid actions for the environment.
For multi-agent environments, this should be a PyTree of spaces.
See jaxnasium.spaces for more information on how to define (composite) action spaces.
sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]]
¤
Convenience method to sample a random action from the environment's action space.
While one could use self.action_space.sample(key), this method additionally works on composite action spaces.
auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState]
¤
Auto-resets the environment when the episode is terminated or truncated.
Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.
Arguments:
key: JAX PRNG key.timestep_step: The timestep returned by thestep_envmethod.state_step: The state returned by thestep_envmethod.
Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.
multi_agent: bool
property
¤
Indicates if the environment is a multi-agent environment.
Infers this via the _multi_agent property. If not set, assumes single-agent.
agent_structure: PyTreeDef
property
¤
Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).
Timestep¤
src.jaxnasium.TimeStep
¤
A container for the output of an environment's step function.
(timestep, state = env.step(...)).
This class follows the Gymnasium standard API,
with the signature: (obs, reward, terminated, truncated, info) tuple.
Arguments:
observation: The environment state representation provided to the agent. Can be an Array or a PyTree of arrays. When using action masking, the observation should be of typeAgentObservation.reward: The reward signal from the previous action, indicating performance. Can be a scalar Array or a PyTree of reward Arrays (in the case of multi agent-environments).terminated: Boolean flag indicating whether the episode has ended due to reaching a terminal state (e.g., goal reached, game over).truncated: Boolean flag indicating whether the episode ended due to external factors (e.g., reaching max steps, timeout).info: Dictionary containing any additional information about the environment step.
Observation container (optional)¤
src.jaxnasium.AgentObservation
¤
A container for the observation of a single agent, with optional action masking.
Typically, this container is optional. However, Algorithms in
jaxnasium.algorithms expect observations to be wrapped in this type when
action masking is enabled.
Arguments:
observation: The observation of the agent.action_mask: The action mask of the agent. A boolean array of the same shape as the action space.