Source code for omnisafe.adapter.modelbased_adapter

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model-based Adapter for OmniSafe."""


from __future__ import annotations

import time
from typing import Any, Callable

import numpy as np
import torch
from gymnasium.spaces import Box

from omnisafe.adapter.online_adapter import OnlineAdapter
from omnisafe.common.logger import Logger
from omnisafe.envs.core import CMDP, make, support_envs
from omnisafe.envs.wrapper import (
    ActionRepeat,
    ActionScale,
    AutoReset,
    CostNormalize,
    ObsNormalize,
    RewardNormalize,
    TimeLimit,
    Unsqueeze,
)
from omnisafe.utils.config import Config
from omnisafe.utils.tools import get_device


[docs]class ModelBasedAdapter( OnlineAdapter, ): # pylint: disable=too-many-instance-attributes,super-init-not-called """Model Based Adapter for OmniSafe. :class:`ModelBasedAdapter` is used to adapt the environment to the model-based training. It trains a world model to provide data for algorithms training. Args: env_id (str): The environment id. num_envs (int): The number of environments. seed (int): The random seed. cfgs (Config): The configuration. Keyword Args: render_mode (str, optional): The render mode ranges from 'human' to 'rgb_array' and 'rgb_array_list'. Defaults to 'rgb_array'. camera_name (str, optional): The camera name. camera_id (int, optional): The camera id. width (int, optional): The width of the rendered image. Defaults to 256. height (int, optional): The height of the rendered image. Defaults to 256. Attributes: coordinate_observation_space (OmnisafeSpace): The coordinate observation space. lidar_observation_space (OmnisafeSpace): The lidar observation space. task (str): The task. eg. The task of SafetyPointGoal-v0 is 'goal' """ coordinate_observation_space: Box | None lidar_observation_space: Box | None task: str | None _ep_ret: torch.Tensor _ep_cost: torch.Tensor _ep_len: torch.Tensor _current_obs: torch.Tensor def __init__( # pylint: disable=too-many-arguments self, env_id: str, num_envs: int, seed: int, cfgs: Config, **env_kwargs: Any, ) -> None: """Initialize the model-based adapter.""" assert env_id in support_envs(), f'Env {env_id} is not supported.' self._env_id: str = env_id self._device: torch.device = get_device(cfgs.train_cfgs.device) self._env: CMDP = make( env_id, num_envs=num_envs, device=cfgs.train_cfgs.device, **env_kwargs, ) # wrap the environment, use the action repeat in model-based setting. self._wrapper( obs_normalize=cfgs.algo_cfgs.obs_normalize, reward_normalize=cfgs.algo_cfgs.reward_normalize, cost_normalize=cfgs.algo_cfgs.cost_normalize, action_repeat=cfgs.algo_cfgs.action_repeat, ) self._env.set_seed(seed) self._cfgs: Config = cfgs if hasattr(self._env, 'coordinate_observation_space') and hasattr( self._env, 'lidar_observation_space', ): self.coordinate_observation_space = self._env.coordinate_observation_space self.lidar_observation_space = self._env.lidar_observation_space else: self.coordinate_observation_space = None self.lidar_observation_space = None if hasattr(self._env, 'task'): self.task = self._env.task else: self.task = None self._current_obs, _ = self.reset() self._max_ep_len: int = 1000 self._reset_log() self._last_dynamics_update: int = 0 self._last_policy_update: int = 0 self._last_eval: int = 0 self._first_log: bool = False
[docs] def get_cost_from_obs_tensor(self, obs: torch.Tensor) -> torch.Tensor: """Get cost from tensor observation. Args: obs (torch.Tensor): The tensor version of observation. """ return ( self._env.get_cost_from_obs_tensor(obs) if hasattr(self._env, 'get_cost_from_obs_tensor') else torch.zeros(1) )
[docs] def get_lidar_from_coordinate(self, obs: np.ndarray) -> torch.Tensor | None: """Get lidar from numpy coordinate. Args: obs (np.ndarray): The observation. """ return ( self._env.get_lidar_from_coordinate(obs) if hasattr(self._env, 'get_lidar_from_coordinate') else None )
[docs] def render(self, *args: str, **kwargs: Any) -> Any: """Render the environment. Args: args (str): The arguments. Keyword Args: render_mode (str, optional): The render mode, ranging from ``human``, ``rgb_array``, ``rgb_array_list``. Defaults to ``rgb_array``. camera_name (str, optional): The camera name. camera_id (int, optional): The camera id. width (int, optional): The width of the rendered image. Defaults to 256. height (int, optional): The height of the rendered image. Defaults to 256. """ return self._env.render(*args, **kwargs)
[docs] def _wrapper( self, obs_normalize: bool = True, reward_normalize: bool = True, cost_normalize: bool = True, action_repeat: int = 1, ) -> None: """Wrapper the environment. .. hint:: OmniSafe supports the following wrappers: +-----------------+--------------------------------------------------------+ | Wrapper | Description | +=================+========================================================+ | TimeLimit | Limit the time steps of the environment. | +-----------------+--------------------------------------------------------+ | AutoReset | Reset the environment when the episode is done. | +-----------------+--------------------------------------------------------+ | ObsNormalize | Normalize the observation. | +-----------------+--------------------------------------------------------+ | RewardNormalize | Normalize the reward. | +-----------------+--------------------------------------------------------+ | CostNormalize | Normalize the cost. | +-----------------+--------------------------------------------------------+ | ActionScale | Scale the action. | +-----------------+--------------------------------------------------------+ | ActionRepeat | Repeat the action. | +-----------------+--------------------------------------------------------+ | Unsqueeze | Unsqueeze the step result for single environment case. | +-----------------+--------------------------------------------------------+ Args: obs_normalize (bool): Whether to normalize the observation. reward_normalize (bool): Whether to normalize the reward. cost_normalize (bool): Whether to normalize the cost. action_repeat (int): The action repeat times. """ if self._env.need_time_limit_wrapper: self._env = TimeLimit(self._env, device=self._device, time_limit=1000) if self._env.need_auto_reset_wrapper: self._env = AutoReset(self._env, device=self._device) if obs_normalize: self._env = ObsNormalize(self._env, device=self._device) if reward_normalize: self._env = RewardNormalize(self._env, device=self._device) if cost_normalize: self._env = CostNormalize(self._env, device=self._device) self._env = ActionScale(self._env, device=self._device, low=-1.0, high=1.0) self._env = ActionRepeat(self._env, times=action_repeat, device=self._device) if self._env.num_envs == 1: self._env = Unsqueeze(self._env, device=self._device)
[docs] def rollout( # pylint: disable=too-many-arguments,too-many-locals self, current_step: int, rollout_step: int, use_actor_critic: bool, act_func: Callable[[int, torch.Tensor], torch.Tensor], store_data_func: Callable[ [ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any], ], None, ], update_dynamics_func: Callable[[], None], logger: Logger, use_eval: bool, eval_func: Callable[[int, bool], None], algo_reset_func: Callable[[], None], update_actor_func: Callable[[int], None], ) -> int: """Roll out the environment and store the data in the buffer. Args: current_step (int): Current training step. rollout_step (int): Number of steps to roll out. use_actor_critic (bool): Whether to use actor-critic. act_func (Callable[[int, torch.Tensor], torch.Tensor]): Function to get action. store_data_func (Callable[[torch.Tensor, ..., dict[str, Any], ], None,]): Function to store data. update_dynamics_func (Callable[[], None]): Function to update dynamics. logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. use_eval (bool): Whether to use evaluation. eval_func (Callable[[int, bool], None]): Function to evaluate the agent. algo_reset_func (Callable[[], None]): Function to reset the algorithm. update_actor_func (Callable[[int], None]): Function to update the actor. """ epoch_start_time = time.time() update_actor_critic_time = 0.0 update_dynamics_time = 0.0 if use_eval: eval_time = 0.0 epoch_steps = 0 while epoch_steps < rollout_step and current_step < self._cfgs.train_cfgs.total_steps: action = act_func(current_step, self._current_obs) next_state, reward, cost, terminated, truncated, info = self.step(action) epoch_steps += info['num_step'] current_step += info['num_step'] self._log_value(reward=reward, cost=cost, info=info) store_data_func( self._current_obs, action, reward, cost, terminated, truncated, next_state, info, ) self._current_obs = next_state if terminated or truncated: self._log_metrics(logger) self._reset_log() self._current_obs, _ = self.reset() if algo_reset_func is not None: algo_reset_func() if ( current_step % self._cfgs.algo_cfgs.update_dynamics_cycle < self._cfgs.algo_cfgs.action_repeat and current_step - self._last_dynamics_update >= self._cfgs.algo_cfgs.update_dynamics_cycle ): update_dynamics_start = time.time() update_dynamics_func() self._last_dynamics_update = current_step update_dynamics_time += time.time() - update_dynamics_start if ( use_actor_critic and current_step % self._cfgs.algo_cfgs.update_policy_cycle < self._cfgs.algo_cfgs.action_repeat and current_step - self._last_policy_update >= self._cfgs.algo_cfgs.update_policy_cycle ): update_actor_critic_start = time.time() update_actor_func(current_step) self._last_policy_update = current_step update_actor_critic_time += time.time() - update_actor_critic_start if ( use_eval and current_step % self._cfgs.evaluation_cfgs.eval_cycle < self._cfgs.algo_cfgs.action_repeat and current_step - self._last_eval >= self._cfgs.evaluation_cfgs.eval_cycle ): eval_start = time.time() eval_func(current_step, True) self._last_eval = current_step eval_time += time.time() - eval_start # pylint: disable=undefined-variable if not self._first_log or current_step >= self._cfgs.train_cfgs.total_steps: self._log_metrics(logger) epoch_time = time.time() - epoch_start_time logger.store(**{'Time/Epoch': epoch_time}) logger.store(**{'Time/UpdateDynamics': update_dynamics_time}) rollout_time = epoch_time - update_dynamics_time if use_eval: logger.store(**{'Time/Eval': eval_time}) rollout_time -= eval_time if use_actor_critic: logger.store(**{'Time/UpdateActorCritic': update_actor_critic_time}) rollout_time -= update_actor_critic_time logger.store(**{'Time/Rollout': rollout_time}) return current_step
[docs] def _log_value( self, reward: torch.Tensor, cost: torch.Tensor, info: dict[str, Any], ) -> None: """Log value. .. note:: OmniSafe uses :class:`RewardNormalizer` wrapper, so the original reward and cost will be stored in ``info['original_reward']`` and ``info['original_cost']``. Args: reward (torch.Tensor): The immediate step reward. cost (torch.Tensor): The immediate step cost. info (dict[str, Any]): Some information logged by the environment. """ self._ep_ret += info.get('original_reward', reward).cpu() self._ep_cost += info.get('original_cost', cost).cpu() self._ep_len += info.get('num_step', 1)
[docs] def _log_metrics(self, logger: Logger) -> None: """Log metrics. Args: logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. """ self._first_log = True logger.store( { 'Metrics/EpRet': self._ep_ret, 'Metrics/EpCost': self._ep_cost, 'Metrics/EpLen': self._ep_len, }, )
[docs] def _reset_log(self) -> None: """Reset log.""" self._ep_ret = torch.zeros(1) self._ep_cost = torch.zeros(1) self._ep_len = torch.zeros(1)