Environment
src.jymkit.Environment
¤
Base environment class for JAX-compatible environments. Create your environment by subclassing this.
step
and reset
should typically not be overridden, as they merely handle the
auto-reset logic. Instead, the environment-specific logic should be implemented in the
step_env
and reset_env
methods.
step(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
¤
Steps the environment forward with the given action and performs auto-reset when necessary.
Additionally, this function inserts the original observation (before auto-resetting) in
the info dictionary to bootstrap correctly on truncated episodes (info={"_TERMINAL_OBSERVATION": obs, ...}
)
This function should typically not be overridden. Instead, the environment-specific logic
should be implemented in the step_env
method.
Returns a TimeStep object (observation, reward, terminated, truncated, info) and the new state.
Arguments:
key
: JAX PRNG key.state
: Current state of the environment.action
: Action to take in the environment.
step_env(key: PRNGKeyArray, state: TEnvState, action: PyTree[Real[Array, ...]]) -> Tuple[TimeStep, TEnvState]
abstractmethod
¤
Defines the environment-specific step logic. I.e. here the state of the environment is updated according to the transition function.
Returns a TimeStep
object (observation, reward, terminated, truncated, info) and the new state.
Arguments:
key
: JAX PRNG key.state
: Current state of the environment.action
: Action to take in the environment.
reset(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
¤
Resets the environment to an initial state and returns the initial observation.
Environment-specific logic is defined in the reset_env
method. Typically, this function
should not be overridden.
Returns the initial observation and the initial state of the environment.
Arguments:
key
: JAX PRNG key.
reset_env(key: PRNGKeyArray) -> Tuple[TObservation, TEnvState]
abstractmethod
¤
Defines the environment-specific reset logic.
Returns the initial observation and the initial state of the environment.
Arguments:
key
: JAX PRNG key.
observation_space: Space | PyTree[Space]
abstractmethod
property
¤
Defines the space of possible observations from the environment.
For multi-agent environments, this should be a PyTree of spaces.
See jymkit.spaces
for more information on how to define (composite) observation spaces.
action_space: Space | PyTree[Space]
abstractmethod
property
¤
Defines the space of valid actions for the environment.
For multi-agent environments, this should be a PyTree of spaces.
See jymkit.spaces
for more information on how to define (composite) action spaces.
Timestep¤
src.jymkit.TimeStep
¤
A container for the output of an environment's step function.
(timestep, state = env.step(...)
).
This class follows the Gymnasium standard API,
with the signature: (obs, reward, terminated, truncated, info)
tuple.
Attributes:
Name | Type | Description |
---|---|---|
observation |
Num[Array, ...] | PyTree[Num[Array, ...]] | PyTree[AgentObservation]
|
The environment state representation provided to the agent.
Can be an Array or a PyTree of arrays.
When using action masking, the observation should be of type |
reward |
Float[Array, ...] | PyTree[Float[Array, ...]]
|
The reward signal from the previous action, indicating performance. Can be a scalar Array or a PyTree of reward Arrays (in the case of multi agent-environments). |
terminated |
Bool[Array, ...] | PyTree[Bool[Array, ...]]
|
Boolean flag indicating whether the episode has ended due to reaching a terminal state (e.g., goal reached, game over). |
truncated |
Bool[Array, ...] | PyTree[Bool[Array, ...]]
|
Boolean flag indicating whether the episode ended due to external factors (e.g., reaching max steps, timeout). |
info |
dict
|
Dictionary containing any additional information about the environment step. |
Observation container (optional)¤
src.jymkit.AgentObservation
¤
A container for the observation of a single agent, with optional action masking.
Typically, this container is optional. However, Algorithms in
jymkit.algorithms
expect observations to be wrapped in this type when
action masking is enabled.
Arguments:
observation
: The observation of the agent.action_mask
: The action mask of the agent. A boolean array of the same shape as the action space.