Source code for omnisafe.common.simmer_agent

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Simmer agents."""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections import deque

import torch

from omnisafe.utils.config import Config


[docs]class BaseSimmerAgent(ABC): """Base class for controlling safety budget of Simmer adapter. Args: cfgs (Config): The configurations for the agent. budget_bound (torch.Tensor): The bound of the safety budget. obs_space (tuple[int,], optional): The observation space. Defaults to (0,). action_space (tuple[int, int], optional): The action space. Defaults to (-1, 1). history_len (int, optional): The length of the history. Defaults to 100. """ def __init__( self, cfgs: Config, budget_bound: torch.Tensor, obs_space: tuple[int,] = (0,), action_space: tuple[int, int] = (-1, 1), history_len: int = 100, ) -> None: """Initialize an instance of :class:`BaseSimmerAgent`.""" assert obs_space is not None, 'Please specify the state space for the Simmer agent' assert history_len > 0, 'History length should be positive' self._history_len: int = history_len self._obs_space: tuple[int,] = obs_space self._action_space: tuple[int, int] = action_space self._budget_bound: torch.Tensor = budget_bound self._cfgs: Config = cfgs # history self._error_history: deque[torch.Tensor] = deque([], maxlen=self._history_len) self._reward_history: deque[torch.Tensor] = deque([], maxlen=self._history_len) self._state_history: deque[torch.Tensor] = deque([], maxlen=self._history_len) self._action_history: deque[torch.Tensor] = deque([], maxlen=self._history_len) self._observation_history: deque[torch.Tensor] = deque([], maxlen=self._history_len)
[docs] @abstractmethod def get_greedy_action( self, safety_budget: torch.Tensor, observation: torch.Tensor, ) -> torch.Tensor: """Get the greedy action. Args: safety_budget (torch.Tensor): The current safety budget. observation (torch.Tensor): The current observation. Raises: NotImplementedError: The method is not implemented. """ raise NotImplementedError
[docs] @abstractmethod def act( self, safety_budget: torch.Tensor, observation: torch.Tensor, ) -> torch.Tensor: """Get the action. Args: safety_budget (torch.Tensor): The current safety budget. observation (torch.Tensor): The current observation. Raises: NotImplementedError: The method is not implemented. """ raise NotImplementedError
# pylint: disable-next=too-many-instance-attributes
[docs]class SimmerPIDAgent(BaseSimmerAgent): """Simmer PID agent. Args: cfgs (Config): The configurations for the agent. budget_bound (torch.Tensor): The bound of the safety budget. obs_space (tuple[int,], optional): The observation space. Defaults to (0,). action_space (tuple[int, int], optional): The action space. Defaults to (-1, 1). history_len (int, optional): The length of the history. Defaults to 100. """ def __init__( self, cfgs: Config, budget_bound: torch.Tensor, obs_space: tuple[int,] = (0,), action_space: tuple[int, int] = (-1, 1), history_len: int = 100, ) -> None: """Initialize an instance of :class:`SimmerPIDAgent`.""" super().__init__( cfgs=cfgs, budget_bound=budget_bound, obs_space=obs_space, action_space=action_space, history_len=history_len, ) self._sum_history: torch.Tensor = torch.zeros(1) self._prev_action: torch.Tensor = torch.zeros(1) self._prev_error: torch.Tensor = torch.zeros(1) self._prev_raw_action: torch.Tensor = torch.zeros(1) self._integral_history: deque[torch.Tensor] = deque([], maxlen=10)
[docs] def get_greedy_action( self, safety_budget: torch.Tensor, observation: torch.Tensor, ) -> torch.Tensor: """Get the greedy action. Args: safety_budget (torch.Tensor): The current safety budget. observation (torch.Tensor): The current observation. Returns: The greedy action. """ # compute the error current_error = safety_budget - observation # blur the error blured_error = ( self._cfgs.polyak * self._prev_error + (1 - self._cfgs.polyak) * current_error ) # log the history self._error_history.append(blured_error) # compute the integral self._integral_history.append(blured_error) self._sum_history = torch.as_tensor(sum(self._integral_history)) # proportional part p_part = self._cfgs.kp * blured_error # integral part i_part = self._cfgs.ki * self._sum_history # derivative part d_part = self._cfgs.kd * (self._prev_action - self._prev_raw_action) # get the raw action raw_action = p_part + i_part + d_part # clip the action action = torch.clamp(raw_action, min=self._action_space[0], max=self._action_space[1]) # get the next safety budget eps = 1e-6 next_safety_budget = torch.clamp( safety_budget + action, eps * torch.ones_like(safety_budget), self._budget_bound, ) # update the true action after clipping action = next_safety_budget - safety_budget # update the history self._prev_action, self._prev_raw_action, self._prev_error = ( action, raw_action, blured_error, ) return next_safety_budget
[docs] def act( self, safety_budget: torch.Tensor, observation: torch.Tensor, ) -> torch.Tensor: """Get the action. Args: safety_budget (torch.Tensor): The current safety budget. observation (torch.Tensor): The current observation. Returns: The selected action. """ return self.get_greedy_action(safety_budget, observation)