Source code for omnisafe.envs.wrapper

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for the environment."""

from __future__ import annotations

from typing import Any

import numpy as np
import torch
from gymnasium import spaces

from omnisafe.common import Normalizer
from omnisafe.envs.core import CMDP, Wrapper


[docs]class TimeLimit(Wrapper): """Time limit wrapper for the environment. .. warning:: The time limit wrapper only supports single environment. Examples: >>> env = TimeLimit(env, time_limit=100) Args: env (CMDP): The environment to wrap. time_limit (int): The time limit for each episode. device (torch.device): The torch device to use. Attributes: _time_limit (int): The time limit for each episode. _time (int): The current time step. """ def __init__(self, env: CMDP, time_limit: int, device: torch.device) -> None: """Initialize an instance of :class:`TimeLimit`.""" super().__init__(env=env, device=device) assert self.num_envs == 1, 'TimeLimit only supports single environment' self._time: int = 0 self._time_limit: int = time_limit
[docs] def reset( self, seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, dict[str, Any]]: """Reset the environment. .. note:: Additionally, the time step will be reset to 0. Args: seed (int, optional): The random seed. Defaults to None. options (dict[str, Any], optional): The options for the environment. Defaults to None. Returns: observation: The initial observation of the space. info: Some information logged by the environment. """ self._time = 0 return super().reset(seed=seed, options=options)
[docs] def step( self, action: torch.Tensor, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any], ]: """Run one timestep of the environment's dynamics using the agent actions. .. note:: Additionally, the time step will be increased by 1. Args: action (torch.Tensor): The action from the agent or random. Returns: observation: The agent's observation of the current environment. reward: The amount of reward returned after previous action. cost: The amount of cost returned after previous action. terminated: Whether the episode has ended. truncated: Whether the episode has been truncated due to a time limit. info: Some information logged by the environment. """ obs, reward, cost, terminated, truncated, info = super().step(action) self._time += 1 truncated = torch.tensor( self._time >= self._time_limit, dtype=torch.bool, device=self._device, ) return obs, reward, cost, terminated, truncated, info
[docs]class AutoReset(Wrapper): """Auto reset the environment when the episode is terminated. Examples: >>> env = AutoReset(env) Args: env (CMDP): The environment to wrap. device (torch.device): The torch device to use. """ def __init__(self, env: CMDP, device: torch.device) -> None: """Initialize an instance of :class:`AutoReset`.""" super().__init__(env=env, device=device) assert self.num_envs == 1, 'AutoReset only supports single environment'
[docs] def step( self, action: torch.Tensor, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any], ]: """Run one timestep of the environment's dynamics using the agent actions. .. note:: If the episode is terminated, the environment will be reset. The ``obs`` will be the first observation of the new episode. And the true final observation will be stored in ``info['final_observation']``. Args: action (torch.Tensor): The action from the agent or random. Returns: observation: The agent's observation of the current environment. reward: The amount of reward returned after previous action. cost: The amount of cost returned after previous action. terminated: Whether the episode has ended. truncated: Whether the episode has been truncated due to a time limit. info: Some information logged by the environment. """ obs, reward, cost, terminated, truncated, info = super().step(action) if terminated or truncated: new_obs, new_info = self.reset() assert ( 'final_observation' not in new_info ), 'info dict cannot contain key "final_observation" ' assert 'final_info' not in new_info, 'info dict cannot contain key "final_info" ' new_info['final_observation'] = obs new_info['final_info'] = info obs = new_obs info = new_info return obs, reward, cost, terminated, truncated, info
[docs]class ObsNormalize(Wrapper): """Normalize the observation. Examples: >>> env = ObsNormalize(env) >>> norm = Normalizer(env.observation_space.shape) # load saved normalizer >>> env = ObsNormalize(env, norm) Args: env (CMDP): The environment to wrap. device (torch.device): The torch device to use. norm (Normalizer or None, optional): The normalizer to use. Defaults to None. """ def __init__(self, env: CMDP, device: torch.device, norm: Normalizer | None = None) -> None: """Initialize an instance of :class:`ObsNormalize`.""" super().__init__(env=env, device=device) assert isinstance(self.observation_space, spaces.Box), 'Observation space must be Box' self._obs_normalizer: Normalizer if norm is not None: self._obs_normalizer = norm.to(self._device) else: self._obs_normalizer = Normalizer(self.observation_space.shape, clip=5).to(self._device)
[docs] def step( self, action: torch.Tensor, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any], ]: """Run one timestep of the environment's dynamics using the agent actions. .. note:: The observation and the ``info['final_observation']`` will be normalized. Args: action (torch.Tensor): The action from the agent or random. Returns: observation: The agent's observation of the current environment. reward: The amount of reward returned after previous action. cost: The amount of cost returned after previous action. terminated: Whether the episode has ended. truncated: Whether the episode has been truncated due to a time limit. info: Some information logged by the environment. """ obs, reward, cost, terminated, truncated, info = super().step(action) if 'final_observation' in info: final_obs_slice = info['_final_observation'] if self.num_envs > 1 else slice(None) info['final_observation'] = info['final_observation'].to(self._device) info['original_final_observation'] = info['final_observation'] info['final_observation'][final_obs_slice] = self._obs_normalizer.normalize( info['final_observation'][final_obs_slice], ) info['original_obs'] = obs obs = self._obs_normalizer.normalize(obs) return obs, reward, cost, terminated, truncated, info
[docs] def reset( self, seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, dict[str, Any]]: """Reset the environment and returns an initial observation. Args: seed (int, optional): The random seed. Defaults to None. options (dict[str, Any], optional): The options for the environment. Defaults to None. Returns: observation: The initial observation of the space. info: Some information logged by the environment. """ obs, info = super().reset(seed=seed, options=options) info['original_obs'] = obs obs = self._obs_normalizer.normalize(obs) return obs, info
[docs] def save(self) -> dict[str, torch.nn.Module]: """Save the observation normalizer. .. note:: The saved components will be stored in the wrapped environment. If the environment is not wrapped, the saved components will be empty dict. common wrappers are obs_normalize, reward_normalize, and cost_normalize. When evaluating the saved model, the normalizer should be loaded. Returns: The saved components, that is the observation normalizer. """ saved = super().save() saved['obs_normalizer'] = self._obs_normalizer return saved
[docs]class RewardNormalize(Wrapper): """Normalize the reward. Examples: >>> env = RewardNormalize(env) >>> norm = Normalizer(()) # load saved normalizer >>> env = RewardNormalize(env, norm) Args: env (CMDP): The environment to wrap. device (torch.device): The torch device to use. norm (Normalizer or None, optional): The normalizer to use. Defaults to None. """ def __init__(self, env: CMDP, device: torch.device, norm: Normalizer | None = None) -> None: """Initialize an instance of :class:`RewardNormalize`.""" super().__init__(env=env, device=device) self._reward_normalizer: Normalizer if norm is not None: self._reward_normalizer = norm.to(self._device) else: self._reward_normalizer = Normalizer((), clip=5).to(self._device)
[docs] def step( self, action: torch.Tensor, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any], ]: """Run one timestep of the environment's dynamics using the agent actions. .. note:: The reward will be normalized for agent training. Then the original reward will be stored in ``info['original_reward']`` for logging. Args: action (torch.Tensor): The action from the agent or random. Returns: observation: The agent's observation of the current environment. reward: The amount of reward returned after previous action. cost: The amount of cost returned after previous action. terminated: Whether the episode has ended. truncated: Whether the episode has been truncated due to a time limit. info: Some information logged by the environment. """ obs, reward, cost, terminated, truncated, info = super().step(action) info['original_reward'] = reward reward = self._reward_normalizer.normalize(reward) return obs, reward, cost, terminated, truncated, info
[docs] def save(self) -> dict[str, torch.nn.Module]: """Save the reward normalizer. .. note:: The saved components will be stored in the wrapped environment. If the environment is not wrapped, the saved components will be empty dict. common wrappers are obs_normalize, reward_normalize, and cost_normalize. Returns: The saved components, that is the reward normalizer. """ saved = super().save() saved['reward_normalizer'] = self._reward_normalizer return saved
[docs]class CostNormalize(Wrapper): """Normalize the cost. Examples: >>> env = CostNormalize(env) >>> norm = Normalizer(()) # load saved normalizer >>> env = CostNormalize(env, norm) Args: env (CMDP): The environment to wrap. device (torch.device): The torch device to use. norm (Normalizer or None, optional): The normalizer to use. Defaults to None. """ def __init__(self, env: CMDP, device: torch.device, norm: Normalizer | None = None) -> None: """Initialize an instance of :class:`CostNormalize`.""" super().__init__(env=env, device=device) self._cost_normalizer: Normalizer if norm is not None: self._cost_normalizer = norm.to(self._device) else: self._cost_normalizer = Normalizer((), clip=5).to(self._device)
[docs] def step( self, action: torch.Tensor, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any], ]: """Run one timestep of the environment's dynamics using the agent actions. .. note:: The cost will be normalized for agent training. Then the original reward will be stored in ``info['original_cost']`` for logging. Args: action (torch.Tensor): The action from the agent or random. Returns: observation: The agent's observation of the current environment. reward: The amount of reward returned after previous action. cost: The amount of cost returned after previous action. terminated: Whether the episode has ended. truncated: Whether the episode has been truncated due to a time limit. info: Some information logged by the environment. """ obs, reward, cost, terminated, truncated, info = super().step(action) info['original_cost'] = cost cost = self._cost_normalizer.normalize(cost) return obs, reward, cost, terminated, truncated, info
[docs] def save(self) -> dict[str, torch.nn.Module]: """Save the cost normalizer. .. note:: The saved components will be stored in the wrapped environment. If the environment is not wrapped, the saved components will be empty dict. common wrappers are obs_normalize, reward_normalize, and cost_normalize. Returns: The saved components, that is the cost normalizer. """ saved = super().save() saved['cost_normalizer'] = self._cost_normalizer return saved
[docs]class ActionScale(Wrapper): """Scale the action space to a given range. Examples: >>> env = ActionScale(env, low=-1, high=1) >>> env.action_space Box(-1.0, 1.0, (1,), float32) Args: env (CMDP): The environment to wrap. device (torch.device): The device to use. low (int or float): The lower bound of the action space. high (int or float): The upper bound of the action space. """ def __init__( self, env: CMDP, device: torch.device, low: float, high: float, ) -> None: """Initialize an instance of :class:`ActionScale`.""" super().__init__(env=env, device=device) assert isinstance(self.action_space, spaces.Box), 'Action space must be Box' self._old_min_action: torch.Tensor = torch.tensor( self.action_space.low, dtype=torch.float32, device=self._device, ) self._old_max_action: torch.Tensor = torch.tensor( self.action_space.high, dtype=torch.float32, device=self._device, ) min_action = np.zeros(self.action_space.shape, dtype=self.action_space.dtype) + low max_action = np.zeros(self.action_space.shape, dtype=self.action_space.dtype) + high self._action_space: spaces.Box = spaces.Box( low=min_action, high=max_action, shape=self.action_space.shape, dtype=self.action_space.dtype, # type: ignore[arg-type] ) self._min_action: torch.Tensor = torch.tensor( min_action, dtype=torch.float32, device=self._device, ) self._max_action: torch.Tensor = torch.tensor( max_action, dtype=torch.float32, device=self._device, )
[docs] def step( self, action: torch.Tensor, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any], ]: """Run one timestep of the environment's dynamics using the agent actions. .. note:: The action will be scaled to the original range for agent training. Args: action (torch.Tensor): The action from the agent or random. Returns: observation: The agent's observation of the current environment. reward: The amount of reward returned after previous action. cost: The amount of cost returned after previous action. terminated: Whether the episode has ended. truncated: Whether the episode has been truncated due to a time limit. info: Some information logged by the environment. """ action = self._old_min_action + (self._old_max_action - self._old_min_action) * ( action - self._min_action ) / (self._max_action - self._min_action) return super().step(action)
[docs]class ActionRepeat(Wrapper): """Repeat action given times. Example: >>> env = ActionRepeat(env, times=3) """ def __init__( self, env: CMDP, times: int, device: torch.device, ) -> None: """Initialize the wrapper. Args: env: The environment to wrap. times: The number of times to repeat the action. device: The device to use. """ super().__init__(env=env, device=device) self._times = times self._device = device
[docs] def step( self, action: torch.Tensor, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any], ]: """Run self._times timesteps of the environment's dynamics using the agent actions. Args: action (torch.Tensor): The action from the agent or random. Returns: observation: The agent's observation of the current environment. reward: The amount of reward returned after previous action. cost: The amount of cost returned after previous action. terminated: Whether the episode has ended. truncated: Whether the episode has been truncated due to a time limit. info: Some information logged by the environment. """ rewards, costs = torch.tensor(0.0).to(self._device), torch.tensor(0.0).to(self._device) for _step, _ in enumerate(range(self._times)): obs, reward, cost, terminated, truncated, info = super().step(action) rewards += reward costs += cost goal_met = info.get('goal_met', False) if terminated or truncated or goal_met: break info['num_step'] = _step + 1 return obs, rewards, costs, terminated, truncated, info
[docs]class Unsqueeze(Wrapper): """Unsqueeze the observation, reward, cost, terminated, truncated and info. Examples: >>> env = Unsqueeze(env) """ def __init__(self, env: CMDP, device: torch.device) -> None: """Initialize an instance of :class:`Unsqueeze`.""" super().__init__(env=env, device=device) assert self.num_envs == 1, 'Unsqueeze only works with single environment' assert isinstance(self.observation_space, spaces.Box), 'Observation space must be Box'
[docs] def step( self, action: torch.Tensor, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any], ]: """Run one timestep of the environment's dynamics using the agent actions. .. note:: The vector information will be unsqueezed to (1, dim) for agent training. Args: action (torch.Tensor): The action from the agent or random. Returns: observation: The agent's observation of the current environment. reward: The amount of reward returned after previous action. cost: The amount of cost returned after previous action. terminated: Whether the episode has ended. truncated: Whether the episode has been truncated due to a time limit. info: Some information logged by the environment. """ action = action.squeeze(0) obs, reward, cost, terminated, truncated, info = super().step(action) obs, reward, cost, terminated, truncated = ( x.unsqueeze(0) for x in (obs, reward, cost, terminated, truncated) ) for k, v in info.items(): if isinstance(v, torch.Tensor): info[k] = v.unsqueeze(0) return obs, reward, cost, terminated, truncated, info
[docs] def reset( self, seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, dict[str, Any]]: """Reset the environment and returns a new observation. .. note:: The vector information will be unsqueezed to (1, dim) for agent training. Args: seed (int, optional): The random seed. Defaults to None. options (dict[str, Any], optional): The options for the environment. Defaults to None. Returns: observation: The initial observation of the space. info: Some information logged by the environment. """ obs, info = super().reset(seed=seed, options=options) obs = obs.unsqueeze(0) for k, v in info.items(): if isinstance(v, torch.Tensor): info[k] = v.unsqueeze(0) return obs, info