Source code for omnisafe.utils.distributed

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools of torch.distributed for multi-processing."""

from __future__ import annotations

import os
import subprocess
import sys
from typing import Any

import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp


[docs]def setup_distributed() -> None: """Setup the distributed training environment. Avoid slowdowns caused by each separate process's PyTorch, using more than its fair share of CPU resources. """ old_num_threads = torch.get_num_threads() # decrease number of torch threads for MPI if old_num_threads > 1 and world_size() > 1: fair_num_threads = max(int(torch.get_num_threads() / world_size()), 1) torch.set_num_threads(fair_num_threads) print( f'Proc {get_rank()}: Decreased number of Torch threads from ' f'{old_num_threads} to {torch.get_num_threads()}', flush=True, )
[docs]def get_rank() -> int: """Get the rank of calling process. Examples: >>> # In process 0 >>> get_rank() 0 Returns: The rank of calling process. """ if os.getenv('MASTER_ADDR') is None: return 0 return dist.get_rank()
[docs]def world_size() -> int: """Count active MPI processes. Returns: The number of active MPI processes. """ if os.getenv('MASTER_ADDR') is None: return 1 return dist.get_world_size()
reduce = dist.reduce all_reduce = dist.all_reduce gather = dist.gather all_gather = dist.all_gather broadcast = dist.broadcast scatter = dist.scatter
[docs]def fork( parallel: int, device: str = 'cpu', manual_args: list[str] | None = None, ) -> bool: """The entrance method of multi-processing. Re-launches the current script with workers linked by MPI. Also, terminates the original process that launched it. Taken almost without modification from the Baselines function of the `same name <https://github.com/openai/baselines/blob/master/baselines/common/mpi_fork.py>`_. Args: parallel (int): The number of processes to launch. device (str, optional): The device to be used. Defaults to 'cpu'. manual_args (list of str or None, optional): The arguments to be passed to the new processes. Defaults to None. """ backend = 'gloo' if device == 'cpu' else 'nccl' if os.getenv('MASTER_ADDR') is not None and os.getenv('IN_DIST') is None: dist.init_process_group(backend=backend) os.environ['IN_DIST'] = '1' # check if MPI is already setup.. if parallel > 1 and os.getenv('MASTER_ADDR') is None: # MPI is not yet set up: quit parent process and start N child processes if device != 'cpu': initial_device = int(device.split(':')[-1]) os.environ['USE_DISTRIBUTED'] = '1' if os.getenv('CUDA_VISIBLE_DEVICES') is None: os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( str(initial_device + i) for i in range(parallel) ) num_gpu = int((len(os.environ['CUDA_VISIBLE_DEVICES']) + 1) / 2) assert ( num_gpu >= parallel ), f'Please make sure you have enough available GPUs to run Parallel {parallel}, \ current available Devices are {num_gpu}.' env = os.environ.copy() env.update(MKL_NUM_THREADS='1', OMP_NUM_THREADS='1', IN_MPI='1') args = [ 'torchrun', '--rdzv_backend', 'c10d', '--rdzv_endpoint', 'localhost:0', '--nproc_per_node', str(parallel), ] if manual_args is not None: args += manual_args print(manual_args) else: args += sys.argv print(sys.argv) # this is the parent process, spawn sub-processes.. subprocess.check_call(args, env=env) # noqa: S603 return True return False
[docs]def avg_tensor(value: torch.Tensor) -> None: """Average a torch tensor over MPI processes. Since torch and numpy share same memory space, tensors of dim > 0 can be be manipulated through call by reference, scalars must be assigned. Examples: >>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> avg_tensor(x) >>> x tensor(1.5) Args: value (torch.Tensor): The value to be averaged. """ assert isinstance(value, torch.Tensor) if world_size() > 1: assert len(value.shape) > 0 avg_x = dist_avg(value) value[:] = avg_x[:]
[docs]def avg_grads(module: torch.nn.Module) -> None: """Average contents of gradient buffers across MPI processes. .. note:: This function only works when the training is multi-processing. Examples: >>> # In process 0 >>> x = torch.tensor(1.0, requires_grad=True) >>> y = x ** 2 >>> y.backward() >>> x.grad tensor(2.) >>> # In process 1 >>> x = torch.tensor(2.0, requires_grad=True) >>> y = x ** 2 >>> y.backward() >>> x.grad tensor(4.) >>> avg_grads(x) >>> x.grad tensor(3.) Args: module (torch.nn.Module): The module in which grad need to be averaged. """ if world_size() > 1: for parameter in module.parameters(): if parameter.grad is not None: p_grad = parameter.grad avg_p_grad = dist_avg(parameter.grad) p_grad[:] = avg_p_grad[:]
[docs]def sync_params(module: torch.nn.Module) -> None: """Sync all parameters of module across all MPI processes. .. note:: This function only works when the training is multi-processing. Examples: >>> # In process 0 >>> model = torch.nn.Linear(1, 1) >>> model.weight.data = torch.tensor([[1.]]) >>> model.weight.data tensor([[1.]]) >>> # In process 1 >>> model = torch.nn.Linear(1, 1) >>> model.weight.data = torch.tensor([[2.]]) >>> model.weight.data tensor([[2.]]) >>> sync_params(model) >>> model.weight.data tensor([[1.]]) Args: module (torch.nn.Module): The module to be synchronized. """ if world_size() > 1: for parameter in module.parameters(): p_numpy = parameter.data broadcast(p_numpy, src=0)
[docs]def avg_params(module: torch.nn.Module) -> None: """Average contents of all parameters across MPI processes. Examples: >>> # In process 0 >>> model = torch.nn.Linear(1, 1) >>> model.weight.data = torch.tensor([[1.]]) >>> model.weight.data tensor([[1.]]) >>> # In process 1 >>> model = torch.nn.Linear(1, 1) >>> model.weight.data = torch.tensor([[2.]]) >>> model.weight.data tensor([[2.]]) >>> avg_params(model) >>> model.weight.data tensor([[1.5]]) Args: module (torch.nn.Module): The module in which parameters need to be averaged. """ if world_size() > 1: for parameter in module.parameters(): param_tensor = parameter.data avg_param_tensor = dist_avg(param_tensor) param_tensor[:] = avg_param_tensor[:]
[docs]def dist_avg(value: np.ndarray | torch.Tensor | float) -> torch.Tensor: """Average a tensor over distributed processes. Examples: >>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> dist_avg(x) tensor(1.5) Args: value (np.ndarray, torch.Tensor, int, or float): value to be averaged. Returns: Averaged tensor. """ return dist_sum(value) / world_size()
[docs]def dist_max(value: np.ndarray | torch.Tensor | float) -> torch.Tensor: """Determine global maximum of tensor over distributed processes. Examples: >>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> dist_max(x) tensor(2.) Args: value (np.ndarray, torch.Tensor, int, or float): value to be find max value. Returns: Maximum tensor. """ return dist_op(value, ReduceOp.MAX)
[docs]def dist_min(value: np.ndarray | torch.Tensor | float) -> torch.Tensor: """Determine global minimum of tensor over distributed processes. Examples: >>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> dist_min(x) tensor(1.) Args: value (np.ndarray, torch.Tensor, int, or float): value to be find min value. Returns: Minimum tensor. """ return dist_op(value, ReduceOp.MIN)
[docs]def dist_sum(value: np.ndarray | torch.Tensor | float) -> torch.Tensor: """Sum a tensor over distributed processes. Examples: >>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> dist_sum(x) tensor(3.) Args: value (np.ndarray, torch.Tensor, int, or float): The value to be summed. Returns: Summed tensor. """ return dist_op(value, ReduceOp.SUM)
[docs]def dist_op(value: np.ndarray | torch.Tensor | float, operation: Any) -> torch.Tensor: """Multi-processing operation. .. note:: The operation can be ``ReduceOp.SUM``, ``ReduceOp.MAX``, ``ReduceOp.MIN``. corresponding to :meth:`dist_sum`, :meth:`dist_max`, :meth:`dist_min`, respectively. Args: value (np.ndarray, torch.Tensor, int, or float): The value to be operated. operation (ReduceOp): operation type. Returns: Operated (SUM, MAX, MIN) tensor. """ if world_size() == 1: return torch.as_tensor(value, dtype=torch.float32) value_, scalar = ([value], True) if np.isscalar(value) else (value, False) value = torch.as_tensor(value_, dtype=torch.float32) all_reduce(value, op=operation) return value[0] if scalar else value
[docs]def dist_statistics_scalar( value: torch.Tensor, with_min_and_max: bool = False, ) -> tuple[torch.Tensor, ...]: r"""Get mean/std and optional min/max of scalar x across MPI processes. Examples: >>> # In process 0 >>> x = torch.tensor(1.0) >>> # In process 1 >>> x = torch.tensor(2.0) >>> dist_statistics_scalar(x) (tensor(1.5), tensor(0.5)) Args: value (torch.Tensor): Value to be operated. with_min_and_max (bool, optional): whether to return min and max. Defaults to False. Returns: A tuple of the [mean, std] or [mean, std, min, max] of the input tensor. """ global_sum = dist_sum(torch.sum(value)) global_n = dist_sum(torch.tensor(len(value)).to(os.getenv('OMNISAFE_DEVICE', 'cpu'))) mean = global_sum / global_n global_sum_sq = dist_sum(torch.sum((value - mean) ** 2)) # compute global std std = torch.sqrt(global_sum_sq / global_n) if with_min_and_max: global_min = dist_min(value) global_max = dist_max(value) return mean, std, global_min, global_max return mean, std