Skip to content

Proximal Policy Optimization¤

rex.ppo.train(env: Union[BaseEnv, Environment], config: Config, rng: jax.Array) -> PPOResult ¤

Train the PPO model.

PPO implementation based on the PPO implementation from purejaxrl: https://github.com/luchris429/purejaxrl

Parameters:

  • env (Union[BaseEnv, Environment]) –

    The environment to train on.

  • config (Config) –

    Configuration for the PPO algorithm.

  • rng (Array) –

    Random number generator key.

Returns:

  • PPOResult ( PPOResult ) –

    The result of the training process.


rex.ppo.Config ¤

Bases: Base

Configuration for PPO.

Inherit from this class and override the EVAL_METRICS_JAX_CB and EVAL_METRICS_HOST_CB methods to customize the evaluation metrics and the host-side callback for the evaluation metrics.

Attributes:

  • LR (float) –

    The learning rate.

  • NUM_ENVS (int) –

    The number of parallel environments.

  • NUM_STEPS (int) –

    The number of steps to run in each environment per update.

  • TOTAL_TIMESTEPS (int) –

    The total number of timesteps to run.

  • UPDATE_EPOCHS (int) –

    The number of epochs to run per update.

  • NUM_MINIBATCHES (int) –

    The number of minibatches to split the data into.

  • GAMMA (float) –

    The discount factor.

  • GAE_LAMBDA (float) –

    The Generalized Advantage Estimation (GAE) parameter.

  • CLIP_EPS (float) –

    The clipping parameter for the ratio in the policy loss.

  • ENT_COEF (float) –

    The coefficient of the entropy regularizer.

  • VF_COEF (float) –

    The value function coefficient.

  • MAX_GRAD_NORM (float) –

    The maximum gradient norm.

  • NUM_HIDDEN_LAYERS (int) –

    The number of hidden layers (same for actor and critic).

  • NUM_HIDDEN_UNITS (int) –

    The number of hidden units per layer (same for actor and critic).

  • KERNEL_INIT_TYPE (str) –

    The kernel initialization type (same for actor and critic).

  • HIDDEN_ACTIVATION (str) –

    The hidden activation function (same for actor and critic).

  • STATE_INDEPENDENT_STD (bool) –

    Whether to use state-independent standard deviation for the actor.

  • SQUASH (bool) –

    Whether to squash the action output of the actor.

  • ANNEAL_LR (bool) –

    Whether to anneal the learning rate.

  • NORMALIZE_ENV (bool) –

    Whether to normalize the environment (observations and rewards), actions are always normalized.

  • FIXED_INIT (bool) –

    Whether to use fixed initial states for each parallel environment.

  • OFFSET_STEP (bool) –

    Whether to offset the step counter for each parallel environment to break temporal correlations.

  • NUM_EVAL_ENVS (int) –

    The number of evaluation environments.

  • EVAL_FREQ (int) –

    The number of evaluations to run per run of training.

  • VERBOSE (bool) –

    Whether to print verbose output.

  • DEBUG (bool) –

    Whether to print debug output per step.

EVAL_METRICS_JAX_CB(total_steps: Union[int, jax.Array], diagnostics: Diagnostics, eval_transitions: Transition = None) -> Dict ¤

Compute evaluation metrics for the PPO algorithm.

Parameters:

  • total_steps (Union[int, Array]) –

    The total number of steps run.

  • diagnostics (Diagnostics) –

    The diagnostics from the training process.

  • eval_transitions (Transition, default: None ) –

    The transitions from the evaluation process.

Returns:

  • Dict ( Dict ) –

    A dictionary containing the evaluation metrics.

EVAL_METRICS_HOST_CB(metrics: Dict) -> None ¤

Evaluate the evaluation metrics for the PPO algorithm on the host.

Can be used for printing or logging the evaluation metrics on the host as this is side-effectful.

Parameters:

  • metrics (Dict) –

    The evaluation metrics.


rex.ppo.PPOResult ¤

Bases: Base

Represents the result of the PPO training process.

Attributes:

  • config (Config) –

    Configuration for the PPO algorithm.

  • runner_state (RunnerState) –

    The state of the runner after training.

  • metrics (Dict[str, Any]) –

    Dictionary containing various metrics collected during training.

obs_scaling: SquashState property ¤

Returns the observation scaling parameters.

act_scaling: SquashActionWrapper property ¤

Returns the action scaling parameters.

policy: Policy property ¤

Returns the policy model.


rex.ppo.Policy ¤

Bases: Base

Represents the policy model.

Attributes:

  • act_scaling (SquashState) –

    The action scaling parameters.

  • obs_scaling (NormalizeVec) –

    The observation scaling parameters.

  • model (Dict[str, Dict[str, Union[ArrayLike, Any]]]) –

    The model parameters.

  • hidden_activation (str) –

    The hidden activation function.

  • output_activation (str) –

    The output activation function.

  • state_independent_std (bool) –

    Whether the standard deviation of the actor is state-independent

apply_actor(norm_obs: jax.typing.ArrayLike, rng: jax.Array = None) -> jax.Array ¤

Apply the actor model to the normalized observation

Parameters:

  • norm_obs (ArrayLike) –

    The normalized observation

  • rng (Array, default: None ) –

    Random number generator key

Returns:

  • Array

    The unscaled action

get_action(obs: jax.typing.ArrayLike, rng: jax.Array = None) -> jax.Array ¤

Get the action from the policy model

Parameters:

  • obs (ArrayLike) –

    The observation

  • rng (Array, default: None ) –

    Random number generator key

Returns:

  • Array

    The action, scaled to the action space.


rex.ppo.RunnerState ¤

Bases: Base

Represents the state of the runner during training.

Attributes:

  • train_state (TrainState) –

    The state of the training process.

  • env_state (GraphState) –

    The state of the environment.

  • last_obs (ArrayLike) –

    The last observation.

  • rng (Array) –

    Random number generator key