Skip to content

Environment

src.jymkit.Environment ¤

Base environment class for JAX-compatible environments. Create your environment by subclassing this.

step and reset should typically not be overridden, as they merely handle the auto-reset logic. Instead, the environment-specific logic should be implemented in the step_env and reset_env methods.

step(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState] ¤

Steps the environment forward with the given action and performs auto-reset when necessary. Additionally, this function inserts the original observation (before auto-resetting) in the info dictionary to bootstrap correctly on truncated episodes (info={"_TERMINAL_OBSERVATION": obs, ...})

This function should typically not be overridden. Instead, the environment-specific logic should be implemented in the step_env method.

Returns a TimeStep object (observation, reward, terminated, truncated, info) and the new state.

Arguments:

  • key: JAX PRNG key.
  • state: Current state of the environment.
  • action: Action to take in the environment.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState] abstractmethod ¤

Defines the environment-specific step logic. I.e. here the state of the environment is updated according to the transition function.

Returns a TimeStep object (observation, reward, terminated, truncated, info) and the new state.

Arguments:

  • key: JAX PRNG key.
  • state: Current state of the environment.
  • action: Action to take in the environment.
reset(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState] ¤

Resets the environment to an initial state and returns the initial observation. Environment-specific logic is defined in the reset_env method. Typically, this function should not be overridden.

Returns the initial observation and the initial state of the environment.

Arguments:

  • key: JAX PRNG key.
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState] abstractmethod ¤

Defines the environment-specific reset logic.

Returns the initial observation and the initial state of the environment.

Arguments:

  • key: JAX PRNG key.
observation_space: Space | PyTree[Space] abstractmethod property ¤

Defines the space of possible observations from the environment. For multi-agent environments, this should be a PyTree of spaces. See jymkit.spaces for more information on how to define (composite) observation spaces.

action_space: Space | PyTree[Space] abstractmethod property ¤

Defines the space of valid actions for the environment. For multi-agent environments, this should be a PyTree of spaces. See jymkit.spaces for more information on how to define (composite) action spaces.

Timestep¤

src.jymkit.TimeStep ¤

A container for the output of an environment's step function. (timestep, state = env.step(...)).

This class follows the Gymnasium standard API, with the signature: (obs, reward, terminated, truncated, info) tuple.

Attributes:

Name Type Description
observation Num[Array, ...] | PyTree[Num[Array, ...]] | PyTree[AgentObservation]

The environment state representation provided to the agent. Can be an Array or a PyTree of arrays. When using action masking, the observation should be of type AgentObservation.

reward Float[Array, ...] | PyTree[Float[Array, ...]]

The reward signal from the previous action, indicating performance. Can be a scalar Array or a PyTree of reward Arrays (in the case of multi agent-environments).

terminated Bool[Array, ...] | PyTree[Bool[Array, ...]]

Boolean flag indicating whether the episode has ended due to reaching a terminal state (e.g., goal reached, game over).

truncated Bool[Array, ...] | PyTree[Bool[Array, ...]]

Boolean flag indicating whether the episode ended due to external factors (e.g., reaching max steps, timeout).

info dict

Dictionary containing any additional information about the environment step.

Observation container (optional)¤

src.jymkit.AgentObservation ¤

A container for the observation of a single agent, with optional action masking.

Typically, this container is optional. However, Algorithms in jymkit.algorithms expect observations to be wrapped in this type when action masking is enabled.

Arguments:

  • observation: The observation of the agent.
  • action_mask: The action mask of the agent. A boolean array of the same shape as the action space.