Source code for omnisafe.algorithms.on_policy.second_order.pcpo

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the PCPO algorithm."""

import torch

from omnisafe.algorithms import registry
from omnisafe.algorithms.on_policy.second_order.cpo import CPO
from omnisafe.utils import distributed
from omnisafe.utils.math import conjugate_gradients
from omnisafe.utils.tools import (
    get_flat_gradients_from,
    get_flat_params_from,
    set_param_values_to_model,
)


[docs]@registry.register class PCPO(CPO): """The Projection-Based Constrained Policy Optimization (PCPO) algorithm. References: - Title: Projection-Based Constrained Policy Optimization - Authors: Tsung-Yen Yang, Justinian Rosca, Karthik Narasimhan, Peter J. Ramadge. - URL: `PCPO <https://arxiv.org/abs/2010.03152>`_ """ # pylint: disable-next=too-many-locals,too-many-arguments
[docs] def _update_actor( self, obs: torch.Tensor, act: torch.Tensor, logp: torch.Tensor, adv_r: torch.Tensor, adv_c: torch.Tensor, ) -> None: """Update policy network. PCPO updates policy network using the conjugate gradient algorithm, following the steps: - Compute the gradient of the policy. - Compute the step direction. - Search for a step size that satisfies the constraint. (Both KL divergence and cost limit). - Update the policy network. Args: obs (torch.Tensor): The observation tensor. act (torch.Tensor): The action tensor. logp (torch.Tensor): The log probability of the action. adv_r (torch.Tensor): The reward advantage tensor. adv_c (torch.Tensor): The cost advantage tensor. """ # pylint: disable=invalid-name self._fvp_obs = obs[:: self._cfgs.algo_cfgs.fvp_sample_freq] theta_old = get_flat_params_from(self._actor_critic.actor) self._actor_critic.actor.zero_grad() loss_reward = self._loss_pi(obs, act, logp, adv_r) loss_reward_before = distributed.dist_avg(loss_reward) p_dist = self._actor_critic.actor(obs) loss_reward.backward() distributed.avg_grads(self._actor_critic.actor) grads = -get_flat_gradients_from(self._actor_critic.actor) x = conjugate_gradients(self._fvp, grads, self._cfgs.algo_cfgs.cg_iters) assert torch.isfinite(x).all(), 'x is not finite' xHx = x.dot(self._fvp(x)) H_inv_g = self._fvp(x) assert xHx.item() >= 0, 'xHx is negative' alpha = torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (xHx + 1e-8)) self._actor_critic.zero_grad() loss_cost = self._loss_pi_cost(obs, act, logp, adv_c) loss_cost_before = distributed.dist_avg(loss_cost) loss_cost.backward() distributed.avg_grads(self._actor_critic.actor) b_grads = get_flat_gradients_from(self._actor_critic.actor) ep_costs = self._logger.get_stats('Metrics/EpCost')[0] - self._cfgs.algo_cfgs.cost_limit self._logger.log(f'c = {ep_costs}') self._logger.log(f'b^T b = {b_grads.dot(b_grads).item()}') p = conjugate_gradients(self._fvp, b_grads, self._cfgs.algo_cfgs.cg_iters) q = xHx r = grads.dot(p) s = b_grads.dot(p) step_direction = ( torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (q + 1e-8)) * H_inv_g - torch.clamp_min( (torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / q) * r + ep_costs) / s, torch.tensor(0.0, device=self._device), ) * p ) # pylint: disable=invalid-name step_direction, accept_step = self._cpo_search_step( step_direction=step_direction, grads=grads, p_dist=p_dist, obs=obs, act=act, logp=logp, adv_r=adv_r, adv_c=adv_c, loss_reward_before=loss_reward_before, loss_cost_before=loss_cost_before, total_steps=200, violation_c=ep_costs, ) theta_new = theta_old + step_direction set_param_values_to_model(self._actor_critic.actor, theta_new) with torch.no_grad(): loss_reward = self._loss_pi(obs, act, logp, adv_r) loss_cost = self._loss_pi_cost(obs, act, logp, adv_c) loss = loss_reward + loss_cost self._logger.store( { 'Loss/Loss_pi': loss.item(), 'Misc/AcceptanceStep': accept_step, 'Misc/Alpha': alpha.item(), 'Misc/FinalStepNorm': step_direction.norm().mean().item(), 'Misc/xHx': xHx.mean().item(), 'Misc/H_inv_g': x.norm().item(), # H^-1 g 'Misc/gradient_norm': torch.norm(grads).mean().item(), 'Misc/cost_gradient_norm': torch.norm(b_grads).mean().item(), 'Misc/Lambda_star': 1.0, 'Misc/Nu_star': 1.0, 'Misc/OptimCase': 1, 'Misc/A': 1.0, 'Misc/B': 1.0, 'Misc/q': q.item(), 'Misc/r': r.item(), 'Misc/s': s.item(), }, )