Source code for omnisafe.algorithms.off_policy.ddpg_lag
# Copyright 2023 OmniSafe Team. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Implementation of the Lagrangian version of Deep Deterministic Policy Gradient algorithm."""importtorchfromomnisafe.algorithmsimportregistryfromomnisafe.algorithms.off_policy.ddpgimportDDPGfromomnisafe.common.lagrangeimportLagrange
[docs]@registry.register# pylint: disable-next=too-many-instance-attributes, too-few-public-methodsclassDDPGLag(DDPG):"""The Lagrangian version of Deep Deterministic Policy Gradient (DDPG) algorithm. References: - Title: Continuous control with deep reinforcement learning - Authors: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver, Daan Wierstra. - URL: `DDPG <https://arxiv.org/abs/1509.02971>`_ """
[docs]def_init(self)->None:"""The initialization of the algorithm. Here we additionally initialize the Lagrange multiplier. """super()._init()self._lagrange:Lagrange=Lagrange(**self._cfgs.lagrange_cfgs)
[docs]def_init_log(self)->None:"""Log the DDPGLag specific information. +----------------------------+--------------------------+ | Things to log | Description | +============================+==========================+ | Metrics/LagrangeMultiplier | The Lagrange multiplier. | +----------------------------+--------------------------+ """super()._init_log()self._logger.register_key('Metrics/LagrangeMultiplier')
[docs]def_update(self)->None:"""Update actor, critic, as we used in the :class:`PolicyGradient` algorithm. Additionally, we update the Lagrange multiplier parameter by calling the :meth:`update_lagrange_multiplier` method. """super()._update()Jc=self._logger.get_stats('Metrics/EpCost')[0]ifself._epoch>self._cfgs.algo_cfgs.warmup_epochs:self._lagrange.update_lagrange_multiplier(Jc)self._logger.store({'Metrics/LagrangeMultiplier':self._lagrange.lagrangian_multiplier.data.item(),},)
[docs]def_loss_pi(self,obs:torch.Tensor,)->torch.Tensor:r"""Computing ``pi/actor`` loss. The loss function in DDPGLag is defined as: .. math:: L = -Q^V (s, \pi (s)) + \lambda Q^C (s, \pi (s)) where :math:`Q^V` is the min value of two reward critic networks outputs, :math:`Q^C` is the value of cost critic network, and :math:`\pi` is the policy network. Args: obs (torch.Tensor): The ``observation`` sampled from buffer. Returns: The loss of pi/actor. """action=self._actor_critic.actor.predict(obs,deterministic=True)loss_r=-self._actor_critic.reward_critic(obs,action)[0]loss_c=(self._lagrange.lagrangian_multiplier.item()*self._actor_critic.cost_critic(obs,action)[0])return(loss_r+loss_c).mean()/(1+self._lagrange.lagrangian_multiplier.item())
[docs]def_log_when_not_update(self)->None:"""Log default value when not update."""super()._log_when_not_update()self._logger.store({'Metrics/LagrangeMultiplier':self._lagrange.lagrangian_multiplier.data.item(),},)