Skip to content

Distributions

src.jaxnasium.algorithms.utils.DistraxContainer ¤

Container for (possibly nested as PyTrees) distrax distributions.

sample(*, seed) ¤
log_prob(value) ¤

src.jaxnasium.algorithms.utils.TanhNormalFactory(low, high) -> Callable[..., TanhNormal] ¤