Proximal Policy Optimization¤
rex.ppo.train(env: Union[BaseEnv, Environment], config: Config, rng: jax.Array) -> PPOResult
¤
Train the PPO model.
PPO implementation based on the PPO implementation from purejaxrl: https://github.com/luchris429/purejaxrl
Parameters:
-
env
(Union[BaseEnv, Environment]
) –The environment to train on.
-
config
(Config
) –Configuration for the PPO algorithm.
-
rng
(Array
) –Random number generator key.
Returns:
-
PPOResult
(PPOResult
) –The result of the training process.
rex.ppo.Config
¤
Bases: Base
Configuration for PPO.
Inherit from this class and override the EVAL_METRICS_JAX_CB
and EVAL_METRICS_HOST_CB
methods to customize the
evaluation metrics and the host-side callback for the evaluation metrics.
Attributes:
-
LR
(float
) –The learning rate.
-
NUM_ENVS
(int
) –The number of parallel environments.
-
NUM_STEPS
(int
) –The number of steps to run in each environment per update.
-
TOTAL_TIMESTEPS
(int
) –The total number of timesteps to run.
-
UPDATE_EPOCHS
(int
) –The number of epochs to run per update.
-
NUM_MINIBATCHES
(int
) –The number of minibatches to split the data into.
-
GAMMA
(float
) –The discount factor.
-
GAE_LAMBDA
(float
) –The Generalized Advantage Estimation (GAE) parameter.
-
CLIP_EPS
(float
) –The clipping parameter for the ratio in the policy loss.
-
ENT_COEF
(float
) –The coefficient of the entropy regularizer.
-
VF_COEF
(float
) –The value function coefficient.
-
MAX_GRAD_NORM
(float
) –The maximum gradient norm.
-
NUM_HIDDEN_LAYERS
(int
) –The number of hidden layers (same for actor and critic).
-
NUM_HIDDEN_UNITS
(int
) –The number of hidden units per layer (same for actor and critic).
-
KERNEL_INIT_TYPE
(str
) –The kernel initialization type (same for actor and critic).
-
HIDDEN_ACTIVATION
(str
) –The hidden activation function (same for actor and critic).
-
STATE_INDEPENDENT_STD
(bool
) –Whether to use state-independent standard deviation for the actor.
-
SQUASH
(bool
) –Whether to squash the action output of the actor.
-
ANNEAL_LR
(bool
) –Whether to anneal the learning rate.
-
NORMALIZE_ENV
(bool
) –Whether to normalize the environment (observations and rewards), actions are always normalized.
-
FIXED_INIT
(bool
) –Whether to use fixed initial states for each parallel environment.
-
OFFSET_STEP
(bool
) –Whether to offset the step counter for each parallel environment to break temporal correlations.
-
NUM_EVAL_ENVS
(int
) –The number of evaluation environments.
-
EVAL_FREQ
(int
) –The number of evaluations to run per run of training.
-
VERBOSE
(bool
) –Whether to print verbose output.
-
DEBUG
(bool
) –Whether to print debug output per step.
EVAL_METRICS_JAX_CB(total_steps: Union[int, jax.Array], diagnostics: Diagnostics, eval_transitions: Transition = None) -> Dict
¤
Compute evaluation metrics for the PPO algorithm.
Parameters:
-
total_steps
(Union[int, Array]
) –The total number of steps run.
-
diagnostics
(Diagnostics
) –The diagnostics from the training process.
-
eval_transitions
(Transition
, default:None
) –The transitions from the evaluation process.
Returns:
-
Dict
(Dict
) –A dictionary containing the evaluation metrics.
EVAL_METRICS_HOST_CB(metrics: Dict) -> None
¤
Evaluate the evaluation metrics for the PPO algorithm on the host.
Can be used for printing or logging the evaluation metrics on the host as this is side-effectful.
Parameters:
-
metrics
(Dict
) –The evaluation metrics.
rex.ppo.PPOResult
¤
Bases: Base
Represents the result of the PPO training process.
Attributes:
-
config
(Config
) –Configuration for the PPO algorithm.
-
runner_state
(RunnerState
) –The state of the runner after training.
-
metrics
(Dict[str, Any]
) –Dictionary containing various metrics collected during training.
rex.ppo.Policy
¤
Bases: Base
Represents the policy model.
Attributes:
-
act_scaling
(SquashState
) –The action scaling parameters.
-
obs_scaling
(NormalizeVec
) –The observation scaling parameters.
-
model
(Dict[str, Dict[str, Union[ArrayLike, Any]]]
) –The model parameters.
-
hidden_activation
(str
) –The hidden activation function.
-
output_activation
(str
) –The output activation function.
-
state_independent_std
(bool
) –Whether the standard deviation of the actor is state-independent
apply_actor(norm_obs: jax.typing.ArrayLike, rng: jax.Array = None) -> jax.Array
¤
Apply the actor model to the normalized observation
Parameters:
-
norm_obs
(ArrayLike
) –The normalized observation
-
rng
(Array
, default:None
) –Random number generator key
Returns:
-
Array
–The unscaled action
get_action(obs: jax.typing.ArrayLike, rng: jax.Array = None) -> jax.Array
¤
Get the action from the policy model
Parameters:
-
obs
(ArrayLike
) –The observation
-
rng
(Array
, default:None
) –Random number generator key
Returns:
-
Array
–The action, scaled to the action space.
rex.ppo.RunnerState
¤
Bases: Base
Represents the state of the runner during training.
Attributes:
-
train_state
(TrainState
) –The state of the training process.
-
env_state
(GraphState
) –The state of the environment.
-
last_obs
(ArrayLike
) –The last observation.
-
rng
(Array
) –Random number generator key