Source code for omnisafe.common.buffer.onpolicy_buffer

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of OnPolicyBuffer."""

from __future__ import annotations

import torch

from omnisafe.common.buffer.base import BaseBuffer
from omnisafe.typing import DEVICE_CPU, AdvatageEstimator, OmnisafeSpace
from omnisafe.utils import distributed
from omnisafe.utils.math import discount_cumsum


[docs]class OnPolicyBuffer(BaseBuffer): # pylint: disable=too-many-instance-attributes """A buffer for storing trajectories experienced by an agent interacting with the environment. Besides, The buffer also provides the functionality of calculating the advantages of state-action pairs, ranging from ``GAE``, ``GAE-RTG`` , ``V-trace`` to ``Plain`` method. .. warning:: The buffer only supports Box spaces. Compared to the base buffer, the on-policy buffer stores extra data: +----------------+---------+---------------+----------------------------------------+ | Name | Shape | Dtype | Shape | +================+=========+===============+========================================+ | discounted_ret | (size,) | torch.float32 | The discounted sum of return. | +----------------+---------+---------------+----------------------------------------+ | value_r | (size,) | torch.float32 | The value estimated by reward critic. | +----------------+---------+---------------+----------------------------------------+ | value_c | (size,) | torch.float32 | The value estimated by cost critic. | +----------------+---------+---------------+----------------------------------------+ | adv_r | (size,) | torch.float32 | The advantage of the reward. | +----------------+---------+---------------+----------------------------------------+ | adv_c | (size,) | torch.float32 | The advantage of the cost. | +----------------+---------+---------------+----------------------------------------+ | target_value_r | (size,) | torch.float32 | The target value of the reward critic. | +----------------+---------+---------------+----------------------------------------+ | target_value_c | (size,) | torch.float32 | The target value of the cost critic. | +----------------+---------+---------------+----------------------------------------+ | logp | (size,) | torch.float32 | The log probability of the action. | +----------------+---------+---------------+----------------------------------------+ Args: obs_space (OmnisafeSpace): The observation space. act_space (OmnisafeSpace): The action space. size (int): The size of the buffer. gamma (float): The discount factor. lam (float): The lambda factor for calculating the advantages. lam_c (float): The lambda factor for calculating the advantages of the critic. advantage_estimator (AdvatageEstimator): The advantage estimator. penalty_coefficient (float, optional): The penalty coefficient. Defaults to 0. standardized_adv_r (bool, optional): Whether to standardize the advantages of the actor. Defaults to False. standardized_adv_c (bool, optional): Whether to standardize the advantages of the critic. Defaults to False. device (torch.device, optional): The device to store the data. Defaults to ``torch.device('cpu')``. Attributes: ptr (int): The pointer of the buffer. path_start (int): The start index of the current path. max_size (int): The maximum size of the buffer. data (dict): The data stored in the buffer. obs_space (OmnisafeSpace): The observation space. act_space (OmnisafeSpace): The action space. device (torch.device): The device to store the data. """ def __init__( # pylint: disable=too-many-arguments self, obs_space: OmnisafeSpace, act_space: OmnisafeSpace, size: int, gamma: float, lam: float, lam_c: float, advantage_estimator: AdvatageEstimator, penalty_coefficient: float = 0, standardized_adv_r: bool = False, standardized_adv_c: bool = False, device: torch.device = DEVICE_CPU, ) -> None: """Initialize an instance of :class:`OnPolicyBuffer`.""" super().__init__(obs_space, act_space, size, device) self._standardized_adv_r: bool = standardized_adv_r self._standardized_adv_c: bool = standardized_adv_c self.data['adv_r'] = torch.zeros((size,), dtype=torch.float32, device=device) self.data['discounted_ret'] = torch.zeros((size,), dtype=torch.float32, device=device) self.data['value_r'] = torch.zeros((size,), dtype=torch.float32, device=device) self.data['target_value_r'] = torch.zeros((size,), dtype=torch.float32, device=device) self.data['adv_c'] = torch.zeros((size,), dtype=torch.float32, device=device) self.data['value_c'] = torch.zeros((size,), dtype=torch.float32, device=device) self.data['target_value_c'] = torch.zeros((size,), dtype=torch.float32, device=device) self.data['logp'] = torch.zeros((size,), dtype=torch.float32, device=device) self._gamma: float = gamma self._lam: float = lam self._lam_c: float = lam_c self._penalty_coefficient: float = penalty_coefficient self._advantage_estimator: AdvatageEstimator = advantage_estimator self.ptr: int = 0 self.path_start_idx: int = 0 self.max_size: int = size assert self._penalty_coefficient >= 0, 'penalty_coefficient must be non-negative!' assert self._advantage_estimator in ['gae', 'gae-rtg', 'vtrace', 'plain'] @property def standardized_adv_r(self) -> bool: """Whether to standardize the advantages of the actor.""" return self._standardized_adv_r @property def standardized_adv_c(self) -> bool: """Whether to standardize the advantages of the critic.""" return self._standardized_adv_c
[docs] def store(self, **data: torch.Tensor) -> None: """Store data into the buffer. .. warning:: The total size of the data must be less than the buffer size. Args: data (torch.Tensor): The data to store. """ assert self.ptr < self.max_size, 'No more space in the buffer!' for key, value in data.items(): self.data[key][self.ptr] = value self.ptr += 1
[docs] def finish_path( self, last_value_r: torch.Tensor | None = None, last_value_c: torch.Tensor | None = None, ) -> None: """Finish the current path and calculate the advantages of state-action pairs. On-policy algorithms need to calculate the advantages of state-action pairs after the path is finished. This function calculates the advantages of state-action pairs and stores them in the buffer, following the steps: .. hint:: #. Calculate the discounted return. #. Calculate the advantages of the reward. #. Calculate the advantages of the cost. Args: last_value_r (torch.Tensor, optional): The value of the last state of the current path. Defaults to torch.zeros(1). last_value_c (torch.Tensor, optional): The value of the last state of the current path. Defaults to torch.zeros(1). """ if last_value_r is None: last_value_r = torch.zeros(1, device=self._device) if last_value_c is None: last_value_c = torch.zeros(1, device=self._device) path_slice = slice(self.path_start_idx, self.ptr) last_value_r = last_value_r.to(self._device) last_value_c = last_value_c.to(self._device) rewards = torch.cat([self.data['reward'][path_slice], last_value_r]) values_r = torch.cat([self.data['value_r'][path_slice], last_value_r]) costs = torch.cat([self.data['cost'][path_slice], last_value_c]) values_c = torch.cat([self.data['value_c'][path_slice], last_value_c]) discountred_ret = discount_cumsum(rewards, self._gamma)[:-1] self.data['discounted_ret'][path_slice] = discountred_ret rewards -= self._penalty_coefficient * costs adv_r, target_value_r = self._calculate_adv_and_value_targets( values_r, rewards, lam=self._lam, ) adv_c, target_value_c = self._calculate_adv_and_value_targets( values_c, costs, lam=self._lam_c, ) self.data['adv_r'][path_slice] = adv_r self.data['target_value_r'][path_slice] = target_value_r self.data['adv_c'][path_slice] = adv_c self.data['target_value_c'][path_slice] = target_value_c self.path_start_idx = self.ptr
[docs] def get(self) -> dict[str, torch.Tensor]: """Get the data in the buffer. .. hint:: We provide a trick to standardize the advantages of state-action pairs. We calculate the mean and standard deviation of the advantages of state-action pairs and then standardize the advantages of state-action pairs. You can turn on this trick by setting the ``standardized_adv_r`` to ``True``. The same trick is applied to the advantages of the cost. Returns: The data stored and calculated in the buffer. """ self.ptr, self.path_start_idx = 0, 0 data = { 'obs': self.data['obs'], 'act': self.data['act'], 'target_value_r': self.data['target_value_r'], 'adv_r': self.data['adv_r'], 'logp': self.data['logp'], 'discounted_ret': self.data['discounted_ret'], 'adv_c': self.data['adv_c'], 'target_value_c': self.data['target_value_c'], } adv_mean, adv_std, *_ = distributed.dist_statistics_scalar(data['adv_r']) cadv_mean, *_ = distributed.dist_statistics_scalar(data['adv_c']) if self._standardized_adv_r: data['adv_r'] = (data['adv_r'] - adv_mean) / (adv_std + 1e-8) if self._standardized_adv_c: data['adv_c'] = data['adv_c'] - cadv_mean return data
[docs] def _calculate_adv_and_value_targets( self, values: torch.Tensor, rewards: torch.Tensor, lam: float, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Compute the estimated advantage. Three methods are supported: - GAE (Generalized Advantage Estimation) GAE is a variance reduction method for the actor-critic algorithm. It is proposed in the paper `High-Dimensional Continuous Control Using Generalized Advantage Estimation <https://arxiv.org/abs/1506.02438>`_. GAE calculates the advantage using the following formula: .. math:: A_t = \sum_{k=0}^{n-1} (\lambda \gamma)^k \delta_{t+k} where :math:`\delta_{t+k} = r_{t+k} + \gamma*V(s_{t+k+1}) - V(s_{t+k})`. When :math:`\lambda =1`, GAE reduces to the Monte Carlo method, which is unbiased but has high variance. When :math:`\lambda =0`, GAE reduces to the TD(1) method, which is biased but has low variance. - V-trace V-trace is a variance reduction method for the actor-critic algorithm. It is proposed in the paper `IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures <https://arxiv.org/abs/1802.01561>`_. V-trace calculates the advantage using the following formula: .. math:: A_t = \sum_{k=0}^{n-1} (\lambda \gamma)^k \delta_{t+k} + (\lambda \gamma)^n \rho_{t+n} (1 - d_{t+n}) (V(x_{t+n}) - b_{t+n}) where :math:`\delta_{t+k} = r_{t+k} + \gamma*V(s_{t+k+1}) - V(s_{t+k})`, :math:`\rho_{t+k} = \frac{\pi(a_{t+k}|s_{t+k})}{b_{t+k}}`, :math:`b_{t+k}` is the behavior policy, and :math:`d_{t+k}` is the done flag. - Plain Plain method is the original actor-critic algorithm. It is unbiased but has high variance. Args: vals (torch.Tensor): The value of states. rews (torch.Tensor): The reward of states. lam (float): The lambda parameter in GAE formula. Returns: adv (torch.Tensor): The estimated advantage. target_value (torch.Tensor): The target value for the value function. Raises: NotImplementedError: If the advantage estimator is not supported. """ # pylint: disable=line-too-long if self._advantage_estimator == 'gae': # GAE formula: A_t = \sum_{k=0}^{n-1} (lam*gamma)^k delta_{t+k} deltas = rewards[:-1] + self._gamma * values[1:] - values[:-1] adv = discount_cumsum(deltas, self._gamma * lam) target_value = adv + values[:-1] elif self._advantage_estimator == 'gae-rtg': # GAE formula: A_t = \sum_{k=0}^{n-1} (lam*gamma)^k delta_{t+k} deltas = rewards[:-1] + self._gamma * values[1:] - values[:-1] adv = discount_cumsum(deltas, self._gamma * lam) # compute rewards-to-go, to be targets for the value function update target_value = discount_cumsum(rewards, self._gamma)[:-1] elif self._advantage_estimator == 'vtrace': # v_s = V(x_s) + \sum^{T-1}_{t=s} \gamma^{t-s} # * \prod_{i=s}^{t-1} c_i # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t)) path_slice = slice(self.path_start_idx, self.ptr) action_probs = self.data['logp'][path_slice].exp() target_value, adv, _ = self._calculate_v_trace( policy_action_probs=action_probs, values=values, rewards=rewards, behavior_action_probs=action_probs, gamma=self._gamma, rho_bar=1.0, c_bar=1.0, ) elif self._advantage_estimator == 'plain': # A(x, u) = Q(x, u) - V(x) = r(x, u) + gamma V(x+1) - V(x) adv = rewards[:-1] + self._gamma * values[1:] - values[:-1] target_value = discount_cumsum(rewards, self._gamma)[:-1] else: raise NotImplementedError return adv, target_value
[docs] @staticmethod # pylint: disable-next=too-many-arguments,too-many-locals def _calculate_v_trace( policy_action_probs: torch.Tensor, values: torch.Tensor, # including bootstrap rewards: torch.Tensor, # including bootstrap behavior_action_probs: torch.Tensor, gamma: float = 0.99, rho_bar: float = 1.0, c_bar: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""This function is used to calculate V-trace targets. .. math:: A_t = \sum_{k=0}^{n-1} (\lambda \gamma)^k \delta_{t+k} + (\lambda \gamma)^n \rho_{t+n} (1 - d_{t+n}) (V(x_{t+n}) - b_{t+n}) Calculate V-trace targets for off-policy actor-critic learning recursively. For more details, please refer to the paper: `Espeholt et al. 2018, IMPALA <https://arxiv.org/abs/1802.01561>`_. Args: policy_action_probs (torch.Tensor): Action probabilities of the policy. values (torch.Tensor): The value of states. rewards (torch.Tensor): The reward of states. behavior_action_probs (torch.Tensor): Action probabilities of the behavior policy. gamma (float, optional): The discount factor. Defaults to 0.99. rho_bar (float, optional): The maximum value of importance weights. Defaults to 1.0. c_bar (float, optional): The maximum value of clipped importance weights. Defaults to 1.0. Returns: V-trace targets, shape=(batch_size, sequence_length) Raises: AssertionError: If the input tensors are scalars. AssertionError: If c_bar is greater than rho_bar. """ assert values.ndim == 1, 'Please provide arrays instead of scalars' assert rewards.ndim == 1, 'Please provide arrays instead of scalars' assert policy_action_probs.ndim == 1, 'Please provide arrays instead of scalars' assert behavior_action_probs.ndim == 1, 'Please provide arrays instead of scalars' assert c_bar <= rho_bar, 'c_bar should be less than or equal to rho_bar' sequence_length = policy_action_probs.shape[0] # pylint: disable-next=assignment-from-no-return rhos = torch.div(policy_action_probs, behavior_action_probs) clip_rhos = torch.min( rhos, torch.as_tensor(rho_bar), ) # pylint: disable=assignment-from-no-return clip_cs = torch.min( rhos, torch.as_tensor(c_bar), ) # pylint: disable=assignment-from-no-return v_s = values[:-1].clone() # copy all values except bootstrap value last_v_s = values[-1] # bootstrap from last state # calculate v_s for index in reversed(range(sequence_length)): delta = clip_rhos[index] * (rewards[index] + gamma * values[index + 1] - values[index]) v_s[index] += delta + gamma * clip_cs[index] * (last_v_s - values[index + 1]) last_v_s = v_s[index] # accumulate current v_s for next iteration # calculate q_targets v_s_plus_1 = torch.cat((v_s[1:], values[-1:])) policy_advantage = clip_rhos * (rewards[:-1] + gamma * v_s_plus_1 - values[:-1]) return v_s, policy_advantage, clip_rhos