Source code for omnisafe.common.normalizer

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Normalizer."""

from __future__ import annotations

from typing import Any, Mapping

import torch
import torch.nn as nn


[docs]class Normalizer(nn.Module): """Calculate normalized raw_data from running mean and std. References: - Title: Updating Formulae and a Pairwise Algorithm for Computing Sample Variances - Author: Tony F. Chan, Gene H. Golub, Randall J. LeVeque - URL: `Normalizer <http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf>`_ """ _mean: torch.Tensor # running mean _sumsq: torch.Tensor # running sum of squares _var: torch.Tensor # running variance _std: torch.Tensor # running standard deviation _count: torch.Tensor # number of samples _clip: torch.Tensor # clip value def __init__(self, shape: tuple[int, ...], clip: float = 1e6) -> None: """Initialize an instance of :class:`Normalizer`.""" super().__init__() if shape == (): self.register_buffer('_mean', torch.tensor(0.0)) self.register_buffer('_sumsq', torch.tensor(0.0)) self.register_buffer('_var', torch.tensor(0.0)) self.register_buffer('_std', torch.tensor(0.0)) self.register_buffer('_count', torch.tensor(0)) self.register_buffer('_clip', clip * torch.tensor(1.0)) else: self.register_buffer('_mean', torch.zeros(*shape)) self.register_buffer('_sumsq', torch.zeros(*shape)) self.register_buffer('_var', torch.zeros(*shape)) self.register_buffer('_std', torch.zeros(*shape)) self.register_buffer('_count', torch.tensor(0)) self.register_buffer('_clip', clip * torch.ones(*shape)) self._shape: tuple[int, ...] = shape self._first: bool = True @property def shape(self) -> tuple[int, ...]: """Return the shape of the normalize.""" return self._shape @property def mean(self) -> torch.Tensor: """Return the mean of the normalize.""" return self._mean @property def std(self) -> torch.Tensor: """Return the std of the normalize.""" return self._std
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor: """Normalize the data. Args: data (torch.Tensor): The raw data to be normalized. Returns: The normalized data. """ return self.normalize(data)
[docs] def normalize(self, data: torch.Tensor) -> torch.Tensor: """Normalize the data. .. hint:: - If the data is the first data, the data will be used to initialize the mean and std. - If the data is not the first data, the data will be normalized by the mean and std. - Update the mean and std by the data. Args: data (torch.Tensor): The raw data to be normalized. Returns: The normalized data. """ data = data.to(self._mean.device) self._push(data) if self._count <= 1: return data output = (data - self._mean) / self._std return torch.clamp(output, -self._clip, self._clip)
[docs] def _push(self, raw_data: torch.Tensor) -> None: """Update the mean and std by the raw_data. Args: raw_data (torch.Tensor): The raw data to be normalized. """ if raw_data.shape == self._shape: raw_data = raw_data.unsqueeze(0) assert raw_data.shape[1:] == self._shape, 'data shape must be equal to (batch_size, *shape)' if self._first: self._mean = torch.mean(raw_data, dim=0) self._sumsq = torch.sum((raw_data - self._mean) ** 2, dim=0) self._count = torch.tensor( raw_data.shape[0], dtype=self._count.dtype, device=self._count.device, ) self._first = False else: count_raw = raw_data.shape[0] count = self._count + count_raw mean_raw = torch.mean(raw_data, dim=0) delta = mean_raw - self._mean self._mean += delta * count_raw / count sumq_raw = torch.sum((raw_data - mean_raw) ** 2, dim=0) self._sumsq += sumq_raw + delta**2 * self._count * count_raw / count self._count = count self._var = self._sumsq / (self._count - 1) self._std = torch.sqrt(self._var) self._std = torch.max(self._std, 1e-2 * torch.ones_like(self._std))
[docs] def load_state_dict( self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False, ) -> Any: """Load the state_dict to the normalizer. Args: state_dict (Mapping[str, Any]): The state_dict to be loaded. strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`. Defaults to True. Returns: The loaded normalizer. """ self._first = False return super().load_state_dict(state_dict, strict, assign)