Skip to content

Replay Buffer

src.jaxnasium.algorithms.utils.TransitionBuffer ¤

A buffer for storing transitions. Samples uniformly from the buffer. The buffer is implemented as a circular buffer, where the oldest transitions are overwritten when the buffer is full.

Arguments: max_size: The maximum size of the buffer. sample_batch_size: The number of transitions to sample from the buffer. data_sample: A sample Transition to initialize the buffer structure. num_batch_axes: The number of batch axes in the transition data. Defaults to 2, which expects the first two axes to be batch dimensions (e.g, (rollout_length, num_envs, ...))

insert(transition: Transition) -> TransitionBuffer ¤

Insert a transition into the buffer.

sample(key: PRNGKeyArray, with_replacement: bool = False) -> Transition ¤

Sample a batch of transitions from the buffer.