Skip to content

Cross-Entropy Method¤

rex.cem.cem(loss: Loss, solver: CEMSolver, init_state: CEMState, transform: Transform, max_steps: int = 100, rng: jax.Array = None, verbose: bool = True) -> Tuple[CEMState, jax.typing.ArrayLike] ¤

Run the Cross-Entropy Method (can be jit-compiled).

Parameters:

  • loss (Loss) –

    Loss function.

  • solver (CEMSolver) –

    CEM Solver.

  • init_state (CEMState) –

    Initial state of the CEM Solver.

  • transform (Transform) –

    Transform function (e.g. denormalization, extension, etc.).

  • max_steps (int, default: 100 ) –

    Maximum number of steps to run the CEM Solver.

  • rng (Array, default: None ) –

    Random number generator.

  • verbose (bool, default: True ) –

    Whether to print the progress.

Returns:

  • Tuple[CEMState, ArrayLike]

    The final state of the CEM Solver and the losses at each step.


rex.cem.CEMSolver ¤

See https://arxiv.org/pdf/1907.03613.pdf for details on CEM

Attributes:

  • u_min (Dict[str, Params]) –

    (Normalized) Minimum values for the parameters (pytree).

  • u_max (Dict[str, Params]) –

    (Normalized) Maximum values for the parameters (pytree).

  • evolution_smoothing (Union[float, ArrayLike]) –

    Smoothing factor for updating the mean and standard deviation.

  • num_samples (int) –

    Number of samples per iteration.

  • elite_portion (float) –

    The portion of the samples to consider

init(u_min: Dict[str, Params], u_max: Dict[str, Params], num_samples: int = 100, evolution_smoothing: Union[float, jax.typing.ArrayLike] = 0.1, elite_portion: float = 0.1) -> CEMSolver classmethod ¤

Initialize the Cross-Entropy Method (CEM) Solver.

Parameters:

  • u_min (Dict[str, Params]) –

    (Normalized) Minimum values for the parameters (pytree).

  • u_max (Dict[str, Params]) –

    (Normalized) Maximum values for the parameters (pytree).

  • num_samples (int, default: 100 ) –

    Number of samples per iteration.

  • evolution_smoothing (Union[float, ArrayLike], default: 0.1 ) –
  • elite_portion (float, default: 0.1 ) –

Returns:

  • CEMSolver ( CEMSolver ) –

    An instance of the CEMSolver class.

init_state(mean: Dict[str, Params], stdev: Dict[str, Params] = None) -> CEMState ¤

Initialize the state of the CEM Solver.

Parameters:

  • mean (Dict[str, Params]) –

    (Normalized) Mean values for the parameters (pytree).

  • stdev (Dict[str, Params], default: None ) –

    (Normalized) Standard deviation values for the parameters (pytree).

Returns:

  • CEMState ( CEMState ) –

    The initialized state of the CEM Solver.


rex.cem.CEMState ¤

State of the CEM Solver.

Attributes:

  • mean (Dict[str, Params]) –

    (Normalized) Mean values for the parameters (pytree).

  • stdev (Dict[str, Params]) –

    (Normalized) Standard deviation values for the parameters (pytree).

  • bestsofar (Dict[str, Params]) –

    (Normalized) Best-so-far values for the parameters (pytree).

  • bestsofar_loss (Union[float, ArrayLike]) –

    Loss of the best-so-far values.