Skip to content

Environment

src.jaxnasium.Environment ¤

Base environment class for JAX-compatible environments. Create your environment by subclassing this.

step and reset should typically not be overridden, as they merely handle the auto-reset logic. Instead, the environment-specific logic should be implemented in the step_env and reset_env methods.

step(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState] ¤

Steps the environment forward with the given action and performs auto-reset when necessary. Additionally, this function inserts the original observation (before auto-resetting) in the info dictionary to bootstrap correctly on truncated episodes (info={"_TERMINAL_OBSERVATION": obs, ...})

This function should typically not be overridden. Instead, the environment-specific logic should be implemented in the step_env method.

Returns a TimeStep object (observation, reward, terminated, truncated, info) and the new state.

Arguments:

  • key: JAX PRNG key.
  • state: Current state of the environment.
  • action: Action to take in the environment.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState] abstractmethod ¤

Defines the environment-specific step logic. I.e. here the state of the environment is updated according to the transition function.

Returns a TimeStep object (observation, reward, terminated, truncated, info) and the new state.

Arguments:

  • key: JAX PRNG key.
  • state: Current state of the environment.
  • action: Action to take in the environment.
reset(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState] ¤

Resets the environment to an initial state and returns the initial observation. Environment-specific logic is defined in the reset_env method. Typically, this function should not be overridden.

Returns the initial observation and the initial state of the environment.

Arguments:

  • key: JAX PRNG key.
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState] abstractmethod ¤

Defines the environment-specific reset logic.

Returns the initial observation and the initial state of the environment.

Arguments:

  • key: JAX PRNG key.
observation_space: Space | PyTree[Space] abstractmethod property ¤

Defines the space of possible observations from the environment. For multi-agent environments, this should be a PyTree of spaces. See jaxnasium.spaces for more information on how to define (composite) observation spaces.

sample_observation(key: PRNGKeyArray) -> TObservation ¤

Convenience method to sample a random observation from the environment's observation space. While one could use self.observation_space.sample(key), this method additionally works on composite observation spaces.

action_space: Space | PyTree[Space] abstractmethod property ¤

Defines the space of valid actions for the environment. For multi-agent environments, this should be a PyTree of spaces. See jaxnasium.spaces for more information on how to define (composite) action spaces.

sample_action(key: PRNGKeyArray) -> PyTree[Real[Array, ...]] ¤

Convenience method to sample a random action from the environment's action space. While one could use self.action_space.sample(key), this method additionally works on composite action spaces.

auto_reset(key: PRNGKeyArray, timestep_step: TimeStep, state_step: TEnvState) -> Tuple[TimeStep, TEnvState] ¤

Auto-resets the environment when the episode is terminated or truncated.

Given a step timestep and state, this function will auto-reset the environment and return the new timestep and state when the episode is terminated or truncated. Inserts the original observation in info to bootstrap correctly on truncated episodes.

Arguments:

  • key: JAX PRNG key.
  • timestep_step: The timestep returned by the step_env method.
  • state_step: The state returned by the step_env method.

Returns: A tuple of the new timestep and state with the state and observation reset to a new initial state and observation when the episode is terminated or truncated. The original observation is inserted in info to bootstrap correctly on truncated episodes.

multi_agent: bool property ¤

Indicates if the environment is a multi-agent environment.

Infers this via the _multi_agent property. If not set, assumes single-agent.

agent_structure: PyTreeDef property ¤

Returns the structure of the agent space. In single-agent environments this is simply a PyTreeDef(). However, for multi-agent environments this is a PyTreeDef((, x num_agents)).

Timestep¤

src.jaxnasium.TimeStep ¤

A container for the output of an environment's step function. (timestep, state = env.step(...)).

This class follows the Gymnasium standard API, with the signature: (obs, reward, terminated, truncated, info) tuple.

Arguments:

  • observation: The environment state representation provided to the agent. Can be an Array or a PyTree of arrays. When using action masking, the observation should be of type AgentObservation.
  • reward: The reward signal from the previous action, indicating performance. Can be a scalar Array or a PyTree of reward Arrays (in the case of multi agent-environments).
  • terminated: Boolean flag indicating whether the episode has ended due to reaching a terminal state (e.g., goal reached, game over).
  • truncated: Boolean flag indicating whether the episode ended due to external factors (e.g., reaching max steps, timeout).
  • info: Dictionary containing any additional information about the environment step.

Observation container (optional)¤

src.jaxnasium.AgentObservation ¤

A container for the observation of a single agent, with optional action masking.

Typically, this container is optional. However, Algorithms in jaxnasium.algorithms expect observations to be wrapped in this type when action masking is enabled.

Arguments:

  • observation: The observation of the agent.
  • action_mask: The action mask of the agent. A boolean array of the same shape as the action space.