Cross-Entropy Method¤
rex.cem.cem(loss: Loss, solver: CEMSolver, init_state: CEMState, transform: Transform, max_steps: int = 100, rng: jax.Array = None, verbose: bool = True) -> Tuple[CEMState, jax.typing.ArrayLike]
¤
Run the Cross-Entropy Method (can be jit-compiled).
Parameters:
-
loss(Loss) –Loss function.
-
solver(CEMSolver) –CEM Solver.
-
init_state(CEMState) –Initial state of the CEM Solver.
-
transform(Transform) –Transform function (e.g. denormalization, extension, etc.).
-
max_steps(int, default:100) –Maximum number of steps to run the CEM Solver.
-
rng(Array, default:None) –Random number generator.
-
verbose(bool, default:True) –Whether to print the progress.
Returns:
-
Tuple[CEMState, ArrayLike]–The final state of the CEM Solver and the losses at each step.
rex.cem.CEMSolver
¤
See https://arxiv.org/pdf/1907.03613.pdf for details on CEM
Attributes:
-
u_min(Dict[str, Params]) –(Normalized) Minimum values for the parameters (pytree).
-
u_max(Dict[str, Params]) –(Normalized) Maximum values for the parameters (pytree).
-
evolution_smoothing(Union[float, ArrayLike]) –Smoothing factor for updating the mean and standard deviation.
-
num_samples(int) –Number of samples per iteration.
-
elite_portion(float) –The portion of the samples to consider
init(u_min: Dict[str, Params], u_max: Dict[str, Params], num_samples: int = 100, evolution_smoothing: Union[float, jax.typing.ArrayLike] = 0.1, elite_portion: float = 0.1) -> CEMSolver
classmethod
¤
Initialize the Cross-Entropy Method (CEM) Solver.
Parameters:
-
u_min(Dict[str, Params]) –(Normalized) Minimum values for the parameters (pytree).
-
u_max(Dict[str, Params]) –(Normalized) Maximum values for the parameters (pytree).
-
num_samples(int, default:100) –Number of samples per iteration.
-
evolution_smoothing(Union[float, ArrayLike], default:0.1) –See https://arxiv.org/pdf/1907.03613.pdf for details.
-
elite_portion(float, default:0.1) –See https://arxiv.org/pdf/1907.03613.pdf for details.
Returns:
-
CEMSolver(CEMSolver) –An instance of the CEMSolver class.
init_state(mean: Dict[str, Params], stdev: Dict[str, Params] = None) -> CEMState
¤
Initialize the state of the CEM Solver.
Parameters:
-
mean(Dict[str, Params]) –(Normalized) Mean values for the parameters (pytree).
-
stdev(Dict[str, Params], default:None) –(Normalized) Standard deviation values for the parameters (pytree).
Returns:
-
CEMState(CEMState) –The initialized state of the CEM Solver.
rex.cem.CEMState
¤
State of the CEM Solver.
Attributes:
-
mean(Dict[str, Params]) –(Normalized) Mean values for the parameters (pytree).
-
stdev(Dict[str, Params]) –(Normalized) Standard deviation values for the parameters (pytree).
-
bestsofar(Dict[str, Params]) –(Normalized) Best-so-far values for the parameters (pytree).
-
bestsofar_loss(Union[float, ArrayLike]) –Loss of the best-so-far values.