Source code for omnisafe.algorithms.off_policy.td3
# Copyright 2023 OmniSafe Team. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Implementation of the Twin Delayed DDPG algorithm."""importtorchfromtorchimportnnfromtorch.nn.utils.clip_gradimportclip_grad_norm_fromomnisafe.algorithmsimportregistryfromomnisafe.algorithms.off_policy.ddpgimportDDPGfromomnisafe.models.actor_critic.constraint_actor_q_criticimportConstraintActorQCritic
[docs]@registry.register# pylint: disable-next=too-many-instance-attributes,too-few-public-methodsclassTD3(DDPG):"""The Twin Delayed DDPG (TD3) algorithm. References: - Title: Addressing Function Approximation Error in Actor-Critic Methods - Authors: Scott Fujimoto, Herke van Hoof, David Meger. - URL: `TD3 <https://arxiv.org/abs/1802.09477>`_ """
[docs]def_init_model(self)->None:"""Initialize the model. The ``num_critics`` in ``critic`` configuration must be 2. """self._cfgs.model_cfgs.critic['num_critics']=2self._actor_critic=ConstraintActorQCritic(obs_space=self._env.observation_space,act_space=self._env.action_space,model_cfgs=self._cfgs.model_cfgs,epochs=self._epochs,).to(self._device)
[docs]def_update_reward_critic(self,obs:torch.Tensor,action:torch.Tensor,reward:torch.Tensor,done:torch.Tensor,next_obs:torch.Tensor,)->None:"""Update reward critic. - Get the target action by target actor. - Add noise to target action. - Clip the noise. - Get the target Q value by target critic. - Use the minimum target Q value to update reward critic. - Log useful information. Args: obs (torch.Tensor): The ``observation`` sampled from buffer. action (torch.Tensor): The ``action`` sampled from buffer. reward (torch.Tensor): The ``reward`` sampled from buffer. done (torch.Tensor): The ``terminated`` sampled from buffer. next_obs (torch.Tensor): The ``next observation`` sampled from buffer. """withtorch.no_grad():# set the update noise and noise clip.next_action=self._actor_critic.target_actor.predict(next_obs,deterministic=True)policy_noise=self._cfgs.algo_cfgs.policy_noisepolicy_noise_clip=self._cfgs.algo_cfgs.policy_noise_clipnoise=(torch.randn_like(next_action)*policy_noise).clamp(-policy_noise_clip,policy_noise_clip,)next_action=(next_action+noise).clamp(-1.0,1.0)next_q1_value_r,next_q2_value_r=self._actor_critic.target_reward_critic(next_obs,next_action,)next_q_value_r=torch.min(next_q1_value_r,next_q2_value_r)target_q_value_r=reward+self._cfgs.algo_cfgs.gamma*(1-done)*next_q_value_rq1_value_r,q2_value_r=self._actor_critic.reward_critic(obs,action)loss=nn.functional.mse_loss(q1_value_r,target_q_value_r)+nn.functional.mse_loss(q2_value_r,target_q_value_r,)ifself._cfgs.algo_cfgs.use_critic_norm:forparaminself._actor_critic.reward_critic.parameters():loss+=param.pow(2).sum()*self._cfgs.algo_cfgs.critic_norm_coeffself._actor_critic.reward_critic_optimizer.zero_grad()loss.backward()ifself._cfgs.algo_cfgs.max_grad_norm:clip_grad_norm_(self._actor_critic.reward_critic.parameters(),self._cfgs.algo_cfgs.max_grad_norm,)self._actor_critic.reward_critic_optimizer.step()self._logger.store({'Loss/Loss_reward_critic':loss.mean().item(),'Value/reward_critic':q1_value_r.mean().item(),},)