# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Deep Deterministic Policy Gradient algorithm."""
from __future__ import annotations
import time
from typing import Any
import torch
from torch import nn
from torch.nn.utils.clip_grad import clip_grad_norm_
from omnisafe.adapter import OffPolicyAdapter
from omnisafe.algorithms import registry
from omnisafe.algorithms.base_algo import BaseAlgo
from omnisafe.common.buffer import VectorOffPolicyBuffer
from omnisafe.common.logger import Logger
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
[docs]@registry.register
# pylint: disable-next=too-many-instance-attributes,too-few-public-methods
class DDPG(BaseAlgo):
"""The Deep Deterministic Policy Gradient (DDPG) algorithm.
References:
- Title: Continuous control with deep reinforcement learning
- Authors: Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess,
Tom Erez, Yuval Tassa, David Silver, Daan Wierstra.
- URL: `DDPG <https://arxiv.org/abs/1509.02971>`_
"""
_epoch: int
[docs] def _init_env(self) -> None:
"""Initialize the environment.
OmniSafe uses :class:`omnisafe.adapter.OffPolicyAdapter` to adapt the environment to this
algorithm.
User can customize the environment by inheriting this method.
Examples:
>>> def _init_env(self) -> None:
... self._env = CustomAdapter()
Raises:
AssertionError: If the number of steps per epoch is not divisible by the number of
environments.
AssertionError: If the total number of steps is not divisible by the number of steps per
epoch.
"""
self._env: OffPolicyAdapter = OffPolicyAdapter(
self._env_id,
self._cfgs.train_cfgs.vector_env_nums,
self._seed,
self._cfgs,
)
assert (
self._cfgs.algo_cfgs.steps_per_epoch % self._cfgs.train_cfgs.vector_env_nums == 0
), 'The number of steps per epoch is not divisible by the number of environments.'
assert (
int(self._cfgs.train_cfgs.total_steps) % self._cfgs.algo_cfgs.steps_per_epoch == 0
), 'The total number of steps is not divisible by the number of steps per epoch.'
self._epochs: int = int(
self._cfgs.train_cfgs.total_steps // self._cfgs.algo_cfgs.steps_per_epoch,
)
self._epoch: int = 0
self._steps_per_epoch: int = (
self._cfgs.algo_cfgs.steps_per_epoch // self._cfgs.train_cfgs.vector_env_nums
)
self._update_cycle: int = self._cfgs.algo_cfgs.update_cycle
assert (
self._steps_per_epoch % self._update_cycle == 0
), 'The number of steps per epoch is not divisible by the number of steps per sample.'
self._samples_per_epoch: int = self._steps_per_epoch // self._update_cycle
self._update_count: int = 0
[docs] def _init_model(self) -> None:
"""Initialize the model.
OmniSafe uses :class:`omnisafe.models.actor_critic.constraint_actor_q_critic.ConstraintActorQCritic`
as the default model.
User can customize the model by inheriting this method.
Examples:
>>> def _init_model(self) -> None:
... self._actor_critic = CustomActorQCritic()
"""
self._cfgs.model_cfgs.critic['num_critics'] = 1
self._actor_critic: ConstraintActorQCritic = ConstraintActorQCritic(
obs_space=self._env.observation_space,
act_space=self._env.action_space,
model_cfgs=self._cfgs.model_cfgs,
epochs=self._epochs,
).to(self._device)
[docs] def _init(self) -> None:
"""The initialization of the algorithm.
User can define the initialization of the algorithm by inheriting this method.
Examples:
>>> def _init(self) -> None:
... super()._init()
... self._buffer = CustomBuffer()
... self._model = CustomModel()
"""
self._buf: VectorOffPolicyBuffer = VectorOffPolicyBuffer(
obs_space=self._env.observation_space,
act_space=self._env.action_space,
size=self._cfgs.algo_cfgs.size,
batch_size=self._cfgs.algo_cfgs.batch_size,
num_envs=self._cfgs.train_cfgs.vector_env_nums,
penalty_coefficient=self._cfgs.algo_cfgs.get('penalty_coefficient', 0.0),
device=self._device,
)
[docs] def _init_log(self) -> None:
"""Log info about epoch.
+-------------------------+----------------------------------------------------------------------+
| Things to log | Description |
+=========================+======================================================================+
| Train/Epoch | Current epoch. |
+-------------------------+----------------------------------------------------------------------+
| Metrics/EpCost | Average cost of the epoch. |
+-------------------------+----------------------------------------------------------------------+
| Metrics/EpRet | Average return of the epoch. |
+-------------------------+----------------------------------------------------------------------+
| Metrics/EpLen | Average length of the epoch. |
+-------------------------+----------------------------------------------------------------------+
| Metrics/TestEpCost | Average cost of the evaluate epoch. |
+-------------------------+----------------------------------------------------------------------+
| Metrics/TestEpRet | Average return of the evaluate epoch. |
+-------------------------+----------------------------------------------------------------------+
| Metrics/TestEpLen | Average length of the evaluate epoch. |
+-------------------------+----------------------------------------------------------------------+
| Value/reward_critic | Average value in :meth:`rollout` (from critic network) of the epoch. |
+-------------------------+----------------------------------------------------------------------+
| Values/cost_critic | Average cost in :meth:`rollout` (from critic network) of the epoch. |
+-------------------------+----------------------------------------------------------------------+
| Loss/Loss_pi | Loss of the policy network. |
+-------------------------+----------------------------------------------------------------------+
| Loss/Loss_reward_critic | Loss of the reward critic. |
+-------------------------+----------------------------------------------------------------------+
| Loss/Loss_cost_critic | Loss of the cost critic network. |
+-------------------------+----------------------------------------------------------------------+
| Train/LR | Learning rate of the policy network. |
+-------------------------+----------------------------------------------------------------------+
| Misc/Seed | Seed of the experiment. |
+-------------------------+----------------------------------------------------------------------+
| Misc/TotalEnvSteps | Total steps of the experiment. |
+-------------------------+----------------------------------------------------------------------+
| Time/Total | Total time. |
+-------------------------+----------------------------------------------------------------------+
| Time/Rollout | Rollout time. |
+-------------------------+----------------------------------------------------------------------+
| Time/Update | Update time. |
+-------------------------+----------------------------------------------------------------------+
| Time/Evaluate | Evaluate time. |
+-------------------------+----------------------------------------------------------------------+
| FPS | Frames per second of the epoch. |
+-------------------------+----------------------------------------------------------------------+
"""
self._logger: Logger = Logger(
output_dir=self._cfgs.logger_cfgs.log_dir,
exp_name=self._cfgs.exp_name,
seed=self._cfgs.seed,
use_tensorboard=self._cfgs.logger_cfgs.use_tensorboard,
use_wandb=self._cfgs.logger_cfgs.use_wandb,
config=self._cfgs,
)
what_to_save: dict[str, Any] = {}
what_to_save['pi'] = self._actor_critic.actor
if self._cfgs.algo_cfgs.obs_normalize:
obs_normalizer = self._env.save()['obs_normalizer']
what_to_save['obs_normalizer'] = obs_normalizer
self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()
self._logger.register_key(
'Metrics/EpRet',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/EpCost',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/EpLen',
window_length=self._cfgs.logger_cfgs.window_lens,
)
if self._cfgs.train_cfgs.eval_episodes > 0:
self._logger.register_key(
'Metrics/TestEpRet',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/TestEpCost',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key(
'Metrics/TestEpLen',
window_length=self._cfgs.logger_cfgs.window_lens,
)
self._logger.register_key('Train/Epoch')
self._logger.register_key('Train/LR')
self._logger.register_key('TotalEnvSteps')
# log information about actor
self._logger.register_key('Loss/Loss_pi', delta=True)
# log information about critic
self._logger.register_key('Loss/Loss_reward_critic', delta=True)
self._logger.register_key('Value/reward_critic')
if self._cfgs.algo_cfgs.use_cost:
# log information about cost critic
self._logger.register_key('Loss/Loss_cost_critic', delta=True)
self._logger.register_key('Value/cost_critic')
self._logger.register_key('Time/Total')
self._logger.register_key('Time/Rollout')
self._logger.register_key('Time/Update')
self._logger.register_key('Time/Evaluate')
self._logger.register_key('Time/Epoch')
self._logger.register_key('Time/FPS')
# register environment specific keys
for env_spec_key in self._env.env_spec_keys:
self.logger.register_key(env_spec_key)
[docs] def learn(self) -> tuple[float, float, float]:
"""This is main function for algorithm update.
It is divided into the following steps:
- :meth:`rollout`: collect interactive data from environment.
- :meth:`update`: perform actor/critic updates.
- :meth:`log`: epoch/update information for visualization and terminal log print.
Returns:
ep_ret: average episode return in final epoch.
ep_cost: average episode cost in final epoch.
ep_len: average episode length in final epoch.
"""
self._logger.log('INFO: Start training')
start_time = time.time()
step = 0
for epoch in range(self._epochs):
self._epoch = epoch
rollout_time = 0.0
update_time = 0.0
epoch_time = time.time()
for sample_step in range(
epoch * self._samples_per_epoch,
(epoch + 1) * self._samples_per_epoch,
):
step = sample_step * self._update_cycle * self._cfgs.train_cfgs.vector_env_nums
rollout_start = time.time()
# set noise for exploration
if self._cfgs.algo_cfgs.use_exploration_noise:
self._actor_critic.actor.noise = self._cfgs.algo_cfgs.exploration_noise
# collect data from environment
self._env.rollout(
rollout_step=self._update_cycle,
agent=self._actor_critic,
buffer=self._buf,
logger=self._logger,
use_rand_action=(step <= self._cfgs.algo_cfgs.start_learning_steps),
)
rollout_time += time.time() - rollout_start
# update parameters
update_start = time.time()
if step > self._cfgs.algo_cfgs.start_learning_steps:
self._update()
# if we haven't updated the network, log 0 for the loss
else:
self._log_when_not_update()
update_time += time.time() - update_start
eval_start = time.time()
self._env.eval_policy(
episode=self._cfgs.train_cfgs.eval_episodes,
agent=self._actor_critic,
logger=self._logger,
)
eval_time = time.time() - eval_start
self._logger.store({'Time/Update': update_time})
self._logger.store({'Time/Rollout': rollout_time})
self._logger.store({'Time/Evaluate': eval_time})
if (
step > self._cfgs.algo_cfgs.start_learning_steps
and self._cfgs.model_cfgs.linear_lr_decay
):
self._actor_critic.actor_scheduler.step()
self._logger.store(
{
'TotalEnvSteps': step + 1,
'Time/FPS': self._cfgs.algo_cfgs.steps_per_epoch / (time.time() - epoch_time),
'Time/Total': (time.time() - start_time),
'Time/Epoch': (time.time() - epoch_time),
'Train/Epoch': epoch,
'Train/LR': self._actor_critic.actor_scheduler.get_last_lr()[0],
},
)
self._logger.dump_tabular()
# save model to disk
if (epoch + 1) % self._cfgs.logger_cfgs.save_model_freq == 0:
self._logger.torch_save()
ep_ret = self._logger.get_stats('Metrics/EpRet')[0]
ep_cost = self._logger.get_stats('Metrics/EpCost')[0]
ep_len = self._logger.get_stats('Metrics/EpLen')[0]
self._logger.close()
self._env.close()
return ep_ret, ep_cost, ep_len
[docs] def _update(self) -> None:
"""Update actor, critic.
- Get the ``data`` from buffer
.. note::
+----------+---------------------------------------+
| obs | ``observaion`` stored in buffer. |
+==========+=======================================+
| act | ``action`` stored in buffer. |
+----------+---------------------------------------+
| reward | ``reward`` stored in buffer. |
+----------+---------------------------------------+
| cost | ``cost`` stored in buffer. |
+----------+---------------------------------------+
| next_obs | ``next observaion`` stored in buffer. |
+----------+---------------------------------------+
| done | ``terminated`` stored in buffer. |
+----------+---------------------------------------+
- Update value net by :meth:`_update_reward_critic`.
- Update cost net by :meth:`_update_cost_critic`.
- Update policy net by :meth:`_update_actor`.
The basic process of each update is as follows:
#. Get the mini-batch data from buffer.
#. Get the loss of network.
#. Update the network by loss.
#. Repeat steps 2, 3 until the ``update_iters`` times.
"""
for _ in range(self._cfgs.algo_cfgs.update_iters):
data = self._buf.sample_batch()
self._update_count += 1
obs, act, reward, cost, done, next_obs = (
data['obs'],
data['act'],
data['reward'],
data['cost'],
data['done'],
data['next_obs'],
)
self._update_reward_critic(obs, act, reward, done, next_obs)
if self._cfgs.algo_cfgs.use_cost:
self._update_cost_critic(obs, act, cost, done, next_obs)
if self._update_count % self._cfgs.algo_cfgs.policy_delay == 0:
self._update_actor(obs)
self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak)
[docs] def _update_reward_critic(
self,
obs: torch.Tensor,
action: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
next_obs: torch.Tensor,
) -> None:
"""Update reward critic.
- Get the TD loss of reward critic.
- Update critic network by loss.
- Log useful information.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
action (torch.Tensor): The ``action`` sampled from buffer.
reward (torch.Tensor): The ``reward`` sampled from buffer.
done (torch.Tensor): The ``terminated`` sampled from buffer.
next_obs (torch.Tensor): The ``next observation`` sampled from buffer.
"""
with torch.no_grad():
next_action = self._actor_critic.actor.predict(next_obs, deterministic=True)
next_q_value_r = self._actor_critic.target_reward_critic(next_obs, next_action)[0]
target_q_value_r = reward + self._cfgs.algo_cfgs.gamma * (1 - done) * next_q_value_r
q_value_r = self._actor_critic.reward_critic(obs, action)[0]
loss = nn.functional.mse_loss(q_value_r, target_q_value_r)
if self._cfgs.algo_cfgs.use_critic_norm:
for param in self._actor_critic.reward_critic.parameters():
loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coeff
self._logger.store(
{
'Loss/Loss_reward_critic': loss.mean().item(),
'Value/reward_critic': q_value_r.mean().item(),
},
)
self._actor_critic.reward_critic_optimizer.zero_grad()
loss.backward()
if self._cfgs.algo_cfgs.max_grad_norm:
clip_grad_norm_(
self._actor_critic.reward_critic.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
self._actor_critic.reward_critic_optimizer.step()
[docs] def _update_cost_critic(
self,
obs: torch.Tensor,
action: torch.Tensor,
cost: torch.Tensor,
done: torch.Tensor,
next_obs: torch.Tensor,
) -> None:
"""Update cost critic.
- Get the TD loss of cost critic.
- Update critic network by loss.
- Log useful information.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
action (torch.Tensor): The ``action`` sampled from buffer.
cost (torch.Tensor): The ``cost`` sampled from buffer.
done (torch.Tensor): The ``terminated`` sampled from buffer.
next_obs (torch.Tensor): The ``next observation`` sampled from buffer.
"""
with torch.no_grad():
next_action = self._actor_critic.actor.predict(next_obs, deterministic=True)
next_q_value_c = self._actor_critic.target_cost_critic(next_obs, next_action)[0]
target_q_value_c = cost + self._cfgs.algo_cfgs.gamma * (1 - done) * next_q_value_c
q_value_c = self._actor_critic.cost_critic(obs, action)[0]
loss = nn.functional.mse_loss(q_value_c, target_q_value_c)
if self._cfgs.algo_cfgs.use_critic_norm:
for param in self._actor_critic.cost_critic.parameters():
loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coeff
self._actor_critic.cost_critic_optimizer.zero_grad()
loss.backward()
if self._cfgs.algo_cfgs.max_grad_norm:
clip_grad_norm_(
self._actor_critic.cost_critic.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
self._actor_critic.cost_critic_optimizer.step()
self._logger.store(
{
'Loss/Loss_cost_critic': loss.mean().item(),
'Value/cost_critic': q_value_c.mean().item(),
},
)
[docs] def _update_actor( # pylint: disable=too-many-arguments
self,
obs: torch.Tensor,
) -> None:
"""Update actor.
- Get the loss of actor.
- Update actor by loss.
- Log useful information.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
"""
loss = self._loss_pi(obs)
self._actor_critic.actor_optimizer.zero_grad()
loss.backward()
if self._cfgs.algo_cfgs.max_grad_norm:
clip_grad_norm_(
self._actor_critic.actor.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
self._actor_critic.actor_optimizer.step()
self._logger.store(
{
'Loss/Loss_pi': loss.mean().item(),
},
)
[docs] def _loss_pi(
self,
obs: torch.Tensor,
) -> torch.Tensor:
r"""Computing ``pi/actor`` loss.
The loss function in DDPG is defined as:
.. math::
L = -Q^V (s, \pi (s))
where :math:`Q^V` is the reward critic network, and :math:`\pi` is the policy network.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
Returns:
The loss of pi/actor.
"""
action = self._actor_critic.actor.predict(obs, deterministic=True)
return -self._actor_critic.reward_critic(obs, action)[0].mean()
[docs] def _log_when_not_update(self) -> None:
"""Log default value when not update."""
self._logger.store(
{
'Loss/Loss_reward_critic': 0.0,
'Loss/Loss_pi': 0.0,
'Value/reward_critic': 0.0,
},
)
if self._cfgs.algo_cfgs.use_cost:
self._logger.store(
{
'Loss/Loss_cost_critic': 0.0,
'Value/cost_critic': 0.0,
},
)