Source code for omnisafe.models.actor.gaussian_learning_actor
# Copyright 2023 OmniSafe Team. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Implementation of GaussianLearningActor."""from__future__importannotationsimporttorchimporttorch.nnasnnfromtorch.distributionsimportDistribution,Normalfromomnisafe.models.actor.gaussian_actorimportGaussianActorfromomnisafe.typingimportActivation,InitFunction,OmnisafeSpacefromomnisafe.utils.modelimportbuild_mlp_network# pylint: disable-next=too-many-instance-attributes
[docs]classGaussianLearningActor(GaussianActor):"""Implementation of GaussianLearningActor. GaussianLearningActor is a Gaussian actor with a learnable standard deviation. It is used in on-policy algorithms such as ``PPO``, ``TRPO`` and so on. Args: obs_space (OmnisafeSpace): Observation space. act_space (OmnisafeSpace): Action space. hidden_sizes (list of int): List of hidden layer sizes. activation (Activation, optional): Activation function. Defaults to ``'relu'``. weight_initialization_mode (InitFunction, optional): Weight initialization mode. Defaults to ``'kaiming_uniform'``. """_current_dist:Normaldef__init__(self,obs_space:OmnisafeSpace,act_space:OmnisafeSpace,hidden_sizes:list[int],activation:Activation='relu',weight_initialization_mode:InitFunction='kaiming_uniform',)->None:"""Initialize an instance of :class:`GaussianLearningActor`."""super().__init__(obs_space,act_space,hidden_sizes,activation,weight_initialization_mode)self.mean:nn.Module=build_mlp_network(sizes=[self._obs_dim,*self._hidden_sizes,self._act_dim],activation=activation,weight_initialization_mode=weight_initialization_mode,)self.log_std:nn.Parameter=nn.Parameter(torch.zeros(self._act_dim),requires_grad=True)
[docs]def_distribution(self,obs:torch.Tensor)->Normal:"""Get the distribution of the actor. .. warning:: This method is not supposed to be called by users. You should call :meth:`forward` instead. Args: obs (torch.Tensor): Observation from environments. Returns: The normal distribution of the mean and standard deviation from the actor. """mean=self.mean(obs)std=torch.exp(self.log_std)returnNormal(mean,std)
[docs]defpredict(self,obs:torch.Tensor,deterministic:bool=False)->torch.Tensor:"""Predict the action given observation. The predicted action depends on the ``deterministic`` flag. - If ``deterministic`` is ``True``, the predicted action is the mean of the distribution. - If ``deterministic`` is ``False``, the predicted action is sampled from the distribution. Args: obs (torch.Tensor): Observation from environments. deterministic (bool, optional): Whether to use deterministic policy. Defaults to False. Returns: The mean of the distribution if deterministic is True, otherwise the sampled action. """self._current_dist=self._distribution(obs)self._after_inference=Trueifdeterministic:returnself._current_dist.meanreturnself._current_dist.rsample()
[docs]defforward(self,obs:torch.Tensor)->Distribution:"""Forward method. Args: obs (torch.Tensor): Observation from environments. Returns: The current distribution. """self._current_dist=self._distribution(obs)self._after_inference=Truereturnself._current_dist
[docs]deflog_prob(self,act:torch.Tensor)->torch.Tensor:"""Compute the log probability of the action given the current distribution. .. warning:: You must call :meth:`forward` or :meth:`predict` before calling this method. Args: act (torch.Tensor): Action from :meth:`predict` or :meth:`forward` . Returns: Log probability of the action. """assertself._after_inference,'log_prob() should be called after predict() or forward()'self._after_inference=Falsereturnself._current_dist.log_prob(act).sum(axis=-1)
@propertydefstd(self)->float:"""Standard deviation of the distribution."""returntorch.exp(self.log_std).mean().item()@std.setterdefstd(self,std:float)->None:device=self.log_std.deviceself.log_std.data.fill_(torch.log(torch.tensor(std,device=device)))