Proximal Policy Optimization¤
rex.ppo.train(env: Union[BaseEnv, Environment], config: Config, rng: jax.Array) -> PPOResult
¤
Train the PPO model.
PPO implementation based on the PPO implementation from purejaxrl: https://github.com/luchris429/purejaxrl
Parameters:
-
env(Union[BaseEnv, Environment]) –The environment to train on.
-
config(Config) –Configuration for the PPO algorithm.
-
rng(Array) –Random number generator key.
Returns:
-
PPOResult(PPOResult) –The result of the training process.
rex.ppo.Config
¤
Bases: Base
Configuration for PPO.
Inherit from this class and override the EVAL_METRICS_JAX_CB and EVAL_METRICS_HOST_CB methods to customize the
evaluation metrics and the host-side callback for the evaluation metrics.
Attributes:
-
LR(float) –The learning rate.
-
NUM_ENVS(int) –The number of parallel environments.
-
NUM_STEPS(int) –The number of steps to run in each environment per update.
-
TOTAL_TIMESTEPS(int) –The total number of timesteps to run.
-
UPDATE_EPOCHS(int) –The number of epochs to run per update.
-
NUM_MINIBATCHES(int) –The number of minibatches to split the data into.
-
GAMMA(float) –The discount factor.
-
GAE_LAMBDA(float) –The Generalized Advantage Estimation (GAE) parameter.
-
CLIP_EPS(float) –The clipping parameter for the ratio in the policy loss.
-
ENT_COEF(float) –The coefficient of the entropy regularizer.
-
VF_COEF(float) –The value function coefficient.
-
MAX_GRAD_NORM(float) –The maximum gradient norm.
-
NUM_HIDDEN_LAYERS(int) –The number of hidden layers (same for actor and critic).
-
NUM_HIDDEN_UNITS(int) –The number of hidden units per layer (same for actor and critic).
-
KERNEL_INIT_TYPE(str) –The kernel initialization type (same for actor and critic).
-
HIDDEN_ACTIVATION(str) –The hidden activation function (same for actor and critic).
-
STATE_INDEPENDENT_STD(bool) –Whether to use state-independent standard deviation for the actor.
-
SQUASH(bool) –Whether to squash the action output of the actor.
-
ANNEAL_LR(bool) –Whether to anneal the learning rate.
-
NORMALIZE_ENV(bool) –Whether to normalize the environment (observations and rewards), actions are always normalized.
-
FIXED_INIT(bool) –Whether to use fixed initial states for each parallel environment.
-
OFFSET_STEP(bool) –Whether to offset the step counter for each parallel environment to break temporal correlations.
-
NUM_EVAL_ENVS(int) –The number of evaluation environments.
-
EVAL_FREQ(int) –The number of evaluations to run per run of training.
-
VERBOSE(bool) –Whether to print verbose output.
-
DEBUG(bool) –Whether to print debug output per step.
EVAL_METRICS_JAX_CB(total_steps: Union[int, jax.Array], diagnostics: Diagnostics, eval_transitions: Transition = None) -> Dict
¤
Compute evaluation metrics for the PPO algorithm.
Parameters:
-
total_steps(Union[int, Array]) –The total number of steps run.
-
diagnostics(Diagnostics) –The diagnostics from the training process.
-
eval_transitions(Transition, default:None) –The transitions from the evaluation process.
Returns:
-
Dict(Dict) –A dictionary containing the evaluation metrics.
EVAL_METRICS_HOST_CB(metrics: Dict) -> None
¤
Evaluate the evaluation metrics for the PPO algorithm on the host.
Can be used for printing or logging the evaluation metrics on the host as this is side-effectful.
Parameters:
-
metrics(Dict) –The evaluation metrics.
rex.ppo.PPOResult
¤
Bases: Base
Represents the result of the PPO training process.
Attributes:
-
config(Config) –Configuration for the PPO algorithm.
-
runner_state(RunnerState) –The state of the runner after training.
-
metrics(Dict[str, Any]) –Dictionary containing various metrics collected during training.
rex.ppo.Policy
¤
Bases: Base
Represents the policy model.
Attributes:
-
act_scaling(SquashState) –The action scaling parameters.
-
obs_scaling(NormalizeVec) –The observation scaling parameters.
-
model(Dict[str, Dict[str, Union[ArrayLike, Any]]]) –The model parameters.
-
hidden_activation(str) –The hidden activation function.
-
output_activation(str) –The output activation function.
-
state_independent_std(bool) –Whether the standard deviation of the actor is state-independent
apply_actor(norm_obs: jax.typing.ArrayLike, rng: jax.Array = None) -> jax.Array
¤
Apply the actor model to the normalized observation
Parameters:
-
norm_obs(ArrayLike) –The normalized observation
-
rng(Array, default:None) –Random number generator key
Returns:
-
Array–The unscaled action
get_action(obs: jax.typing.ArrayLike, rng: jax.Array = None) -> jax.Array
¤
Get the action from the policy model
Parameters:
-
obs(ArrayLike) –The observation
-
rng(Array, default:None) –Random number generator key
Returns:
-
Array–The action, scaled to the action space.
rex.ppo.RunnerState
¤
Bases: Base
Represents the state of the runner during training.
Attributes:
-
train_state(TrainState) –The state of the training process.
-
env_state(GraphState) –The state of the environment.
-
last_obs(ArrayLike) –The last observation.
-
rng(Array) –Random number generator key