# Copyright 2023 OmniSafe Team. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Implementation of VCritic."""from__future__importannotationsimporttorchimporttorch.nnasnnfromomnisafe.models.baseimportCriticfromomnisafe.typingimportActivation,InitFunction,OmnisafeSpacefromomnisafe.utils.modelimportbuild_mlp_network
[docs]classVCritic(Critic):"""Implementation of VCritic. A V-function approximator that uses a multi-layer perceptron (MLP) to map observations to V-values. This class is an inherit class of :class:`Critic`. You can design your own V-function approximator by inheriting this class or :class:`Critic`. Args: obs_dim (int): Observation dimension. act_dim (int): Action dimension. hidden_sizes (list of int): List of hidden layer sizes. activation (Activation, optional): Activation function. Defaults to ``'relu'``. weight_initialization_mode (InitFunction, optional): Weight initialization mode. Defaults to ``'kaiming_uniform'``. num_critics (int, optional): Number of critics. Defaults to 1. """def__init__(self,obs_space:OmnisafeSpace,act_space:OmnisafeSpace,hidden_sizes:list[int],activation:Activation='relu',weight_initialization_mode:InitFunction='kaiming_uniform',num_critics:int=1,)->None:"""Initialize an instance of :class:`VCritic`."""super().__init__(obs_space,act_space,hidden_sizes,activation,weight_initialization_mode,num_critics,use_obs_encoder=False,)self.net_lst:list[nn.Module]self.net_lst=[]foridxinrange(self._num_critics):net=build_mlp_network(sizes=[self._obs_dim,*self._hidden_sizes,1],activation=self._activation,weight_initialization_mode=self._weight_initialization_mode,)self.net_lst.append(net)self.add_module(f'critic_{idx}',net)
[docs]defforward(self,obs:torch.Tensor,)->list[torch.Tensor]:"""Forward function. Specifically, V function approximator maps observations to V-values. Args: obs (torch.Tensor): Observations from environments. Returns: The V critic value of observation. """res=[]forcriticinself.net_lst:res.append(torch.squeeze(critic(obs),-1))returnres