# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model-based Adapter for OmniSafe."""
from __future__ import annotations
import time
from typing import Any, Callable
import numpy as np
import torch
from gymnasium.spaces import Box
from omnisafe.adapter.online_adapter import OnlineAdapter
from omnisafe.common.logger import Logger
from omnisafe.envs.core import CMDP, make, support_envs
from omnisafe.envs.wrapper import (
ActionRepeat,
ActionScale,
AutoReset,
CostNormalize,
ObsNormalize,
RewardNormalize,
TimeLimit,
Unsqueeze,
)
from omnisafe.utils.config import Config
from omnisafe.utils.tools import get_device
[docs]class ModelBasedAdapter(
OnlineAdapter,
): # pylint: disable=too-many-instance-attributes,super-init-not-called
"""Model Based Adapter for OmniSafe.
:class:`ModelBasedAdapter` is used to adapt the environment to the model-based training.
It trains a world model to provide data for algorithms training.
Args:
env_id (str): The environment id.
num_envs (int): The number of environments.
seed (int): The random seed.
cfgs (Config): The configuration.
Keyword Args:
render_mode (str, optional): The render mode ranges from 'human' to 'rgb_array' and 'rgb_array_list'.
Defaults to 'rgb_array'.
camera_name (str, optional): The camera name.
camera_id (int, optional): The camera id.
width (int, optional): The width of the rendered image. Defaults to 256.
height (int, optional): The height of the rendered image. Defaults to 256.
Attributes:
coordinate_observation_space (OmnisafeSpace): The coordinate observation space.
lidar_observation_space (OmnisafeSpace): The lidar observation space.
task (str): The task. eg. The task of SafetyPointGoal-v0 is 'goal'
"""
coordinate_observation_space: Box | None
lidar_observation_space: Box | None
task: str | None
_ep_ret: torch.Tensor
_ep_cost: torch.Tensor
_ep_len: torch.Tensor
_current_obs: torch.Tensor
def __init__( # pylint: disable=too-many-arguments
self,
env_id: str,
num_envs: int,
seed: int,
cfgs: Config,
**env_kwargs: Any,
) -> None:
"""Initialize the model-based adapter."""
assert env_id in support_envs(), f'Env {env_id} is not supported.'
self._env_id: str = env_id
self._device: torch.device = get_device(cfgs.train_cfgs.device)
self._env: CMDP = make(
env_id,
num_envs=num_envs,
device=cfgs.train_cfgs.device,
**env_kwargs,
)
# wrap the environment, use the action repeat in model-based setting.
self._wrapper(
obs_normalize=cfgs.algo_cfgs.obs_normalize,
reward_normalize=cfgs.algo_cfgs.reward_normalize,
cost_normalize=cfgs.algo_cfgs.cost_normalize,
action_repeat=cfgs.algo_cfgs.action_repeat,
)
self._env.set_seed(seed)
self._cfgs: Config = cfgs
if hasattr(self._env, 'coordinate_observation_space') and hasattr(
self._env,
'lidar_observation_space',
):
self.coordinate_observation_space = self._env.coordinate_observation_space
self.lidar_observation_space = self._env.lidar_observation_space
else:
self.coordinate_observation_space = None
self.lidar_observation_space = None
if hasattr(self._env, 'task'):
self.task = self._env.task
else:
self.task = None
self._current_obs, _ = self.reset()
self._max_ep_len: int = 1000
self._reset_log()
self._last_dynamics_update: int = 0
self._last_policy_update: int = 0
self._last_eval: int = 0
self._first_log: bool = False
[docs] def get_cost_from_obs_tensor(self, obs: torch.Tensor) -> torch.Tensor:
"""Get cost from tensor observation.
Args:
obs (torch.Tensor): The tensor version of observation.
"""
return (
self._env.get_cost_from_obs_tensor(obs)
if hasattr(self._env, 'get_cost_from_obs_tensor')
else torch.zeros(1)
)
[docs] def get_lidar_from_coordinate(self, obs: np.ndarray) -> torch.Tensor | None:
"""Get lidar from numpy coordinate.
Args:
obs (np.ndarray): The observation.
"""
return (
self._env.get_lidar_from_coordinate(obs)
if hasattr(self._env, 'get_lidar_from_coordinate')
else None
)
[docs] def render(self, *args: str, **kwargs: Any) -> Any:
"""Render the environment.
Args:
args (str): The arguments.
Keyword Args:
render_mode (str, optional): The render mode, ranging from ``human``, ``rgb_array``, ``rgb_array_list``.
Defaults to ``rgb_array``.
camera_name (str, optional): The camera name.
camera_id (int, optional): The camera id.
width (int, optional): The width of the rendered image. Defaults to 256.
height (int, optional): The height of the rendered image. Defaults to 256.
"""
return self._env.render(*args, **kwargs)
[docs] def _wrapper(
self,
obs_normalize: bool = True,
reward_normalize: bool = True,
cost_normalize: bool = True,
action_repeat: int = 1,
) -> None:
"""Wrapper the environment.
.. hint::
OmniSafe supports the following wrappers:
+-----------------+--------------------------------------------------------+
| Wrapper | Description |
+=================+========================================================+
| TimeLimit | Limit the time steps of the environment. |
+-----------------+--------------------------------------------------------+
| AutoReset | Reset the environment when the episode is done. |
+-----------------+--------------------------------------------------------+
| ObsNormalize | Normalize the observation. |
+-----------------+--------------------------------------------------------+
| RewardNormalize | Normalize the reward. |
+-----------------+--------------------------------------------------------+
| CostNormalize | Normalize the cost. |
+-----------------+--------------------------------------------------------+
| ActionScale | Scale the action. |
+-----------------+--------------------------------------------------------+
| ActionRepeat | Repeat the action. |
+-----------------+--------------------------------------------------------+
| Unsqueeze | Unsqueeze the step result for single environment case. |
+-----------------+--------------------------------------------------------+
Args:
obs_normalize (bool): Whether to normalize the observation.
reward_normalize (bool): Whether to normalize the reward.
cost_normalize (bool): Whether to normalize the cost.
action_repeat (int): The action repeat times.
"""
if self._env.need_time_limit_wrapper:
self._env = TimeLimit(self._env, device=self._device, time_limit=1000)
if self._env.need_auto_reset_wrapper:
self._env = AutoReset(self._env, device=self._device)
if obs_normalize:
self._env = ObsNormalize(self._env, device=self._device)
if reward_normalize:
self._env = RewardNormalize(self._env, device=self._device)
if cost_normalize:
self._env = CostNormalize(self._env, device=self._device)
self._env = ActionScale(self._env, device=self._device, low=-1.0, high=1.0)
self._env = ActionRepeat(self._env, times=action_repeat, device=self._device)
if self._env.num_envs == 1:
self._env = Unsqueeze(self._env, device=self._device)
[docs] def rollout( # pylint: disable=too-many-arguments,too-many-locals
self,
current_step: int,
rollout_step: int,
use_actor_critic: bool,
act_func: Callable[[int, torch.Tensor], torch.Tensor],
store_data_func: Callable[
[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
dict[str, Any],
],
None,
],
update_dynamics_func: Callable[[], None],
logger: Logger,
use_eval: bool,
eval_func: Callable[[int, bool], None],
algo_reset_func: Callable[[], None],
update_actor_func: Callable[[int], None],
) -> int:
"""Roll out the environment and store the data in the buffer.
Args:
current_step (int): Current training step.
rollout_step (int): Number of steps to roll out.
use_actor_critic (bool): Whether to use actor-critic.
act_func (Callable[[int, torch.Tensor], torch.Tensor]): Function to get action.
store_data_func (Callable[[torch.Tensor, ..., dict[str, Any], ], None,]): Function to store data.
update_dynamics_func (Callable[[], None]): Function to update dynamics.
logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``.
use_eval (bool): Whether to use evaluation.
eval_func (Callable[[int, bool], None]): Function to evaluate the agent.
algo_reset_func (Callable[[], None]): Function to reset the algorithm.
update_actor_func (Callable[[int], None]): Function to update the actor.
"""
epoch_start_time = time.time()
update_actor_critic_time = 0.0
update_dynamics_time = 0.0
if use_eval:
eval_time = 0.0
epoch_steps = 0
while epoch_steps < rollout_step and current_step < self._cfgs.train_cfgs.total_steps:
action = act_func(current_step, self._current_obs)
next_state, reward, cost, terminated, truncated, info = self.step(action)
epoch_steps += info['num_step']
current_step += info['num_step']
self._log_value(reward=reward, cost=cost, info=info)
store_data_func(
self._current_obs,
action,
reward,
cost,
terminated,
truncated,
next_state,
info,
)
self._current_obs = next_state
if terminated or truncated:
self._log_metrics(logger)
self._reset_log()
self._current_obs, _ = self.reset()
if algo_reset_func is not None:
algo_reset_func()
if (
current_step % self._cfgs.algo_cfgs.update_dynamics_cycle
< self._cfgs.algo_cfgs.action_repeat
and current_step - self._last_dynamics_update
>= self._cfgs.algo_cfgs.update_dynamics_cycle
):
update_dynamics_start = time.time()
update_dynamics_func()
self._last_dynamics_update = current_step
update_dynamics_time += time.time() - update_dynamics_start
if (
use_actor_critic
and current_step % self._cfgs.algo_cfgs.update_policy_cycle
< self._cfgs.algo_cfgs.action_repeat
and current_step - self._last_policy_update
>= self._cfgs.algo_cfgs.update_policy_cycle
):
update_actor_critic_start = time.time()
update_actor_func(current_step)
self._last_policy_update = current_step
update_actor_critic_time += time.time() - update_actor_critic_start
if (
use_eval
and current_step % self._cfgs.evaluation_cfgs.eval_cycle
< self._cfgs.algo_cfgs.action_repeat
and current_step - self._last_eval >= self._cfgs.evaluation_cfgs.eval_cycle
):
eval_start = time.time()
eval_func(current_step, True)
self._last_eval = current_step
eval_time += time.time() - eval_start # pylint: disable=undefined-variable
if not self._first_log or current_step >= self._cfgs.train_cfgs.total_steps:
self._log_metrics(logger)
epoch_time = time.time() - epoch_start_time
logger.store(**{'Time/Epoch': epoch_time})
logger.store(**{'Time/UpdateDynamics': update_dynamics_time})
rollout_time = epoch_time - update_dynamics_time
if use_eval:
logger.store(**{'Time/Eval': eval_time})
rollout_time -= eval_time
if use_actor_critic:
logger.store(**{'Time/UpdateActorCritic': update_actor_critic_time})
rollout_time -= update_actor_critic_time
logger.store(**{'Time/Rollout': rollout_time})
return current_step
[docs] def _log_value(
self,
reward: torch.Tensor,
cost: torch.Tensor,
info: dict[str, Any],
) -> None:
"""Log value.
.. note::
OmniSafe uses :class:`RewardNormalizer` wrapper, so the original reward and cost will
be stored in ``info['original_reward']`` and ``info['original_cost']``.
Args:
reward (torch.Tensor): The immediate step reward.
cost (torch.Tensor): The immediate step cost.
info (dict[str, Any]): Some information logged by the environment.
"""
self._ep_ret += info.get('original_reward', reward).cpu()
self._ep_cost += info.get('original_cost', cost).cpu()
self._ep_len += info.get('num_step', 1)
[docs] def _log_metrics(self, logger: Logger) -> None:
"""Log metrics.
Args:
logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``.
"""
self._first_log = True
logger.store(
{
'Metrics/EpRet': self._ep_ret,
'Metrics/EpCost': self._ep_cost,
'Metrics/EpLen': self._ep_len,
},
)
[docs] def _reset_log(self) -> None:
"""Reset log."""
self._ep_ret = torch.zeros(1)
self._ep_cost = torch.zeros(1)
self._ep_len = torch.zeros(1)