Source code for omnisafe.models.actor_critic.constraint_actor_critic
# Copyright 2023 OmniSafe Team. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Implementation of ConstraintActorCritic."""from__future__importannotationsimporttorchfromtorchimportoptimfromomnisafe.models.actor_critic.actor_criticimportActorCriticfromomnisafe.models.baseimportCriticfromomnisafe.models.critic.critic_builderimportCriticBuilderfromomnisafe.typingimportOmnisafeSpacefromomnisafe.utils.configimportModelConfig
[docs]classConstraintActorCritic(ActorCritic):"""ConstraintActorCritic is a wrapper around ActorCritic that adds a cost critic to the model. In OmniSafe, we combine the actor and critic into one this class. +-----------------+-----------------------------------------------+ | Model | Description | +=================+===============================================+ | Actor | Input is observation. Output is action. | +-----------------+-----------------------------------------------+ | Reward V Critic | Input is observation. Output is reward value. | +-----------------+-----------------------------------------------+ | Cost V Critic | Input is observation. Output is cost value. | +-----------------+-----------------------------------------------+ Args: obs_space (OmnisafeSpace): The observation space. act_space (OmnisafeSpace): The action space. model_cfgs (ModelConfig): The model configurations. epochs (int): The number of epochs. Attributes: actor (Actor): The actor network. reward_critic (Critic): The critic network. cost_critic (Critic): The critic network. std_schedule (Schedule): The schedule for the standard deviation of the Gaussian distribution. """def__init__(self,obs_space:OmnisafeSpace,act_space:OmnisafeSpace,model_cfgs:ModelConfig,epochs:int,)->None:"""Initialize an instance of :class:`ConstraintActorCritic`."""super().__init__(obs_space,act_space,model_cfgs,epochs)self.cost_critic:Critic=CriticBuilder(obs_space=obs_space,act_space=act_space,hidden_sizes=model_cfgs.critic.hidden_sizes,activation=model_cfgs.critic.activation,weight_initialization_mode=model_cfgs.weight_initialization_mode,num_critics=1,use_obs_encoder=False,).build_critic('v')self.add_module('cost_critic',self.cost_critic)ifmodel_cfgs.critic.lrisnotNone:self.cost_critic_optimizer:optim.Optimizerself.cost_critic_optimizer=optim.Adam(self.cost_critic.parameters(),lr=model_cfgs.critic.lr,)
[docs]defstep(self,obs:torch.Tensor,deterministic:bool=False,)->tuple[torch.Tensor,...]:"""Choose action based on observation. Args: obs (torch.Tensor): Observation from environments. deterministic (bool, optional): Whether to use deterministic policy. Defaults to False. Returns: action: The deterministic action if ``deterministic`` is True, otherwise the action with Gaussian noise. value_r: The reward value of the observation. value_c: The cost value of the observation. log_prob: The log probability of the action. """withtorch.no_grad():value_r=self.reward_critic(obs)value_c=self.cost_critic(obs)action=self.actor.predict(obs,deterministic=deterministic)log_prob=self.actor.log_prob(action)returnaction,value_r[0],value_c[0],log_prob
[docs]defforward(self,obs:torch.Tensor,deterministic:bool=False,)->tuple[torch.Tensor,...]:"""Choose action based on observation. Args: obs (torch.Tensor): Observation from environments. deterministic (bool, optional): Whether to use deterministic policy. Defaults to False. Returns: action: The deterministic action if ``deterministic`` is True, otherwise the action with Gaussian noise. value_r: The reward value of the observation. value_c: The cost value of the observation. log_prob: The log probability of the action. """returnself.step(obs,deterministic=deterministic)