# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Learning Off-Policy with Online Planning algorithm."""
from __future__ import annotations
from typing import Any
import torch
from gymnasium.spaces import Box
from torch import nn, optim
from torch.nn.utils.clip_grad import clip_grad_norm_
from omnisafe.algorithms import registry
from omnisafe.algorithms.model_based.base.ensemble import EnsembleDynamicsModel
from omnisafe.algorithms.model_based.base.pets import PETS
from omnisafe.algorithms.model_based.planner.arc import ARCPlanner
from omnisafe.common.buffer import OffPolicyBuffer
from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic
from omnisafe.typing import OmnisafeSpace
[docs]@registry.register
# pylint: disable-next=too-many-instance-attributes, too-few-public-methods
class LOOP(PETS):
"""The Learning Off-Policy with Online Planning (LOOP) algorithm.
References:
- Title: Learning Off-Policy with Online Planning
- Authors: Harshit Sikchi, Wenxuan Zhou, David Held.
- URL: `LOOP <https://arxiv.org/abs/2008.10066>`_
"""
_log_alpha: torch.Tensor
_alpha_optimizer: optim.Optimizer
_target_entropy: float
[docs] def _init_model(self) -> None:
"""Initialize the dynamics model and the planner.
LOOP uses following models:
- dynamics model: to predict the next state and the cost.
- actor_critic: to predict the action and the value.
- planner: to generate the action.
"""
self._dynamics_state_space: OmnisafeSpace = (
self._env.coordinate_observation_space
if self._env.coordinate_observation_space is not None
else self._env.observation_space
)
assert self._dynamics_state_space is not None and isinstance(
self._dynamics_state_space.shape,
tuple,
)
assert self._env.action_space is not None and isinstance(
self._env.action_space.shape,
tuple,
)
if isinstance(self._env.action_space, Box):
self._action_space = self._env.action_space
else:
raise NotImplementedError
self._actor_critic: ConstraintActorQCritic = ConstraintActorQCritic(
obs_space=self._dynamics_state_space,
act_space=self._env.action_space,
model_cfgs=self._cfgs.model_cfgs,
epochs=self._epochs,
).to(self._device)
self._use_actor_critic: bool = True
self._update_count: int = 0
self._dynamics: EnsembleDynamicsModel = EnsembleDynamicsModel(
model_cfgs=self._cfgs.dynamics_cfgs,
device=self._device,
state_shape=self._dynamics_state_space.shape,
action_shape=self._env.action_space.shape,
actor_critic=self._actor_critic,
rew_func=None,
cost_func=None,
terminal_func=None,
)
self._update_dynamics_cycle = int(self._cfgs.algo_cfgs.update_dynamics_cycle)
self._planner: ARCPlanner = ARCPlanner(
dynamics=self._dynamics,
planner_cfgs=self._cfgs.planner_cfgs,
gamma=float(self._cfgs.algo_cfgs.gamma),
cost_gamma=float(self._cfgs.algo_cfgs.cost_gamma),
dynamics_state_shape=self._dynamics_state_space.shape,
action_shape=self._action_space.shape,
action_max=1.0,
action_min=-1.0,
device=self._device,
actor_critic=self._actor_critic,
)
[docs] def _init(self) -> None:
"""The initialization of the algorithm.
User can define the initialization of the algorithm by inheriting this method.
Examples:
>>> def _init(self) -> None:
... super()._init()
... self._buffer = CustomBuffer()
... self._model = CustomModel()
"""
super()._init()
self._alpha = self._cfgs.algo_cfgs.alpha
self._alpha_gamma = self._cfgs.algo_cfgs.alpha_gamma
self._policy_buf = OffPolicyBuffer(
obs_space=self._dynamics_state_space,
act_space=self._env.action_space,
size=self._cfgs.train_cfgs.total_steps,
batch_size=self._cfgs.algo_cfgs.policy_batch_size,
device=self._device,
)
[docs] def _alpha_discount(self) -> None:
"""Alpha discount."""
self._alpha *= self._alpha_gamma
[docs] def _init_log(self) -> None:
"""Initialize logger.
+-------------------------+----------------------------------------------------------------------+
| Things to log | Description |
+=========================+======================================================================+
| Value/alpha | The value of alpha. |
+-------------------------+----------------------------------------------------------------------+
| Values/reward_critic | Average value in :meth:`rollout` (from critic network) of the epoch. |
+-------------------------+----------------------------------------------------------------------+
| Values/cost_critic | Average cost in :meth:`rollout` (from critic network) of the epoch. |
+-------------------------+----------------------------------------------------------------------+
| Loss/Loss_cost_critic | Loss of the cost critic network. |
+-------------------------+----------------------------------------------------------------------+
| Loss/Loss_reward_critic | Loss of the cost critic network. |
+-------------------------+----------------------------------------------------------------------+
| Loss/Loss_pi | Loss of the policy network. |
+-------------------------+----------------------------------------------------------------------+
"""
super()._init_log()
self._logger.register_key('Value/alpha')
# log information about actor
self._logger.register_key('Loss/Loss_pi', delta=True)
# log information about critic
self._logger.register_key('Loss/Loss_reward_critic', delta=True)
self._logger.register_key('Value/reward_critic')
if self._cfgs.algo_cfgs.use_cost:
# log information about cost critic
self._logger.register_key('Loss/Loss_cost_critic', delta=True)
self._logger.register_key('Value/cost_critic')
[docs] def _save_model(self) -> None:
"""Save the model."""
what_to_save: dict[str, Any] = {}
# set up model saving
what_to_save = {
'dynamics': self._dynamics.ensemble_model,
'actor_critic': self._actor_critic,
}
if self._cfgs.algo_cfgs.obs_normalize:
obs_normalizer = self._env.save()['obs_normalizer']
what_to_save['obs_normalizer'] = obs_normalizer
self._logger.setup_torch_saver(what_to_save)
self._logger.torch_save()
[docs] def _select_action( # pylint: disable=unused-argument
self,
current_step: int,
state: torch.Tensor,
) -> torch.Tensor:
"""Select action.
Args:
current_step (int): The current step.
state (torch.Tensor): The current state.
Returns:
The selected action.
"""
if current_step < self._cfgs.algo_cfgs.start_learning_steps:
action = torch.tensor(self._env.action_space.sample()).to(self._device).unsqueeze(0)
else:
action, info = self._planner.output_action(state)
self._logger.store(**info)
assert action.shape == torch.Size(
[1, *self._action_space.shape],
), 'action shape should be [batch_size, action_dim]'
return action
[docs] def _update_policy(self, current_step: int) -> None:
"""Update policy.
- Get the ``data`` from buffer
.. note::
+----------+---------------------------------------+
| obs | ``observaion`` stored in buffer. |
+==========+=======================================+
| act | ``action`` stored in buffer. |
+----------+---------------------------------------+
| reward | ``reward`` stored in buffer. |
+----------+---------------------------------------+
| cost | ``cost`` stored in buffer. |
+----------+---------------------------------------+
| next_obs | ``next observaion`` stored in buffer. |
+----------+---------------------------------------+
| done | ``terminated`` stored in buffer. |
+----------+---------------------------------------+
- Update value net by :meth:`_update_reward_critic`.
- Update cost net by :meth:`_update_cost_critic`.
- Update policy net by :meth:`_update_actor`.
The basic process of each update is as follows:
#. Get the mini-batch data from buffer.
#. Get the loss of network.
#. Update the network by loss.
#. Repeat steps 2, 3 until the ``update_policy_iters`` times.
Args:
current_step (int): The current step.
"""
if current_step >= self._cfgs.algo_cfgs.start_learning_steps:
for _step in range(self._cfgs.algo_cfgs.update_policy_iters):
self._update_count += 1
data = self._policy_buf.sample_batch()
obs, act, reward, cost, done, next_obs = (
data['obs'],
data['act'],
data['reward'],
data['cost'],
data['done'],
data['next_obs'],
)
self._update_reward_critic(obs, act, reward, done, next_obs)
if self._cfgs.algo_cfgs.use_cost:
self._update_cost_critic(obs, act, cost, done, next_obs)
if self._update_count % self._cfgs.algo_cfgs.policy_delay == 0:
# freeze Q-network so you don't waste computational effort
# computing gradients for it during the policy learning step
for param in self._actor_critic.reward_critic.parameters():
param.requires_grad = False
if self._cfgs.algo_cfgs.use_cost:
for param in self._actor_critic.cost_critic.parameters():
param.requires_grad = False
self._update_actor(obs)
# unfreeze Q-network so you can optimize it at next DDPG step.
for param in self._actor_critic.reward_critic.parameters():
param.requires_grad = True
if self._cfgs.algo_cfgs.use_cost:
for param in self._actor_critic.cost_critic.parameters():
param.requires_grad = True
self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak)
if self._cfgs.algo_cfgs.alpha_discount:
self._alpha_discount()
[docs] def _store_real_data( # pylint: disable=too-many-arguments,unused-argument
self,
state: torch.Tensor,
action: torch.Tensor,
reward: torch.Tensor,
cost: torch.Tensor,
terminated: torch.Tensor,
truncated: torch.Tensor,
next_state: torch.Tensor,
info: dict[str, Any],
) -> None: # pylint: disable=too-many-arguments
"""Store real data in buffer.
Args:
state (torch.Tensor): The state from the environment.
action (torch.Tensor): The action from the agent.
reward (torch.Tensor): The reward signal from the environment.
cost (torch.Tensor): The cost signal from the environment.
terminated (torch.Tensor): The terminated signal from the environment.
truncated (torch.Tensor): The truncated signal from the environment.
next_state (torch.Tensor): The next state from the environment.
info (dict[str, Any]): The information from the environment.
"""
done = terminated or truncated
goal_met = info.get('goal_met', False)
if not done and not goal_met:
# when goal_met == true:
# current goal position is not related to the last goal position,
# this huge transition will confuse the dynamics model.
self._dynamics_buf.store(
obs=state,
act=action,
reward=reward,
cost=cost,
next_obs=next_state,
done=done,
)
if (done and self._cfgs.algo_cfgs.policy_store_done) or (not done and not goal_met):
self._policy_buf.store(
obs=state,
act=action,
reward=reward,
cost=cost,
next_obs=next_state,
done=done,
)
[docs] def _update_reward_critic(
self,
obs: torch.Tensor,
action: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
next_obs: torch.Tensor,
) -> None:
"""Update reward critic using Soft Actor-Critic.
- Get the TD loss of reward critic.
- Update critic network by loss.
- Log useful information.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
action (torch.Tensor): The ``action`` sampled from buffer.
reward (torch.Tensor): The ``reward`` sampled from buffer.
done (torch.Tensor): The ``terminated`` sampled from buffer.
next_obs (torch.Tensor): The ``next observation`` sampled from buffer.
"""
self._actor_critic.reward_critic_optimizer.zero_grad()
with torch.no_grad():
next_action = self._actor_critic.actor.predict(next_obs, deterministic=False)
next_logp = self._actor_critic.actor.log_prob(next_action)
next_q1_value_r, next_q2_value_r = self._actor_critic.target_reward_critic(
next_obs,
next_action,
)
next_q_value_r = torch.min(next_q1_value_r, next_q2_value_r) - next_logp * self._alpha
target_q_value_r = reward + self._cfgs.algo_cfgs.gamma * (1 - done) * next_q_value_r
q1_value_r, q2_value_r = self._actor_critic.reward_critic(obs, action)
loss = nn.functional.mse_loss(q1_value_r, target_q_value_r) + nn.functional.mse_loss(
q2_value_r,
target_q_value_r,
)
if self._cfgs.algo_cfgs.use_critic_norm:
for param in self._actor_critic.reward_critic.parameters():
loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coeff
loss.backward()
if self._cfgs.algo_cfgs.use_grad_norm:
clip_grad_norm_(
self._actor_critic.reward_critic.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
self._actor_critic.reward_critic_optimizer.step()
self._logger.store(
**{
'Loss/Loss_reward_critic': loss.mean().item(),
'Value/reward_critic': q1_value_r.mean().item(),
},
)
[docs] def _update_cost_critic(
self,
obs: torch.Tensor,
action: torch.Tensor,
cost: torch.Tensor,
done: torch.Tensor,
next_obs: torch.Tensor,
) -> None:
"""Update cost critic using TD3 algorithm.
- Get the TD loss of cost critic.
- Update critic network by loss.
- Log useful information.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
action (torch.Tensor): The ``action`` sampled from buffer.
cost (torch.Tensor): The ``cost`` sampled from buffer.
done (torch.Tensor): The ``terminated`` sampled from buffer.
next_obs (torch.Tensor): The ``next observation`` sampled from buffer.
"""
with torch.no_grad():
next_action = self._actor_critic.actor.predict(next_obs, deterministic=True)
next_q_value_c = self._actor_critic.target_cost_critic(next_obs, next_action)[0]
target_q_value_c = cost + self._cfgs.algo_cfgs.gamma * (1 - done) * next_q_value_c
q_value_c = self._actor_critic.cost_critic(obs, action)[0]
loss = nn.functional.mse_loss(q_value_c, target_q_value_c)
if self._cfgs.algo_cfgs.use_critic_norm:
for param in self._actor_critic.cost_critic.parameters():
loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coeff
self._actor_critic.cost_critic_optimizer.zero_grad()
loss.backward()
if self._cfgs.algo_cfgs.use_grad_norm:
clip_grad_norm_(
self._actor_critic.cost_critic.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
self._actor_critic.cost_critic_optimizer.step()
self._logger.store(
**{
'Loss/Loss_cost_critic': loss.mean().item(),
'Value/cost_critic': q_value_c.mean().item(),
},
)
[docs] def _update_actor(
self,
obs: torch.Tensor,
) -> None:
"""Update actor using Soft Actor-Critic algorithm.
- Get the loss of actor.
- Update actor by loss.
- Log useful information.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
"""
self._actor_critic.actor_optimizer.zero_grad()
loss = self._loss_pi(obs)
loss.backward()
if self._cfgs.algo_cfgs.use_grad_norm:
clip_grad_norm_(
self._actor_critic.actor.parameters(),
self._cfgs.algo_cfgs.max_grad_norm,
)
self._actor_critic.actor_optimizer.step()
self._logger.store(
**{
'Loss/Loss_pi': loss.mean().item(),
},
)
self._logger.store(
**{
'Value/alpha': self._alpha,
},
)
[docs] def _loss_pi(
self,
obs: torch.Tensor,
) -> torch.Tensor:
r"""Computing ``pi/actor`` loss.
The loss function in SAC is defined as:
.. math::
L = -Q^V (s, \pi (s)) + \alpha \log \pi (s)
where :math:`Q^V` is the min value of two reward critic networks, and :math:`\pi` is the
policy network, and :math:`\alpha` is the temperature parameter.
Args:
obs (torch.Tensor): The ``observation`` sampled from buffer.
Returns:
The loss of pi/actor.
"""
action = self._actor_critic.actor.predict(
obs,
deterministic=self._cfgs.algo_cfgs.loss_pi_deterministic,
)
log_prob = self._actor_critic.actor.log_prob(action)
q1_value_r, q2_value_r = self._actor_critic.reward_critic(obs, action)
return (self._alpha * log_prob - torch.min(q1_value_r, q2_value_r)).mean()