Transitions
src.jaxnasium.algorithms.utils.Transition
¤
Container for (possibly batches of) transitions Comes with functionality of creating minibatches and some utilities for multi-agent reinforcement learning.
make_minibatches(key: PRNGKeyArray, n_minibatches: int, n_epochs: int = 1, n_batch_axis: int = 1) -> Transition
¤
Creates shuffled minibatches from the transition. Returns a copy of the transition with each leaf reshaped to (num_minibatches, ...),
This function first flattens the transition over the leading n_batch_axis. This is useful if your data hasn't been flattened yet and may be structured as (rollout_length, num_envs, ...), where num_envs is the number of parallel environments.
If n_epochs > 1, it will create n_epochs copies of the minibatches. and stack these such that there is a single leading axis to scan over for training.
Arguments:
- key: JAX PRNG key for randomization.
- num_minibatches: Number of minibatches to create.
- n_epochs: Number of copies the minibatches should be stacked.
- n_batch_axis: Number of leading batch axes to flatten over. Default is 1 (already flattened).
view_transposed: PyTree[Transition]
property
¤
For single-agent settings, this will do nothing and return the original transition.
For multi-agent settings: The original transition is a Transition of PyTrees e.g. Transition(observation={a1: ..., a2: ...}, action={a1: ..., a2: ...}, ...) The transposed transition is a PyTree of Transitions e.g. {a1: Transition(observation=..., action=..., ...), a2: Transition(observation=..., action=..., ...), ...} This is useful for multi-agent environments where we want to have a single Transition object per agent.
structure: PyTreeDef
property
¤
Returns the top-level structure of the transition objects (using reward as a reference). This is either PyTreeDef() for single agents or PyTreeDef((, x num_agents)) for multi-agent environments. usefull for unflattening Transition.flat.properties back to the original structure.
view_flat: Transition
property
¤
Returns a flattened version of the transition. Where possible, this is a jnp.stack of the leaves. Otherwise, it returns a list of leaves.