Cross-Entropy Method¤
rex.cem.cem(loss: Loss, solver: CEMSolver, init_state: CEMState, transform: Transform, max_steps: int = 100, rng: jax.Array = None, verbose: bool = True) -> Tuple[CEMState, jax.typing.ArrayLike]
¤
Run the Cross-Entropy Method (can be jit-compiled).
Parameters:
-
loss
(Loss
) –Loss function.
-
solver
(CEMSolver
) –CEM Solver.
-
init_state
(CEMState
) –Initial state of the CEM Solver.
-
transform
(Transform
) –Transform function (e.g. denormalization, extension, etc.).
-
max_steps
(int
, default:100
) –Maximum number of steps to run the CEM Solver.
-
rng
(Array
, default:None
) –Random number generator.
-
verbose
(bool
, default:True
) –Whether to print the progress.
Returns:
-
Tuple[CEMState, ArrayLike]
–The final state of the CEM Solver and the losses at each step.
rex.cem.CEMSolver
¤
See https://arxiv.org/pdf/1907.03613.pdf for details on CEM
Attributes:
-
u_min
(Dict[str, Params]
) –(Normalized) Minimum values for the parameters (pytree).
-
u_max
(Dict[str, Params]
) –(Normalized) Maximum values for the parameters (pytree).
-
evolution_smoothing
(Union[float, ArrayLike]
) –Smoothing factor for updating the mean and standard deviation.
-
num_samples
(int
) –Number of samples per iteration.
-
elite_portion
(float
) –The portion of the samples to consider
init(u_min: Dict[str, Params], u_max: Dict[str, Params], num_samples: int = 100, evolution_smoothing: Union[float, jax.typing.ArrayLike] = 0.1, elite_portion: float = 0.1) -> CEMSolver
classmethod
¤
Initialize the Cross-Entropy Method (CEM) Solver.
Parameters:
-
u_min
(Dict[str, Params]
) –(Normalized) Minimum values for the parameters (pytree).
-
u_max
(Dict[str, Params]
) –(Normalized) Maximum values for the parameters (pytree).
-
num_samples
(int
, default:100
) –Number of samples per iteration.
-
evolution_smoothing
(Union[float, ArrayLike]
, default:0.1
) –See https://arxiv.org/pdf/1907.03613.pdf for details.
-
elite_portion
(float
, default:0.1
) –See https://arxiv.org/pdf/1907.03613.pdf for details.
Returns:
-
CEMSolver
(CEMSolver
) –An instance of the CEMSolver class.
init_state(mean: Dict[str, Params], stdev: Dict[str, Params] = None) -> CEMState
¤
Initialize the state of the CEM Solver.
Parameters:
-
mean
(Dict[str, Params]
) –(Normalized) Mean values for the parameters (pytree).
-
stdev
(Dict[str, Params]
, default:None
) –(Normalized) Standard deviation values for the parameters (pytree).
Returns:
-
CEMState
(CEMState
) –The initialized state of the CEM Solver.
rex.cem.CEMState
¤
State of the CEM Solver.
Attributes:
-
mean
(Dict[str, Params]
) –(Normalized) Mean values for the parameters (pytree).
-
stdev
(Dict[str, Params]
) –(Normalized) Standard deviation values for the parameters (pytree).
-
bestsofar
(Dict[str, Params]
) –(Normalized) Best-so-far values for the parameters (pytree).
-
bestsofar_loss
(Union[float, ArrayLike]
) –Loss of the best-so-far values.