Source code for omnisafe.algorithms.off_policy.td3_lag

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Lagrangian version of Twin Delayed DDPG algorithm."""


import torch

from omnisafe.algorithms import registry
from omnisafe.algorithms.off_policy.td3 import TD3
from omnisafe.common.lagrange import Lagrange


[docs]@registry.register # pylint: disable-next=too-many-instance-attributes, too-few-public-methods class TD3Lag(TD3): """The Lagrangian version of Twin Delayed DDPG (TD3) algorithm. References: - Title: Addressing Function Approximation Error in Actor-Critic Methods - Authors: Scott Fujimoto, Herke van Hoof, David Meger. - URL: `TD3 <https://arxiv.org/abs/1802.09477>`_ """
[docs] def _init(self) -> None: """The initialization of the algorithm. Here we additionally initialize the Lagrange multiplier. """ super()._init() self._lagrange: Lagrange = Lagrange(**self._cfgs.lagrange_cfgs)
[docs] def _init_log(self) -> None: """Log the TD3Lag specific information. +----------------------------+--------------------------+ | Things to log | Description | +============================+==========================+ | Metrics/LagrangeMultiplier | The Lagrange multiplier. | +----------------------------+--------------------------+ """ super()._init_log() self._logger.register_key('Metrics/LagrangeMultiplier')
[docs] def _update(self) -> None: """Update actor, critic, as we used in the :class:`PolicyGradient` algorithm. Additionally, we update the Lagrange multiplier parameter by calling the :meth:`update_lagrange_multiplier` method. """ super()._update() Jc = self._logger.get_stats('Metrics/EpCost')[0] if self._epoch > self._cfgs.algo_cfgs.warmup_epochs: self._lagrange.update_lagrange_multiplier(Jc) self._logger.store( { 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(), }, )
[docs] def _loss_pi( self, obs: torch.Tensor, ) -> torch.Tensor: r"""Computing ``pi/actor`` loss. The loss function in TD3Lag is defined as: .. math:: L = -Q^V (s, \pi (s)) + \lambda Q^C (s, \pi (s)) where :math:`Q^V` is the min value of two reward critic networks outputs, :math:`Q^C` is the value of cost critic network, and :math:`\pi` is the policy network. Args: obs (torch.Tensor): The ``observation`` sampled from buffer. Returns: The loss of pi/actor. """ action = self._actor_critic.actor.predict(obs, deterministic=True) loss_r = -self._actor_critic.reward_critic(obs, action)[0] loss_q_c = self._actor_critic.cost_critic(obs, action)[0] loss_c = self._lagrange.lagrangian_multiplier.item() * loss_q_c return (loss_r + loss_c).mean() / (1 + self._lagrange.lagrangian_multiplier.item())
[docs] def _log_when_not_update(self) -> None: """Log default value when not update.""" super()._log_when_not_update() self._logger.store( { 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(), }, )