# Copyright 2023 OmniSafe Team. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Implementation of PID Lagrange."""from__future__importannotationsimportabcfromcollectionsimportdeque# pylint: disable-next=too-few-public-methods,too-many-instance-attributes
[docs]classPIDLagrangian(abc.ABC):# noqa: B024"""PID version of Lagrangian. Similar to the :class:`Lagrange` module, this module implements the PID version of the lagrangian method. .. note:: The PID-Lagrange is more general than the Lagrange, and can be used in any policy gradient algorithm. As PID_Lagrange use the PID controller to control the lagrangian multiplier, it is more stable than the naive Lagrange. Args: pid_kp (float): The proportional gain of the PID controller. pid_ki (float): The integral gain of the PID controller. pid_kd (float): The derivative gain of the PID controller. pid_d_delay (int): The delay of the derivative term. pid_delta_p_ema_alpha (float): The exponential moving average alpha of the delta_p. pid_delta_d_ema_alpha (float): The exponential moving average alpha of the delta_d. sum_norm (bool): Whether to use the sum norm. diff_norm (bool): Whether to use the diff norm. penalty_max (int): The maximum penalty. lagrangian_multiplier_init (float): The initial value of the lagrangian multiplier. cost_limit (float): The cost limit. References: - Title: Responsive Safety in Reinforcement Learning by PID Lagrangian Methods - Authors: Adam Stooke, Joshua Achiam, Pieter Abbeel. - URL: `PID Lagrange <https://arxiv.org/abs/2007.03964>`_ """# pylint: disable-next=too-many-argumentsdef__init__(self,pid_kp:float,pid_ki:float,pid_kd:float,pid_d_delay:int,pid_delta_p_ema_alpha:float,pid_delta_d_ema_alpha:float,sum_norm:bool,diff_norm:bool,penalty_max:int,lagrangian_multiplier_init:float,cost_limit:float,)->None:"""Initialize an instance of :class:`PIDLagrangian`."""self._pid_kp:float=pid_kpself._pid_ki:float=pid_kiself._pid_kd:float=pid_kdself._pid_d_delay=pid_d_delayself._pid_delta_p_ema_alpha:float=pid_delta_p_ema_alphaself._pid_delta_d_ema_alpha:float=pid_delta_d_ema_alphaself._penalty_max:int=penalty_maxself._sum_norm:bool=sum_normself._diff_norm:bool=diff_normself._pid_i:float=lagrangian_multiplier_initself._cost_ds:deque[float]=deque(maxlen=self._pid_d_delay)self._cost_ds.append(0.0)self._delta_p:float=0.0self._cost_d:float=0.0self._cost_limit:float=cost_limitself._cost_penalty:float=0.0@propertydeflagrangian_multiplier(self)->float:"""The lagrangian multiplier."""returnself._cost_penalty
[docs]defpid_update(self,ep_cost_avg:float)->None:r"""Update the PID controller. PID controller update the lagrangian multiplier following the next equation: .. math:: \lambda_{t+1} = \lambda_t + (K_p e_p + K_i \int e_p dt + K_d \frac{d e_p}{d t}) \eta where :math:`e_p` is the error between the current episode cost and the cost limit, :math:`K_p`, :math:`K_i`, :math:`K_d` are the PID parameters, and :math:`\eta` is the learning rate. Args: ep_cost_avg (float): The average cost of the current episode. """delta=float(ep_cost_avg-self._cost_limit)self._pid_i=max(0.0,self._pid_i+delta*self._pid_ki)ifself._diff_norm:self._pid_i=max(0.0,min(1.0,self._pid_i))a_p=self._pid_delta_p_ema_alphaself._delta_p*=a_pself._delta_p+=(1-a_p)*deltaa_d=self._pid_delta_d_ema_alphaself._cost_d*=a_dself._cost_d+=(1-a_d)*float(ep_cost_avg)pid_d=max(0.0,self._cost_d-self._cost_ds[0])pid_o=self._pid_kp*self._delta_p+self._pid_i+self._pid_kd*pid_dself._cost_penalty=max(0.0,pid_o)ifself._diff_norm:self._cost_penalty=min(1.0,self._cost_penalty)ifnot(self._diff_normorself._sum_norm):self._cost_penalty=min(self._cost_penalty,self._penalty_max)self._cost_ds.append(self._cost_d)