Source code for omnisafe.utils.model

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module contains the helper functions for the model."""

from __future__ import annotations

import numpy as np
from torch import nn

from omnisafe.typing import Activation, InitFunction


[docs]def initialize_layer(init_function: InitFunction, layer: nn.Linear) -> None: """Initialize the layer with the given initialization function. The ``init_function`` can be chosen from: ``kaiming_uniform``, ``xavier_normal``, ``glorot``, ``xavier_uniform``, ``orthogonal``. Args: init_function (InitFunction): The initialization function. layer (nn.Linear): The layer to be initialized. """ if init_function == 'kaiming_uniform': nn.init.kaiming_uniform_(layer.weight, a=np.sqrt(5)) elif init_function == 'xavier_normal': nn.init.xavier_normal_(layer.weight) elif init_function in ['glorot', 'xavier_uniform']: nn.init.xavier_uniform_(layer.weight) elif init_function == 'orthogonal': nn.init.orthogonal_(layer.weight, gain=np.sqrt(2)) else: raise TypeError(f'Invalid initialization function: {init_function}')
[docs]def get_activation( activation: Activation, ) -> type[nn.Identity | nn.ReLU | nn.Sigmoid | nn.Softplus | nn.Tanh]: """Get the activation function. The ``activation`` can be chosen from: ``identity``, ``relu``, ``sigmoid``, ``softplus``, ``tanh``. Args: activation (Activation): The activation function. Returns: The activation function, ranging from ``nn.Identity``, ``nn.ReLU``, ``nn.Sigmoid``, ``nn.Softplus`` to ``nn.Tanh``. """ activations = { 'identity': nn.Identity, 'relu': nn.ReLU, 'sigmoid': nn.Sigmoid, 'softplus': nn.Softplus, 'tanh': nn.Tanh, } assert activation in activations return activations[activation]
[docs]def build_mlp_network( sizes: list[int], activation: Activation, output_activation: Activation = 'identity', weight_initialization_mode: InitFunction = 'kaiming_uniform', ) -> nn.Module: """Build the MLP network. Examples: >>> build_mlp_network([64, 64, 64], 'relu', 'tanh') Sequential( (0): Linear(in_features=64, out_features=64, bias=True) (1): ReLU() (2): Linear(in_features=64, out_features=64, bias=True) (3): ReLU() (4): Linear(in_features=64, out_features=64, bias=True) (5): Tanh() ) Args: sizes (list of int): The sizes of the layers. activation (Activation): The activation function. output_activation (Activation, optional): The output activation function. Defaults to ``identity``. weight_initialization_mode (InitFunction, optional): Weight initialization mode. Defaults to ``'kaiming_uniform'``. Returns: The MLP network. """ activation_fn = get_activation(activation) output_activation_fn = get_activation(output_activation) layers = [] for j in range(len(sizes) - 1): act_fn = activation_fn if j < len(sizes) - 2 else output_activation_fn affine_layer = nn.Linear(sizes[j], sizes[j + 1]) initialize_layer(weight_initialization_mode, affine_layer) layers += [affine_layer, act_fn()] return nn.Sequential(*layers)