Tree
The jaxnasium.tree package is provided convenience pytree functions used in the various RL algorithms which
aren't found in used higher-level libraries (equinox / jax / optax).
src.jaxnasium.tree._tree
¤
tree_mean(tree)
¤
Computes the global mean of the leaves of a pytree.
tree_get_first(tree: PyTree, key: str) -> Any
¤
Get the first value from a pytree with the given key.
Like optax.tree.get() but returns the first value found in case
of multiple matches instead of raising an error.
Arguments:
tree: A pytree.key: A string key.
Returns: The first value from the pytree with the given key.
Raises: KeyError: If the key is not found in the pytree.
tree_map_one_level(fn: Callable, tree, *rest)
¤
Simple jax.tree.map operation over the first level of a pytree.
Arguments:
fn: A function to map over the pytree.tree: A pytree.*rest: Additional pytrees to map over.
tree_map_distribution(fn: Callable, tree, *rest)
¤
Map a function with distrax.Distribution instances marked as leaves.
Additionally, if one of the inputs is a DistraxContainer, the function
is applied to the distribution attribute of the DistraxContainer.
Arguments:
fn: A function to map over the pytree.tree: A pytree.*rest: Additional pytrees to map over.
tree_stack(pytrees: PyTree, *, axis=0) -> PyTree
¤
Stack corresponding leaves of pytrees along the specified axis.
Interprets the root node's immediate children as a batch of N pytrees that all
share the same structure. For each leaf, stacks the N leaves along axis using
jnp.stack. This does not traverse deeper than one level when determining what to stack.
Arguments:
pytrees: A pytree whose root has N immediate children. Each child must have the same pytree structure. Corresponding leaves must be array-like and have identical shapes and dtypes (compatible with jnp.stack).axis: Axis along which to insert the new dimension of size N in each stacked leaf (default=0).
Returns:
A pytree with the same structure as a single direct-child element of pytrees, where each
leaf is the stack of the corresponding leaves across all elements, with a new
dimension of size N inserted at axis.
Example:
>>> trees = (
... [jnp.array([1, 2]), jnp.array(4)],
... [jnp.array([5, 5]), jnp.array(3)],
... )
>>> stack_one_level(trees, axis=0)
[Array([[1, 2], [5, 5]], dtype=int32), Array([4, 3], dtype=int32)]
tree_unstack(tree, *, axis=0, structure: Optional[PyTreeDef] = None)
¤
Inverse of stack: split a pytree whose leaves were stacked along axis
into N separate pytrees.
If structure is provided (e.g., from eqx.tree_flatten_one_level),
the list of N pytrees is immediately placed back into that container and
returned as a single pytree.
Arguments:
tree: A pytree whose leaves are array-like and all share the same size N alongaxis.axis: The axis that carries the size-N dimension in each leaf (default=0).structure: OptionalPyTreeDef. If provided, the list of N pytrees is immediately placed back into that container and returned as a single pytree.
Returns:
If structure is None: a list of N pytrees.
Otherwise: a single pytree produced by unflattening structure with those N pytrees.
Example:
>>> trees = (
... [jnp.array([1, 2]), jnp.array(4)],
... [jnp.array([5, 5]), jnp.array(3)],
... )
>>> batched = stack(trees, axis=0)
>>> unstack(batched, axis=0)
[[Array([1, 2], dtype=int32), Array(4, dtype=int32)],
[Array([5, 5], dtype=int32), Array(3, dtype=int32)]]
tree_concatenate(trees: PyTree) -> Array
¤
Concatenate the leaves of a pytree into a single 1D array.
**Arguments**:
- `trees`: A pytree whose leaves are array-like and all 1d or 0d.
**Returns**: A 1D array containing the concatenated leaves of the pytree.
**Example**:
python
>>> tree = {'a': jnp.array([1, 2]), 'b': jnp.array(3)}
>>> tree_concatenate(tree)
Array([1, 2, 3], dtype=int32)