Source code for mlbench_core.utils.pytorch

import os
from contextlib import contextmanager

import torch
import torch.distributed as dist

from .helpers import config_logging, config_path, config_pytorch
from .topology import FCGraph

__all__ = ["initialize_backends", "FCGraph"]


@contextmanager
[docs]def initialize_backends( comm_backend="mpi", hosts=None, rank=-1, logging_level="INFO", logging_file="/mlbench.log", use_cuda=False, seed=None, cudnn_deterministic=False, ckpt_run_dir="/checkpoints", delete_existing_ckpts=False, ): """Initializes the backends. Sets up logging, sets up pytorch and configures paths correctly. Args: config (:obj:`types.SimpleNamespace`): a global object containing all of the config. Returns: (:obj:`types.SimpleNamespace`): a global object containing all of the config. """ if not (hasattr(dist, "_initialized") and dist._initialized): if comm_backend in [dist.Backend.GLOO, dist.Backend.NCCL]: if comm_backend == dist.Backend.NCCL: assert ( torch.cuda.is_available() ), "Invalid use of NCCL backend without CUDA support available" hosts = hosts.split(",") os.environ["MASTER_ADDR"] = hosts[0] os.environ["MASTER_PORT"] = "29500" os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(len(hosts)) dist.init_process_group(comm_backend) config_logging(logging_level, logging_file) rank, world_size, graph = config_pytorch(use_cuda, seed, cudnn_deterministic) config_path(ckpt_run_dir, delete_existing_ckpts) yield rank, world_size, graph dist.destroy_process_group()