Source code for omnisafe.algorithms.on_policy.second_order.cpo

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the CPO algorithm."""

from __future__ import annotations

import torch

from omnisafe.algorithms import registry
from omnisafe.algorithms.on_policy.base.trpo import TRPO
from omnisafe.utils import distributed
from omnisafe.utils.math import conjugate_gradients
from omnisafe.utils.tools import (
    get_flat_gradients_from,
    get_flat_params_from,
    set_param_values_to_model,
)


[docs]@registry.register class CPO(TRPO): """The Constrained Policy Optimization (CPO) algorithm. CPO is a derivative of TRPO. References: - Title: Constrained Policy Optimization - Authors: Joshua Achiam, David Held, Aviv Tamar, Pieter Abbeel. - URL: `CPO <https://arxiv.org/abs/1705.10528>`_ """
[docs] def _init_log(self) -> None: super()._init_log() self._logger.register_key('Misc/cost_gradient_norm') self._logger.register_key('Misc/A') self._logger.register_key('Misc/B') self._logger.register_key('Misc/q') self._logger.register_key('Misc/r') self._logger.register_key('Misc/s') self._logger.register_key('Misc/Lambda_star') self._logger.register_key('Misc/Nu_star') self._logger.register_key('Misc/OptimCase')
# pylint: disable-next=too-many-arguments,too-many-locals
[docs] def _cpo_search_step( self, step_direction: torch.Tensor, grads: torch.Tensor, p_dist: torch.distributions.Distribution, obs: torch.Tensor, act: torch.Tensor, logp: torch.Tensor, adv_r: torch.Tensor, adv_c: torch.Tensor, loss_reward_before: torch.Tensor, loss_cost_before: torch.Tensor, total_steps: int = 15, decay: float = 0.8, violation_c: int = 0, optim_case: int = 0, ) -> tuple[torch.Tensor, int]: r"""Use line-search to find the step size that satisfies the constraint. CPO uses line-search to find the step size that satisfies the constraint. The constraint is defined as: .. math:: J^C (\theta + \alpha \delta) - J^C (\theta) \leq \max \{ 0, c \} \\ D_{KL} (\pi_{\theta} (\cdot|s) || \pi_{\theta + \alpha \delta} (\cdot|s)) \leq \delta_{KL} where :math:`\delta_{KL}` is the constraint of KL divergence, :math:`\alpha` is the step size, :math:`c` is the violation of constraint. Args: step_dir (torch.Tensor): The step direction. g_flat (torch.Tensor): The gradient of the policy. p_dist (torch.distributions.Distribution): The old policy distribution. obs (torch.Tensor): The observation. act (torch.Tensor): The action. logp (torch.Tensor): The log probability of the action. adv (torch.Tensor): The advantage. adv_c (torch.Tensor): The cost advantage. loss_pi_before (float): The loss of the policy before the update. total_steps (int, optional): The total steps to search. Defaults to 15. decay (float, optional): The decay rate of the step size. Defaults to 0.8. violation_c (int, optional): The violation of constraint. Defaults to 0. optim_case (int, optional): The optimization case. Defaults to 0. Returns: A tuple of final step direction and the size of acceptance steps. """ # get distance each time theta goes towards certain direction step_frac = 1.0 # get and flatten parameters from pi-net theta_old = get_flat_params_from(self._actor_critic.actor) # reward improvement, g-flat as gradient of reward expected_reward_improve = grads.dot(step_direction) kl = torch.zeros(1) # while not within_trust_region and not finish all steps: for step in range(total_steps): # get new theta new_theta = theta_old + step_frac * step_direction # set new theta as new actor parameters set_param_values_to_model(self._actor_critic.actor, new_theta) # the last acceptance steps to next step acceptance_step = step + 1 with torch.no_grad(): try: # loss of policy reward from target/expected reward loss_reward = self._loss_pi(obs=obs, act=act, logp=logp, adv=adv_r) except ValueError: step_frac *= decay continue # loss of cost of policy cost from real/expected reward loss_cost = self._loss_pi_cost(obs=obs, act=act, logp=logp, adv_c=adv_c) # compute KL distance between new and old policy q_dist = self._actor_critic.actor(obs) kl = torch.distributions.kl.kl_divergence(p_dist, q_dist).mean() # compute improvement of reward loss_reward_improve = loss_reward_before - loss_reward # compute difference of cost loss_cost_diff = loss_cost - loss_cost_before # average across MPI processes... kl = distributed.dist_avg(kl) # pi_average of torch_kl above loss_reward_improve = distributed.dist_avg(loss_reward_improve) loss_cost_diff = distributed.dist_avg(loss_cost_diff) self._logger.log( f'Expected Improvement: {expected_reward_improve} Actual: {loss_reward_improve}', ) # check whether there are nan. if not torch.isfinite(loss_reward) and not torch.isfinite(loss_cost): self._logger.log('WARNING: loss_pi not finite') if not torch.isfinite(kl): self._logger.log('WARNING: KL not finite') continue if loss_reward_improve < 0 if optim_case > 1 else False: self._logger.log('INFO: did not improve improve <0') # change of cost's range elif loss_cost_diff > max(-violation_c, 0): self._logger.log(f'INFO: no improve {loss_cost_diff} > {max(-violation_c, 0)}') # check KL-distance to avoid too far gap elif kl > self._cfgs.algo_cfgs.target_kl: self._logger.log(f'INFO: violated KL constraint {kl} at step {step + 1}.') else: # step only if surrogate is improved and we are # within the trust region self._logger.log(f'Accept step at i={step + 1}') break step_frac *= decay else: # if didn't find a step satisfy those conditions self._logger.log('INFO: no suitable step found...') step_direction = torch.zeros_like(step_direction) acceptance_step = 0 self._logger.store( { 'Train/KL': kl, }, ) set_param_values_to_model(self._actor_critic.actor, theta_old) return step_frac * step_direction, acceptance_step
[docs] def _loss_pi_cost( self, obs: torch.Tensor, act: torch.Tensor, logp: torch.Tensor, adv_c: torch.Tensor, ) -> torch.Tensor: r"""Compute the performance of cost on this moment. We compute the loss of cost of policy cost from real cost. .. math:: L = \mathbb{E}_{\pi} \left[ \frac{\pi^{'} (a|s)}{\pi (a|s)} A^C (s, a) \right] where :math:`A^C (s, a)` is the cost advantage, :math:`\pi (a|s)` is the old policy, and :math:`\pi^{'} (a|s)` is the current policy. Args: obs (torch.Tensor): The ``observation`` sampled from buffer. act (torch.Tensor): The ``action`` sampled from buffer. logp (torch.Tensor): The ``log probability`` of action sampled from buffer. adv_c (torch.Tensor): The ``cost_advantage`` sampled from buffer. Returns: The loss of the cost performance. """ self._actor_critic.actor(obs) logp_ = self._actor_critic.actor.log_prob(act) ratio = torch.exp(logp_ - logp) return (ratio * adv_c).mean()
# pylint: disable=invalid-name
[docs] def _determine_case( self, b_grads: torch.Tensor, ep_costs: torch.Tensor, q: torch.Tensor, r: torch.Tensor, s: torch.Tensor, ) -> tuple[int, torch.Tensor, torch.Tensor]: """Determine the case of the trust region update. Args: b_grad (torch.Tensor): Gradient of the cost function. ep_costs (torch.Tensor): Cost of the current episode. q (torch.Tensor): The quadratic term of the quadratic approximation of the cost function. r (torch.Tensor): The linear term of the quadratic approximation of the cost function. s (torch.Tensor): The constant term of the quadratic approximation of the cost function. Returns: optim_case: The case of the trust region update. A: The quadratic term of the quadratic approximation of the cost function. B: The linear term of the quadratic approximation of the cost function. """ if b_grads.dot(b_grads) <= 1e-6 and ep_costs < 0: # feasible step and cost grad is zero: use plain TRPO update... A = torch.zeros(1) B = torch.zeros(1) optim_case = 4 else: assert torch.isfinite(r).all(), 'r is not finite' assert torch.isfinite(s).all(), 's is not finite' A = q - r**2 / (s + 1e-8) B = 2 * self._cfgs.algo_cfgs.target_kl - ep_costs**2 / (s + 1e-8) if ep_costs < 0 and B < 0: # point in trust region is feasible and safety boundary doesn't intersect # ==> entire trust region is feasible optim_case = 3 elif ep_costs < 0 <= B: # point in trust region is feasible but safety boundary intersects # ==> only part of trust region is feasible optim_case = 2 elif ep_costs >= 0 and B >= 0: # point in trust region is infeasible and cost boundary doesn't intersect # ==> entire trust region is infeasible optim_case = 1 self._logger.log('Alert! Attempting feasible recovery!', 'yellow') else: # x = 0 infeasible, and safety half space is outside trust region # ==> whole trust region is infeasible, try to fail gracefully optim_case = 0 self._logger.log('Alert! Attempting infeasible recovery!', 'red') return optim_case, A, B
# pylint: disable=invalid-name, too-many-arguments, too-many-locals def _step_direction( self, optim_case: int, xHx: torch.Tensor, x: torch.Tensor, A: torch.Tensor, B: torch.Tensor, q: torch.Tensor, p: torch.Tensor, r: torch.Tensor, s: torch.Tensor, ep_costs: torch.Tensor, ) -> tuple[torch.Tensor, ...]: if optim_case in (3, 4): # under 3 and 4 cases directly use TRPO method alpha = torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (xHx + 1e-8)) nu_star = torch.zeros(1) lambda_star = 1 / (alpha + 1e-8) step_direction = alpha * x elif optim_case in (1, 2): def project(data: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: """Project data to [low, high] interval.""" return torch.clamp(data, low, high) # analytical Solution to LQCLP, employ lambda,nu to compute final solution of OLOLQC # λ=argmax(f_a(λ),f_b(λ)) = λa_star or λb_star # computing formula shown in appendix, lambda_a and lambda_b lambda_a = torch.sqrt(A / B) lambda_b = torch.sqrt(q / (2 * self._cfgs.algo_cfgs.target_kl)) # λa_star = Proj(lambda_a ,0 ~ r/c) λb_star=Proj(lambda_b,r/c~ +inf) # where projection(str,b,c)=max(b,min(str,c)) # may be regarded as a projection from effective region towards safety region r_num = r.item() eps_cost = ep_costs + 1e-8 if ep_costs < 0: lambda_a_star = project(lambda_a, torch.as_tensor(0.0), r_num / eps_cost) lambda_b_star = project(lambda_b, r_num / eps_cost, torch.as_tensor(torch.inf)) else: lambda_a_star = project(lambda_a, r_num / eps_cost, torch.as_tensor(torch.inf)) lambda_b_star = project(lambda_b, torch.as_tensor(0.0), r_num / eps_cost) def f_a(lam: torch.Tensor) -> torch.Tensor: return -0.5 * (A / (lam + 1e-8) + B * lam) - r * ep_costs / (s + 1e-8) def f_b(lam: torch.Tensor) -> torch.Tensor: return -0.5 * (q / (lam + 1e-8) + 2 * self._cfgs.algo_cfgs.target_kl * lam) lambda_star = ( lambda_a_star if f_a(lambda_a_star) >= f_b(lambda_b_star) else lambda_b_star ) # discard all negative values with torch.clamp(x, min=0) # Nu_star = (lambda_star * - r)/s nu_star = torch.clamp(lambda_star * ep_costs - r, min=0) / (s + 1e-8) # final x_star as final direction played as policy's loss to backward and update step_direction = 1.0 / (lambda_star + 1e-8) * (x - nu_star * p) else: # case == 0 # purely decrease costs # without further check lambda_star = torch.zeros(1) nu_star = torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (s + 1e-8)) step_direction = -nu_star * p return step_direction, lambda_star, nu_star # pylint: disable=invalid-name,too-many-arguments,too-many-locals
[docs] def _update_actor( self, obs: torch.Tensor, act: torch.Tensor, logp: torch.Tensor, adv_r: torch.Tensor, adv_c: torch.Tensor, ) -> None: """Update policy network. Constrained Policy Optimization updates policy network using the `conjugate gradient <https://en.wikipedia.org/wiki/Conjugate_gradient_method>`_ algorithm, following the steps: - Compute the gradient of the policy. - Compute the step direction. - Search for a step size that satisfies the constraint. - Update the policy network. Args: obs (torch.Tensor): The observation tensor. act (torch.Tensor): The action tensor. logp (torch.Tensor): The log probability of the action. adv_r (torch.Tensor): The reward advantage tensor. adv_c (torch.Tensor): The cost advantage tensor. """ self._fvp_obs = obs[:: self._cfgs.algo_cfgs.fvp_sample_freq] theta_old = get_flat_params_from(self._actor_critic.actor) self._actor_critic.actor.zero_grad() loss_reward = self._loss_pi(obs, act, logp, adv_r) loss_reward_before = distributed.dist_avg(loss_reward) p_dist = self._actor_critic.actor(obs) loss_reward.backward() distributed.avg_grads(self._actor_critic.actor) grads = -get_flat_gradients_from(self._actor_critic.actor) x = conjugate_gradients(self._fvp, grads, self._cfgs.algo_cfgs.cg_iters) assert torch.isfinite(x).all(), 'x is not finite' xHx = x.dot(self._fvp(x)) assert xHx.item() >= 0, 'xHx is negative' alpha = torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (xHx + 1e-8)) self._actor_critic.zero_grad() loss_cost = self._loss_pi_cost(obs, act, logp, adv_c) loss_cost_before = distributed.dist_avg(loss_cost) loss_cost.backward() distributed.avg_grads(self._actor_critic.actor) b_grads = get_flat_gradients_from(self._actor_critic.actor) ep_costs = self._logger.get_stats('Metrics/EpCost')[0] - self._cfgs.algo_cfgs.cost_limit p = conjugate_gradients(self._fvp, b_grads, self._cfgs.algo_cfgs.cg_iters) q = xHx r = grads.dot(p) s = b_grads.dot(p) optim_case, A, B = self._determine_case( b_grads=b_grads, ep_costs=ep_costs, q=q, r=r, s=s, ) step_direction, lambda_star, nu_star = self._step_direction( optim_case=optim_case, xHx=xHx, x=x, A=A, B=B, q=q, p=p, r=r, s=s, ep_costs=ep_costs, ) step_direction, accept_step = self._cpo_search_step( step_direction=step_direction, grads=grads, p_dist=p_dist, obs=obs, act=act, logp=logp, adv_r=adv_r, adv_c=adv_c, loss_reward_before=loss_reward_before, loss_cost_before=loss_cost_before, total_steps=20, violation_c=ep_costs, optim_case=optim_case, ) theta_new = theta_old + step_direction set_param_values_to_model(self._actor_critic.actor, theta_new) with torch.no_grad(): loss_reward = self._loss_pi(obs, act, logp, adv_r) loss_cost = self._loss_pi_cost(obs, act, logp, adv_c) loss = loss_reward + loss_cost self._logger.store( { 'Loss/Loss_pi': loss.item(), 'Misc/AcceptanceStep': accept_step, 'Misc/Alpha': alpha.item(), 'Misc/FinalStepNorm': step_direction.norm().mean().item(), 'Misc/xHx': xHx.mean().item(), 'Misc/H_inv_g': x.norm().item(), # H^-1 g 'Misc/gradient_norm': torch.norm(grads).mean().item(), 'Misc/cost_gradient_norm': torch.norm(b_grads).mean().item(), 'Misc/Lambda_star': lambda_star.item(), 'Misc/Nu_star': nu_star.item(), 'Misc/OptimCase': int(optim_case), 'Misc/A': A.item(), 'Misc/B': B.item(), 'Misc/q': q.item(), 'Misc/r': r.item(), 'Misc/s': s.item(), }, )