Source code for omnisafe.algorithms.off_policy.ddpg

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Deep Deterministic Policy Gradient algorithm."""

from __future__ import annotations

import time
from typing import Any

import torch
from torch import nn
from torch.nn.utils.clip_grad import clip_grad_norm_

from omnisafe.adapter import OffPolicyAdapter
from omnisafe.algorithms import registry
from omnisafe.algorithms.base_algo import BaseAlgo
from omnisafe.common.buffer import VectorOffPolicyBuffer
from omnisafe.common.logger import Logger
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic


[docs]@registry.register # pylint: disable-next=too-many-instance-attributes,too-few-public-methods class DDPG(BaseAlgo): """The Deep Deterministic Policy Gradient (DDPG) algorithm. References: - Title: Continuous control with deep reinforcement learning - Authors: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver, Daan Wierstra. - URL: `DDPG <https://arxiv.org/abs/1509.02971>`_ """ _epoch: int
[docs] def _init_env(self) -> None: """Initialize the environment. OmniSafe uses :class:`omnisafe.adapter.OffPolicyAdapter` to adapt the environment to this algorithm. User can customize the environment by inheriting this method. Examples: >>> def _init_env(self) -> None: ... self._env = CustomAdapter() Raises: AssertionError: If the number of steps per epoch is not divisible by the number of environments. AssertionError: If the total number of steps is not divisible by the number of steps per epoch. """ self._env: OffPolicyAdapter = OffPolicyAdapter( self._env_id, self._cfgs.train_cfgs.vector_env_nums, self._seed, self._cfgs, ) assert ( self._cfgs.algo_cfgs.steps_per_epoch % self._cfgs.train_cfgs.vector_env_nums == 0 ), 'The number of steps per epoch is not divisible by the number of environments.' assert ( int(self._cfgs.train_cfgs.total_steps) % self._cfgs.algo_cfgs.steps_per_epoch == 0 ), 'The total number of steps is not divisible by the number of steps per epoch.' self._epochs: int = int( self._cfgs.train_cfgs.total_steps // self._cfgs.algo_cfgs.steps_per_epoch, ) self._epoch: int = 0 self._steps_per_epoch: int = ( self._cfgs.algo_cfgs.steps_per_epoch // self._cfgs.train_cfgs.vector_env_nums ) self._update_cycle: int = self._cfgs.algo_cfgs.update_cycle assert ( self._steps_per_epoch % self._update_cycle == 0 ), 'The number of steps per epoch is not divisible by the number of steps per sample.' self._samples_per_epoch: int = self._steps_per_epoch // self._update_cycle self._update_count: int = 0
[docs] def _init_model(self) -> None: """Initialize the model. OmniSafe uses :class:`omnisafe.models.actor_critic.constraint_actor_q_critic.ConstraintActorQCritic` as the default model. User can customize the model by inheriting this method. Examples: >>> def _init_model(self) -> None: ... self._actor_critic = CustomActorQCritic() """ self._cfgs.model_cfgs.critic['num_critics'] = 1 self._actor_critic: ConstraintActorQCritic = ConstraintActorQCritic( obs_space=self._env.observation_space, act_space=self._env.action_space, model_cfgs=self._cfgs.model_cfgs, epochs=self._epochs, ).to(self._device)
[docs] def _init(self) -> None: """The initialization of the algorithm. User can define the initialization of the algorithm by inheriting this method. Examples: >>> def _init(self) -> None: ... super()._init() ... self._buffer = CustomBuffer() ... self._model = CustomModel() """ self._buf: VectorOffPolicyBuffer = VectorOffPolicyBuffer( obs_space=self._env.observation_space, act_space=self._env.action_space, size=self._cfgs.algo_cfgs.size, batch_size=self._cfgs.algo_cfgs.batch_size, num_envs=self._cfgs.train_cfgs.vector_env_nums, penalty_coefficient=self._cfgs.algo_cfgs.get('penalty_coefficient', 0.0), device=self._device, )
[docs] def _init_log(self) -> None: """Log info about epoch. +-------------------------+----------------------------------------------------------------------+ | Things to log | Description | +=========================+======================================================================+ | Train/Epoch | Current epoch. | +-------------------------+----------------------------------------------------------------------+ | Metrics/EpCost | Average cost of the epoch. | +-------------------------+----------------------------------------------------------------------+ | Metrics/EpRet | Average return of the epoch. | +-------------------------+----------------------------------------------------------------------+ | Metrics/EpLen | Average length of the epoch. | +-------------------------+----------------------------------------------------------------------+ | Metrics/TestEpCost | Average cost of the evaluate epoch. | +-------------------------+----------------------------------------------------------------------+ | Metrics/TestEpRet | Average return of the evaluate epoch. | +-------------------------+----------------------------------------------------------------------+ | Metrics/TestEpLen | Average length of the evaluate epoch. | +-------------------------+----------------------------------------------------------------------+ | Value/reward_critic | Average value in :meth:`rollout` (from critic network) of the epoch. | +-------------------------+----------------------------------------------------------------------+ | Values/cost_critic | Average cost in :meth:`rollout` (from critic network) of the epoch. | +-------------------------+----------------------------------------------------------------------+ | Loss/Loss_pi | Loss of the policy network. | +-------------------------+----------------------------------------------------------------------+ | Loss/Loss_reward_critic | Loss of the reward critic. | +-------------------------+----------------------------------------------------------------------+ | Loss/Loss_cost_critic | Loss of the cost critic network. | +-------------------------+----------------------------------------------------------------------+ | Train/LR | Learning rate of the policy network. | +-------------------------+----------------------------------------------------------------------+ | Misc/Seed | Seed of the experiment. | +-------------------------+----------------------------------------------------------------------+ | Misc/TotalEnvSteps | Total steps of the experiment. | +-------------------------+----------------------------------------------------------------------+ | Time/Total | Total time. | +-------------------------+----------------------------------------------------------------------+ | Time/Rollout | Rollout time. | +-------------------------+----------------------------------------------------------------------+ | Time/Update | Update time. | +-------------------------+----------------------------------------------------------------------+ | Time/Evaluate | Evaluate time. | +-------------------------+----------------------------------------------------------------------+ | FPS | Frames per second of the epoch. | +-------------------------+----------------------------------------------------------------------+ """ self._logger: Logger = Logger( output_dir=self._cfgs.logger_cfgs.log_dir, exp_name=self._cfgs.exp_name, seed=self._cfgs.seed, use_tensorboard=self._cfgs.logger_cfgs.use_tensorboard, use_wandb=self._cfgs.logger_cfgs.use_wandb, config=self._cfgs, ) what_to_save: dict[str, Any] = {} what_to_save['pi'] = self._actor_critic.actor if self._cfgs.algo_cfgs.obs_normalize: obs_normalizer = self._env.save()['obs_normalizer'] what_to_save['obs_normalizer'] = obs_normalizer self._logger.setup_torch_saver(what_to_save) self._logger.torch_save() self._logger.register_key( 'Metrics/EpRet', window_length=self._cfgs.logger_cfgs.window_lens, ) self._logger.register_key( 'Metrics/EpCost', window_length=self._cfgs.logger_cfgs.window_lens, ) self._logger.register_key( 'Metrics/EpLen', window_length=self._cfgs.logger_cfgs.window_lens, ) if self._cfgs.train_cfgs.eval_episodes > 0: self._logger.register_key( 'Metrics/TestEpRet', window_length=self._cfgs.logger_cfgs.window_lens, ) self._logger.register_key( 'Metrics/TestEpCost', window_length=self._cfgs.logger_cfgs.window_lens, ) self._logger.register_key( 'Metrics/TestEpLen', window_length=self._cfgs.logger_cfgs.window_lens, ) self._logger.register_key('Train/Epoch') self._logger.register_key('Train/LR') self._logger.register_key('TotalEnvSteps') # log information about actor self._logger.register_key('Loss/Loss_pi', delta=True) # log information about critic self._logger.register_key('Loss/Loss_reward_critic', delta=True) self._logger.register_key('Value/reward_critic') if self._cfgs.algo_cfgs.use_cost: # log information about cost critic self._logger.register_key('Loss/Loss_cost_critic', delta=True) self._logger.register_key('Value/cost_critic') self._logger.register_key('Time/Total') self._logger.register_key('Time/Rollout') self._logger.register_key('Time/Update') self._logger.register_key('Time/Evaluate') self._logger.register_key('Time/Epoch') self._logger.register_key('Time/FPS') # register environment specific keys for env_spec_key in self._env.env_spec_keys: self.logger.register_key(env_spec_key)
[docs] def learn(self) -> tuple[float, float, float]: """This is main function for algorithm update. It is divided into the following steps: - :meth:`rollout`: collect interactive data from environment. - :meth:`update`: perform actor/critic updates. - :meth:`log`: epoch/update information for visualization and terminal log print. Returns: ep_ret: average episode return in final epoch. ep_cost: average episode cost in final epoch. ep_len: average episode length in final epoch. """ self._logger.log('INFO: Start training') start_time = time.time() step = 0 for epoch in range(self._epochs): self._epoch = epoch rollout_time = 0.0 update_time = 0.0 epoch_time = time.time() for sample_step in range( epoch * self._samples_per_epoch, (epoch + 1) * self._samples_per_epoch, ): step = sample_step * self._update_cycle * self._cfgs.train_cfgs.vector_env_nums rollout_start = time.time() # set noise for exploration if self._cfgs.algo_cfgs.use_exploration_noise: self._actor_critic.actor.noise = self._cfgs.algo_cfgs.exploration_noise # collect data from environment self._env.rollout( rollout_step=self._update_cycle, agent=self._actor_critic, buffer=self._buf, logger=self._logger, use_rand_action=(step <= self._cfgs.algo_cfgs.start_learning_steps), ) rollout_time += time.time() - rollout_start # update parameters update_start = time.time() if step > self._cfgs.algo_cfgs.start_learning_steps: self._update() # if we haven't updated the network, log 0 for the loss else: self._log_when_not_update() update_time += time.time() - update_start eval_start = time.time() self._env.eval_policy( episode=self._cfgs.train_cfgs.eval_episodes, agent=self._actor_critic, logger=self._logger, ) eval_time = time.time() - eval_start self._logger.store({'Time/Update': update_time}) self._logger.store({'Time/Rollout': rollout_time}) self._logger.store({'Time/Evaluate': eval_time}) if ( step > self._cfgs.algo_cfgs.start_learning_steps and self._cfgs.model_cfgs.linear_lr_decay ): self._actor_critic.actor_scheduler.step() self._logger.store( { 'TotalEnvSteps': step + 1, 'Time/FPS': self._cfgs.algo_cfgs.steps_per_epoch / (time.time() - epoch_time), 'Time/Total': (time.time() - start_time), 'Time/Epoch': (time.time() - epoch_time), 'Train/Epoch': epoch, 'Train/LR': self._actor_critic.actor_scheduler.get_last_lr()[0], }, ) self._logger.dump_tabular() # save model to disk if (epoch + 1) % self._cfgs.logger_cfgs.save_model_freq == 0: self._logger.torch_save() ep_ret = self._logger.get_stats('Metrics/EpRet')[0] ep_cost = self._logger.get_stats('Metrics/EpCost')[0] ep_len = self._logger.get_stats('Metrics/EpLen')[0] self._logger.close() self._env.close() return ep_ret, ep_cost, ep_len
[docs] def _update(self) -> None: """Update actor, critic. - Get the ``data`` from buffer .. note:: +----------+---------------------------------------+ | obs | ``observaion`` stored in buffer. | +==========+=======================================+ | act | ``action`` stored in buffer. | +----------+---------------------------------------+ | reward | ``reward`` stored in buffer. | +----------+---------------------------------------+ | cost | ``cost`` stored in buffer. | +----------+---------------------------------------+ | next_obs | ``next observaion`` stored in buffer. | +----------+---------------------------------------+ | done | ``terminated`` stored in buffer. | +----------+---------------------------------------+ - Update value net by :meth:`_update_reward_critic`. - Update cost net by :meth:`_update_cost_critic`. - Update policy net by :meth:`_update_actor`. The basic process of each update is as follows: #. Get the mini-batch data from buffer. #. Get the loss of network. #. Update the network by loss. #. Repeat steps 2, 3 until the ``update_iters`` times. """ for _ in range(self._cfgs.algo_cfgs.update_iters): data = self._buf.sample_batch() self._update_count += 1 obs, act, reward, cost, done, next_obs = ( data['obs'], data['act'], data['reward'], data['cost'], data['done'], data['next_obs'], ) self._update_reward_critic(obs, act, reward, done, next_obs) if self._cfgs.algo_cfgs.use_cost: self._update_cost_critic(obs, act, cost, done, next_obs) if self._update_count % self._cfgs.algo_cfgs.policy_delay == 0: self._update_actor(obs) self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak)
[docs] def _update_reward_critic( self, obs: torch.Tensor, action: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, next_obs: torch.Tensor, ) -> None: """Update reward critic. - Get the TD loss of reward critic. - Update critic network by loss. - Log useful information. Args: obs (torch.Tensor): The ``observation`` sampled from buffer. action (torch.Tensor): The ``action`` sampled from buffer. reward (torch.Tensor): The ``reward`` sampled from buffer. done (torch.Tensor): The ``terminated`` sampled from buffer. next_obs (torch.Tensor): The ``next observation`` sampled from buffer. """ with torch.no_grad(): next_action = self._actor_critic.actor.predict(next_obs, deterministic=True) next_q_value_r = self._actor_critic.target_reward_critic(next_obs, next_action)[0] target_q_value_r = reward + self._cfgs.algo_cfgs.gamma * (1 - done) * next_q_value_r q_value_r = self._actor_critic.reward_critic(obs, action)[0] loss = nn.functional.mse_loss(q_value_r, target_q_value_r) if self._cfgs.algo_cfgs.use_critic_norm: for param in self._actor_critic.reward_critic.parameters(): loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coeff self._logger.store( { 'Loss/Loss_reward_critic': loss.mean().item(), 'Value/reward_critic': q_value_r.mean().item(), }, ) self._actor_critic.reward_critic_optimizer.zero_grad() loss.backward() if self._cfgs.algo_cfgs.max_grad_norm: clip_grad_norm_( self._actor_critic.reward_critic.parameters(), self._cfgs.algo_cfgs.max_grad_norm, ) self._actor_critic.reward_critic_optimizer.step()
[docs] def _update_cost_critic( self, obs: torch.Tensor, action: torch.Tensor, cost: torch.Tensor, done: torch.Tensor, next_obs: torch.Tensor, ) -> None: """Update cost critic. - Get the TD loss of cost critic. - Update critic network by loss. - Log useful information. Args: obs (torch.Tensor): The ``observation`` sampled from buffer. action (torch.Tensor): The ``action`` sampled from buffer. cost (torch.Tensor): The ``cost`` sampled from buffer. done (torch.Tensor): The ``terminated`` sampled from buffer. next_obs (torch.Tensor): The ``next observation`` sampled from buffer. """ with torch.no_grad(): next_action = self._actor_critic.actor.predict(next_obs, deterministic=True) next_q_value_c = self._actor_critic.target_cost_critic(next_obs, next_action)[0] target_q_value_c = cost + self._cfgs.algo_cfgs.gamma * (1 - done) * next_q_value_c q_value_c = self._actor_critic.cost_critic(obs, action)[0] loss = nn.functional.mse_loss(q_value_c, target_q_value_c) if self._cfgs.algo_cfgs.use_critic_norm: for param in self._actor_critic.cost_critic.parameters(): loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coeff self._actor_critic.cost_critic_optimizer.zero_grad() loss.backward() if self._cfgs.algo_cfgs.max_grad_norm: clip_grad_norm_( self._actor_critic.cost_critic.parameters(), self._cfgs.algo_cfgs.max_grad_norm, ) self._actor_critic.cost_critic_optimizer.step() self._logger.store( { 'Loss/Loss_cost_critic': loss.mean().item(), 'Value/cost_critic': q_value_c.mean().item(), }, )
[docs] def _update_actor( # pylint: disable=too-many-arguments self, obs: torch.Tensor, ) -> None: """Update actor. - Get the loss of actor. - Update actor by loss. - Log useful information. Args: obs (torch.Tensor): The ``observation`` sampled from buffer. """ loss = self._loss_pi(obs) self._actor_critic.actor_optimizer.zero_grad() loss.backward() if self._cfgs.algo_cfgs.max_grad_norm: clip_grad_norm_( self._actor_critic.actor.parameters(), self._cfgs.algo_cfgs.max_grad_norm, ) self._actor_critic.actor_optimizer.step() self._logger.store( { 'Loss/Loss_pi': loss.mean().item(), }, )
[docs] def _loss_pi( self, obs: torch.Tensor, ) -> torch.Tensor: r"""Computing ``pi/actor`` loss. The loss function in DDPG is defined as: .. math:: L = -Q^V (s, \pi (s)) where :math:`Q^V` is the reward critic network, and :math:`\pi` is the policy network. Args: obs (torch.Tensor): The ``observation`` sampled from buffer. Returns: The loss of pi/actor. """ action = self._actor_critic.actor.predict(obs, deterministic=True) return -self._actor_critic.reward_critic(obs, action)[0].mean()
[docs] def _log_when_not_update(self) -> None: """Log default value when not update.""" self._logger.store( { 'Loss/Loss_reward_critic': 0.0, 'Loss/Loss_pi': 0.0, 'Value/reward_critic': 0.0, }, ) if self._cfgs.algo_cfgs.use_cost: self._logger.store( { 'Loss/Loss_cost_critic': 0.0, 'Value/cost_critic': 0.0, }, )