Skip to content

Tree

The jaxnasium.tree package is provided convenience pytree functions used in the various RL algorithms which aren't found in used higher-level libraries (equinox / jax / optax).

src.jaxnasium.tree._tree ¤

tree_mean(tree) ¤

Computes the global mean of the leaves of a pytree.

tree_get_first(tree: PyTree, key: str) -> Any ¤

Get the first value from a pytree with the given key. Like optax.tree.get() but returns the first value found in case of multiple matches instead of raising an error.

Arguments:

  • tree: A pytree.
  • key: A string key.

Returns: The first value from the pytree with the given key.

Raises: KeyError: If the key is not found in the pytree.

tree_map_one_level(fn: Callable, tree, *rest) ¤

Simple jax.tree.map operation over the first level of a pytree.

Arguments:

  • fn: A function to map over the pytree.
  • tree: A pytree.
  • *rest: Additional pytrees to map over.
tree_map_distribution(fn: Callable, tree, *rest) ¤

Map a function with distrax.Distribution instances marked as leaves. Additionally, if one of the inputs is a DistraxContainer, the function is applied to the distribution attribute of the DistraxContainer.

Arguments:

  • fn: A function to map over the pytree.
  • tree: A pytree.
  • *rest: Additional pytrees to map over.
tree_stack(pytrees: PyTree, *, axis=0) -> PyTree ¤

Stack corresponding leaves of pytrees along the specified axis.

Interprets the root node's immediate children as a batch of N pytrees that all share the same structure. For each leaf, stacks the N leaves along axis using jnp.stack. This does not traverse deeper than one level when determining what to stack.

Arguments:

  • pytrees: A pytree whose root has N immediate children. Each child must have the same pytree structure. Corresponding leaves must be array-like and have identical shapes and dtypes (compatible with jnp.stack).
  • axis: Axis along which to insert the new dimension of size N in each stacked leaf (default=0).

Returns: A pytree with the same structure as a single direct-child element of pytrees, where each leaf is the stack of the corresponding leaves across all elements, with a new dimension of size N inserted at axis.

Example:

    >>> trees = (
    ...     [jnp.array([1, 2]), jnp.array(4)],
    ...     [jnp.array([5, 5]), jnp.array(3)],
    ... )
    >>> stack_one_level(trees, axis=0)
    [Array([[1, 2], [5, 5]], dtype=int32), Array([4, 3], dtype=int32)]
Evolved from: link.

tree_unstack(tree, *, axis=0, structure: Optional[PyTreeDef] = None) ¤

Inverse of stack: split a pytree whose leaves were stacked along axis into N separate pytrees.

If structure is provided (e.g., from eqx.tree_flatten_one_level), the list of N pytrees is immediately placed back into that container and returned as a single pytree.

Arguments:

  • tree: A pytree whose leaves are array-like and all share the same size N along axis.
  • axis: The axis that carries the size-N dimension in each leaf (default=0).
  • structure: Optional PyTreeDef. If provided, the list of N pytrees is immediately placed back into that container and returned as a single pytree.

Returns: If structure is None: a list of N pytrees. Otherwise: a single pytree produced by unflattening structure with those N pytrees.

Example:

    >>> trees = (
    ...     [jnp.array([1, 2]), jnp.array(4)],
    ...     [jnp.array([5, 5]), jnp.array(3)],
    ... )
    >>> batched = stack(trees, axis=0)
    >>> unstack(batched, axis=0)
    [[Array([1, 2], dtype=int32), Array(4, dtype=int32)],
     [Array([5, 5], dtype=int32), Array(3, dtype=int32)]]

tree_concatenate(trees: PyTree) -> Array ¤

Concatenate the leaves of a pytree into a single 1D array.

**Arguments**:

- `trees`: A pytree whose leaves are array-like and all 1d or 0d.

**Returns**: A 1D array containing the concatenated leaves of the pytree.

**Example**:

python >>> tree = {'a': jnp.array([1, 2]), 'b': jnp.array(3)} >>> tree_concatenate(tree) Array([1, 2, 3], dtype=int32)