Source code for mlbench_core.utils.pytorch.checkpoint

import enum
import json
import os
import shutil

import dill
import torch


class CheckpointFreq(enum.IntEnum):
    # Checkpoint every epoch or best epoch
    ALL = 1
    BEST = 2
    NONE = 3


[docs]class Checkpointer(object): """A class for handling checkpoint saving and loading. Args: ckpt_run_dir (str): The path of the checkpoint directory. rank (int): The rank of the eurrent worker. freq (int): The frequency of checkpointing. Default: `CheckpointFreq.BEST` save_stats (bool): Save stats to additional text files. Default: `True` """ def __init__(self, ckpt_run_dir, rank, freq=CheckpointFreq.BEST, save_stats=True): self.dirname = ckpt_run_dir self.rank = rank self.freq = freq self.save_stats = save_stats def save(self, tracker, model, optimizer, scheduler, is_best): """Saves a checkpoint Args: tracker (:obj:`mlbench_core.utils.pytorch.helpers.Tracker`): The metrics tracker object model (:obj:`torch.nn.Module`): a pytorch model to be trained and validated. optimizer (:obj:`torch.optim.Optimizer`, optional): an optimizer for the given model. scheduler (:obj:`mlbench_core.lr_scheduler.pytorch.lr.*`, optional): a scheduler for hyper-parameters. is_best (bool): Whether the current model is a new best scoring one """ state = { "tracker": tracker, "model": model.state_dict(), "optimizer": optimizer.state_dict() if optimizer is not None else None, "scheduler": scheduler.state_dict() if scheduler is not None else None, "freq": self.freq, } filename = "{epoch}_{rank}.pth.tar".format( epoch=tracker.current_epoch, rank=self.rank ) checkpoint_path = os.path.join(self.dirname, filename) best_model_path = os.path.join(self.dirname, "model_best.pth.tar") if self.freq == CheckpointFreq.ALL: torch.save(state, checkpoint_path, pickle_module=dill) if is_best: shutil.copyfile(checkpoint_path, best_model_path) elif self.freq == CheckpointFreq.BEST: torch.save(state, best_model_path, pickle_module=dill) elif self.freq != CheckpointFreq.NONE: raise NotImplementedError self._maybe_save_stats(tracker.records, tracker.current_epoch, self.rank) def _maybe_save_stats(self, records, epoch, rank): """Save the records in the tracker.""" if self.save_stats: filename = os.path.join(self.dirname, "{}_{}.json".format(epoch, rank)) with open(filename, "w") as f: json.dump(records, f) @staticmethod def load(ckpt_run_dir, rank, model, optimizer, scheduler): """Loads a checkpoint Args: ckpt_run_dir (str): Folder path of checkpoint directory rank (int): The rank of the current worker model (:obj:`torch.nn.Module`): a pytorch model to be trained and validated. optimizer (:obj:`torch.optim.Optimizer`, optional): an optimizer for the given model. scheduler (:obj:`mlbench_core.lr_scheduler.pytorch.lr.*`, optional): a scheduler for hyper-parameters. Returns: A tuple of `(Checkpointer, model, optimizer, scheduler)` """ checkpoint_path = determine_restore_ckpt_path(rank, ckpt_run_dir) if not os.path.isfile(checkpoint_path): raise FileNotFoundError( "No checkpoint found at '{}' for rank '{}'".format(ckpt_run_dir, rank) ) checkpoint = torch.load(checkpoint_path, pickle_module=dill) model.load_state_dict(checkpoint["model"]) if optimizer is not None: optimizer.load_state_dict(checkpoint["optimizer"]) if scheduler is not None: scheduler.load_state_dict(checkpoint["scheduler"]) tracker = checkpoint["tracker"] freq = checkpoint["freq"] checkpointer = Checkpointer(ckpt_run_dir, rank, freq) return checkpointer, model, optimizer, scheduler, tracker @staticmethod def load_model_by_epoch(ckpt_run_dir, rank, epoch, model): """Loads a checkpoint Args: ckpt_run_dir (str): Folder path of checkpoint directory rank (int): The rank of the current worker epoch (int): Epoch of the model to be loaded. model (:obj:`torch.nn.Module`): a pytorch model to be trained and validated. Returns: `model` """ checkpoint_path = os.path.join( ckpt_run_dir, "{}_{}.pth.tar".format(epoch, rank) ) if not os.path.isfile(checkpoint_path): raise FileNotFoundError( "No checkpoint found at '{}' for rank '{}'".format(ckpt_run_dir, rank) ) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint["model"]) return model @staticmethod def checkpoint_exists(ckpt_run_dir, rank, epoch): """Check if a checkpoint exists. Args: ckpt_run_dir (str): Folder path of checkpoint directory rank (int): The rank of the current worker epoch (int): Epoch of the model to be loaded. Returns: `model` """ checkpoint_path = os.path.join( ckpt_run_dir, "{}_{}.pth.tar".format(epoch, rank) ) if not os.path.isfile(checkpoint_path): raise FileNotFoundError( "No checkpoint found at '{}' for rank '{}'".format(ckpt_run_dir, rank)
) def determine_restore_ckpt_path(rank, checkpoint_root): """Determine the checkpoint path to restore. Args: rank (int): The rank of the current worker checkpoint_root (str): Folder path of checkpoint directory Returns: The path of the newest checkpoint for this worker """ ckpt_ids = os.listdir(checkpoint_root) ckpt_ids = list(filter(lambda x: x.endswith(".pth.tar"), ckpt_ids)) ckpt_ids = list(set(ckpt_ids) - set(["model_best.pth.tar"])) ckpt_ids = filter( lambda x: x.split("_")[1][: -len(".pth.tar")] == str(rank), ckpt_ids ) latest = sorted(ckpt_ids, reverse=True, key=lambda x: int(x.split("_")[0])) path = os.path.join(checkpoint_root, latest[0]) return path