Source code for omnisafe.algorithms.model_based.base.pets

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Probabilistic Ensembles with Trajectory Sampling algorithm."""


from __future__ import annotations

import os
import time
from typing import Any, Callable

import numpy as np
import torch
from gymnasium.spaces import Box
from gymnasium.utils.save_video import save_video
from matplotlib import pylab

from omnisafe.adapter import ModelBasedAdapter
from omnisafe.algorithms import registry
from omnisafe.algorithms.base_algo import BaseAlgo
from omnisafe.algorithms.model_based.base.ensemble import EnsembleDynamicsModel
from omnisafe.algorithms.model_based.planner.cem import CEMPlanner
from omnisafe.common.buffer import OffPolicyBuffer
from omnisafe.common.logger import Logger
from omnisafe.typing import OmnisafeSpace


[docs]@registry.register # pylint: disable-next=too-many-instance-attributes, too-few-public-methods class PETS(BaseAlgo): """The Probabilistic Ensembles with Trajectory Sampling (PETS) algorithm. References: - Title: Deep Reinforcement Learning in a Handful of Trials using Probabilistic Dynamics Models - Authors: Kurtland Chua, Roberto Calandra, Rowan McAllister, Sergey Levine. - URL: `PETS <https://arxiv.org/abs/1805.12114>`_ """
[docs] def _init_env(self) -> None: self._env: ModelBasedAdapter = ModelBasedAdapter( self._env_id, 1, self._seed, self._cfgs, ) self._total_steps: int = int(self._cfgs.train_cfgs.total_steps) self._steps_per_epoch: int = int(self._cfgs.algo_cfgs.steps_per_epoch) self._epochs: int = self._total_steps // self._cfgs.algo_cfgs.steps_per_epoch
[docs] def _init_model(self) -> None: """Initialize dynamics model and planner.""" self._dynamics_state_space: OmnisafeSpace = ( self._env.coordinate_observation_space if self._env.coordinate_observation_space is not None else self._env.observation_space ) assert self._dynamics_state_space is not None and isinstance( self._dynamics_state_space.shape, tuple, ) assert self._env.action_space is not None and isinstance( self._env.action_space.shape, tuple, ) if isinstance(self._env.action_space, Box): self._action_space = self._env.action_space else: raise NotImplementedError self._dynamics: EnsembleDynamicsModel = EnsembleDynamicsModel( model_cfgs=self._cfgs.dynamics_cfgs, device=self._device, state_shape=self._dynamics_state_space.shape, action_shape=self._action_space.shape, actor_critic=None, rew_func=None, cost_func=None, terminal_func=None, ) self._planner: CEMPlanner = CEMPlanner( dynamics=self._dynamics, planner_cfgs=self._cfgs.planner_cfgs, gamma=float(self._cfgs.algo_cfgs.gamma), cost_gamma=float(self._cfgs.algo_cfgs.cost_gamma), dynamics_state_shape=self._dynamics_state_space.shape, action_shape=self._action_space.shape, action_max=1.0, action_min=-1.0, device=self._device, ) self._use_actor_critic: bool = False self._update_dynamics_cycle: int = int(self._cfgs.algo_cfgs.update_dynamics_cycle)
[docs] def _init(self) -> None: """Initialize the algorithm.""" self._dynamics_buf: OffPolicyBuffer = OffPolicyBuffer( obs_space=self._dynamics_state_space, act_space=self._env.action_space, size=self._cfgs.train_cfgs.total_steps, batch_size=self._cfgs.dynamics_cfgs.batch_size, device=self._device, ) env_kwargs: dict[str, Any] = { 'render_mode': 'rgb_array', 'camera_name': 'track', } self._eval_env: ModelBasedAdapter = ModelBasedAdapter( self._env_id, 1, self._seed, self._cfgs, **env_kwargs, ) self._eval_fn: Callable[[int, bool], None] = self._evaluation_single_step
[docs] def _init_log(self) -> None: """Initialize logger. +---------------------------+-------------------------------------------------+ | Things to log | Description | +===========================+=================================================+ | Train/Epoch | Current epoch. | +---------------------------+-------------------------------------------------+ | TotalEnvSteps | Total steps of the experiment. | +---------------------------+-------------------------------------------------+ | Metrics/EpRet | Average return of the epoch. | +---------------------------+-------------------------------------------------+ | Metrics/EpCost | Average cost of the epoch. | +---------------------------+-------------------------------------------------+ | Metrics/EpLen | Average length of the epoch. | +---------------------------+-------------------------------------------------+ | EvalMetrics/EpRet | Average episode return in evaluation. | +---------------------------+-------------------------------------------------+ | EvalMetrics/EpCost | Average episode cost in evaluation. | +---------------------------+-------------------------------------------------+ | EvalMetrics/EpLen | Average episode length in evaluation. | +---------------------------+-------------------------------------------------+ | Loss/DynamicsTrainMseLoss | The training loss of dynamics model. | +---------------------------+-------------------------------------------------+ | Loss/DynamicsValMseLoss | The validation loss of dynamics model. | +---------------------------+-------------------------------------------------+ | Plan/iter | The number of iterations in the planner. | +---------------------------+-------------------------------------------------+ | Plan/last_var_mean | The mean of the last variance in the planner. | +---------------------------+-------------------------------------------------+ | Plan/last_var_max | The max of the last variance in the planner. | +---------------------------+-------------------------------------------------+ | Plan/last_var_min | The min of the last variance in the planner. | +---------------------------+-------------------------------------------------+ | Plan/episode_returns_max | The max of the episode returns in the planner. | +---------------------------+-------------------------------------------------+ | Plan/episode_returns_mean | The mean of the episode returns in the planner. | +---------------------------+-------------------------------------------------+ | Plan/episode_returns_min | The min of the episode returns in the planner. | +---------------------------+-------------------------------------------------+ | Time/Total | The total time of the algorithm. | +---------------------------+-------------------------------------------------+ | Time/Rollout | The time of the rollout. | +---------------------------+-------------------------------------------------+ | Time/UpdateActorCritic | The time of the actor-critic update. | +---------------------------+-------------------------------------------------+ | Time/Eval | The time of the evaluation. | +---------------------------+-------------------------------------------------+ | Time/Epoch | The time of the epoch. | +---------------------------+-------------------------------------------------+ | Time/FPS | The FPS of the algorithm. | +---------------------------+-------------------------------------------------+ | Time/UpdateDynamics | The time of the dynamics update. | +---------------------------+-------------------------------------------------+ """ self._logger = Logger( output_dir=self._cfgs.logger_cfgs.log_dir, exp_name=self._cfgs.exp_name, seed=self._cfgs.seed, use_tensorboard=self._cfgs.logger_cfgs.use_tensorboard, use_wandb=self._cfgs.logger_cfgs.use_wandb, config=self._cfgs, ) self._logger.register_key('Train/Epoch') self._logger.register_key('TotalEnvSteps') self._logger.register_key('Metrics/EpRet', window_length=50) self._logger.register_key('Metrics/EpCost', window_length=50) self._logger.register_key('Metrics/EpLen', window_length=50) if self._cfgs.evaluation_cfgs.use_eval: self._logger.register_key('EvalMetrics/EpRet', window_length=5) self._logger.register_key('EvalMetrics/EpCost', window_length=5) self._logger.register_key('EvalMetrics/EpLen', window_length=5) self._logger.register_key('Loss/DynamicsTrainMseLoss') self._logger.register_key('Loss/DynamicsValMseLoss') self._logger.register_key('Plan/iter') self._logger.register_key('Plan/last_var_mean') self._logger.register_key('Plan/last_var_max') self._logger.register_key('Plan/last_var_min') self._logger.register_key('Plan/episode_returns_max') self._logger.register_key('Plan/episode_returns_mean') self._logger.register_key('Plan/episode_returns_min') self._logger.register_key('Time/Total') self._logger.register_key('Time/Rollout') self._logger.register_key('Time/UpdateDynamics') if self._use_actor_critic: self._logger.register_key('Time/UpdateActorCritic') if self._cfgs.evaluation_cfgs.use_eval: self._logger.register_key('Time/Eval') self._logger.register_key('Time/Epoch') self._logger.register_key('Time/FPS') self._save_model()
[docs] def _save_model(self) -> None: """Save the model.""" # set up model saving what_to_save: dict[str, Any] = { 'dynamics': self._dynamics.ensemble_model, } if self._cfgs.algo_cfgs.obs_normalize: obs_normalizer = self._env.save()['obs_normalizer'] what_to_save['obs_normalizer'] = obs_normalizer self._logger.setup_torch_saver(what_to_save) self._logger.torch_save()
[docs] def learn(self) -> tuple[float, float, float]: """This is main function for algorithm update. It is divided into the following steps: - :meth:`rollout`: collect interactive data from environment. - :meth:`update`: perform actor/critic updates. - :meth:`log`: epoch/update information for visualization and terminal log print. Returns: ep_ret: Average episode return in final epoch. ep_cost: Average episode cost in final epoch. ep_len: Average episode length in final epoch. """ self._logger.log('INFO: Start training') start_time = time.time() current_step = 0 for epoch in range(self._epochs): current_step = self._env.rollout( current_step=current_step, rollout_step=self._steps_per_epoch, use_actor_critic=self._use_actor_critic, act_func=self._select_action, store_data_func=self._store_real_data, update_dynamics_func=self._update_dynamics_model, use_eval=self._cfgs.evaluation_cfgs.use_eval, eval_func=self._eval_fn, logger=self._logger, algo_reset_func=self._algo_reset, update_actor_func=self._update_policy, ) if current_step > self._cfgs.algo_cfgs.start_learning_steps: # update something per epoch # e.g. update lagrange multiplier self._update_epoch() # evaluate episode self._logger.store( { 'Train/Epoch': epoch, 'TotalEnvSteps': current_step, 'Time/Total': time.time() - start_time, }, ) self._logger.dump_tabular() # save model to disk if (epoch + 1) % self._cfgs.logger_cfgs.save_model_freq == 0: self._logger.torch_save() ep_ret = self._logger.get_stats('Metrics/EpRet')[0] ep_cost = self._logger.get_stats('Metrics/EpCost')[0] ep_len = self._logger.get_stats('Metrics/EpLen')[0] self._logger.close() return ep_ret, ep_cost, ep_len
[docs] def _algo_reset( self, ) -> None: """Reset the algorithm."""
[docs] def _update_policy( self, current_step: int, # pylint: disable=unused-argument ) -> None: """Update policy."""
[docs] def _update_dynamics_model( self, ) -> None: """Update dynamics model.""" state = self._dynamics_buf.data['obs'][: self._dynamics_buf.size, :] action = self._dynamics_buf.data['act'][: self._dynamics_buf.size, :] reward = self._dynamics_buf.data['reward'][: self._dynamics_buf.size] cost = self._dynamics_buf.data['cost'][: self._dynamics_buf.size] next_state = self._dynamics_buf.data['next_obs'][: self._dynamics_buf.size, :] delta_state = next_state - state assert isinstance(delta_state, torch.Tensor), 'delta_state should be torch.Tensor' inputs = torch.cat((state, action), -1) inputs = torch.reshape(inputs, (inputs.shape[0], -1)) labels = torch.reshape(delta_state, (delta_state.shape[0], -1)) if self._cfgs.dynamics_cfgs.predict_reward: labels = torch.cat(((torch.reshape(reward, (reward.shape[0], -1))), labels), -1) if self._cfgs.dynamics_cfgs.predict_cost: labels = torch.cat(((torch.reshape(cost, (cost.shape[0], -1))), labels), -1) inputs = inputs.cpu().detach().numpy() labels = labels.cpu().detach().numpy() train_mse_losses, val_mse_losses = self._dynamics.train( inputs, labels, holdout_ratio=0.2, ) self._logger.store( **{ 'Loss/DynamicsTrainMseLoss': train_mse_losses.item(), 'Loss/DynamicsValMseLoss': val_mse_losses.item(), }, )
[docs] def _update_epoch(self) -> None: """Update function per epoch."""
[docs] def _select_action( # pylint: disable=unused-argument self, current_step: int, state: torch.Tensor, ) -> torch.Tensor: """Select action. Args: current_step (int): The current step. state (torch.Tensor): The current state. Returns: The selected action. """ assert state.shape[0] == 1, 'state shape should be [1, state_dim]' if current_step < self._cfgs.algo_cfgs.start_learning_steps: action = torch.tensor(self._env.action_space.sample()).to(self._device).unsqueeze(0) else: action, info = self._planner.output_action(state) self._logger.store(**info) assert action.shape == torch.Size( [1, *self._action_space.shape], ), 'action shape should be [batch_size, action_dim]' return action
[docs] def _store_real_data( # pylint: disable=too-many-arguments,unused-argument self, state: torch.Tensor, action: torch.Tensor, reward: torch.Tensor, cost: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, next_state: torch.Tensor, info: dict[str, Any], ) -> None: # pylint: disable=too-many-arguments """Store real data in buffer. Args: state (torch.Tensor): The state from the environment. action (torch.Tensor): The action from the agent. reward (torch.Tensor): The reward signal from the environment. cost (torch.Tensor): The cost signal from the environment. terminated (torch.Tensor): The terminated signal from the environment. truncated (torch.Tensor): The truncated signal from the environment. next_state (torch.Tensor): The next state from the environment. info (dict[str, Any]): The information from the environment. """ done = terminated or truncated goal_met = info.get('goal_met', False) if not terminated and not truncated and not goal_met: # pylint: disable-next=line-too-long # if goal_met == true, Current goal position is not related to the last goal position, this huge transition will confuse the dynamics model. self._dynamics_buf.store( obs=state, act=action, reward=reward, cost=cost, next_obs=next_state, done=done, )
[docs] def _evaluation_single_step( # pylint: disable=too-many-locals self, current_step: int, use_real_input: bool = True, ) -> None: """Evaluation dynamics model single step. Args: current_step (int): The current step. use_real_input (bool): Whether to use real input or not. """ obs, _ = self._eval_env.reset() obs_dynamics = obs ep_len, ep_ret, ep_cost = 0, 0, 0 terminated, truncated = torch.tensor([False]), torch.tensor([False]) frames: list[np.ndarray] = [] obs_pred: list[float] = [] obs_true: list[float] = [] reward_pred: list[float] = [] reward_true: list[float] = [] num_episode = 0 while True: if terminated or truncated: print(f'Eval Episode Return: {ep_ret} \t Cost: {ep_cost}') save_replay_path = os.path.join(self._logger.log_dir, 'video-pic') self._logger.store( { 'EvalMetrics/EpRet': ep_ret, 'EvalMetrics/EpCost': ep_cost, 'EvalMetrics/EpLen': ep_len, }, ) save_video( frames, save_replay_path, fps=30, episode_trigger=lambda x: True, episode_index=current_step + num_episode, name_prefix='eval', ) self.draw_picture( timestep=current_step, num_episode=self._cfgs.evaluation_cfgs.num_episode, pred_state=obs_pred, true_state=obs_true, save_replay_path=save_replay_path, name='obs_mean', ) self.draw_picture( timestep=current_step, num_episode=self._cfgs.evaluation_cfgs.num_episode, pred_state=reward_pred, true_state=reward_true, save_replay_path=save_replay_path, name='reward', ) frames = [] obs_pred, obs_true = [], [] reward_pred, reward_true = [], [] ep_len, ep_ret, ep_cost = 0, 0, 0 obs, _ = self._eval_env.reset() num_episode += 1 if num_episode == self._cfgs.evaluation_cfgs.num_episode: break action = self._select_action(current_step, obs) idx = np.random.choice(self._dynamics.elite_model_idxes, size=1)[0] traj = self._dynamics.imagine( states=obs_dynamics, horizon=1, idx=idx, actions=action.unsqueeze(0), ) pred_next_obs_mean = traj['states'][0][0].mean() pred_reward = traj['rewards'][0][0] obs, reward, cost, terminated, truncated, info = self._eval_env.step(action) obs_dynamics = obs if use_real_input else traj['states'][0][0] true_next_obs_mean = obs.mean() obs_pred.append(pred_next_obs_mean.item()) obs_true.append(true_next_obs_mean.item()) reward_pred.append(pred_reward.item()) reward_true.append(reward.item()) ep_ret += reward.cpu().numpy().item() ep_cost += cost.cpu().numpy().item() ep_len += info['num_step'] frames.append(self._eval_env.render())
[docs] def draw_picture( self, timestep: int, num_episode: int, pred_state: list[float], true_state: list[float], save_replay_path: str = './', name: str = 'reward', ) -> None: """Draw a curve of the predicted value and the ground true value. Args: timestep (int): The current step. num_episode (int): The number of episodes. pred_state (list[float]): The predicted state. true_state (list[float]): The true state. save_replay_path (str): The path for saving replay. name (str): The name of the curve. """ target1 = list(pred_state) target2 = list(true_state) input1 = np.arange(0, np.array(pred_state).shape[0], 1) input2 = np.arange(0, np.array(pred_state).shape[0], 1) pylab.plot(input1, target1, 'r-', label='pred') pylab.plot(input2, target2, 'b-', label='true') pylab.xlabel('Step') pylab.ylabel(name) pylab.xticks(np.arange(0, np.array(pred_state).shape[0], 50)) # set the axis numbers if name == 'reward': pylab.yticks(np.arange(0, 3, 0.2)) else: pylab.yticks(np.arange(0, 1, 0.2)) pylab.legend( loc=3, borderaxespad=2.0, bbox_to_anchor=(0.7, 0.7), ) # set the position of that box for what each line is pylab.grid() # draw grid pylab.savefig( os.path.join( save_replay_path, str(name) + str(timestep) + '_' + str(num_episode) + '.png', ), dpi=200, ) # save as picture pylab.close()