"""Contains definitions for Residual Networks.
Residual networks were originally proposed in :cite:`he2016deep` . Then they improve the :cite:`he2016identity`
Here we refer to the settings in :cite:`he2016deep` as `v1` and :cite:`he2016identity` as `v2`.
Since `torchvision resnet <https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py>`_
has already implemented.
* ResNet-18
* ResNet-34
* ResNet-50
* ResNet-101
* ResNet-152
for image net. Here we only implemented the remaining models
* ResNet-20
* ResNet-32
* ResNet-44
* ResNet-56
for CIFAR-10 dataset. Besides, their implementation uses projection shortcut by default.
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
from mlbench_core.controlflow.pytorch.helpers import convert_dtype
_DEFAULT_RESNETCIFAR_VERSION = 1
def batch_norm(num_features):
"""Create a batch normalization layer.
See the Disclaimers in Kaiming's
`repository <https://github.com/KaimingHe/deep-residual-networks/tree/a7026cb6d478e131b765b898c312e25f9f6dc031>`_.
* compute the mean and variance on a sufficiently large traing batch instead of moving average;
* learn gamma and beta in affine function.
:param num_features: number of features passed to batch normalization
:type num_features: int
"""
return nn.BatchNorm2d(
num_features=num_features,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
)
def conv3x3(in_channels, out_channels, stride=1):
"""3x3 convolution with padding."""
return nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
)
class BasicBlockV1(nn.Module):
"""The basic block in :cite:`he2016deep` is used for shallower ResNets.
The activation functions (ReLU and BN) are viewed as post-activation of the weight layer.
.. note::
This class is similar to `BasicBlock` in
`resnet <https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py>`_.
but with different nn.BatchNorm2d configuration.
"""
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
"""
:param in_channels: input channels
:type in_channels: int
:param out_channels: output channels
:type out_channels: int
:param stride: stride of the first layer, defaults to 1
:type stride: int, optional
:param downsample: projection identity map or no downsample, defaults to None
:type downsample: nn.module or None, optional
"""
super(BasicBlockV1, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn1 = batch_norm(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(out_channels, out_channels)
self.bn2 = batch_norm(out_channels)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = self.downsample(x) if self.downsample is not None else x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# Shortcut connection.
out += residual
out = self.relu(out)
return out
class BasicBlockV2(nn.Module):
"""The basic block in :cite:`he2016identity` is used for shallower ResNets.
The activation functions (ReLU and BN) are viewed as pre-activation of the weight layer.
"""
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
"""
:param in_channels: input channels
:type in_channels: int
:param out_channels: output channels
:type out_channels: int
:param stride: stride of the first layer, defaults to 1
:type stride: int, optional
:param downsample: projection identity map or no downsample, defaults to None
:type downsample: nn.module or None, optional
"""
super(BasicBlockV2, self).__init__()
self.bn1 = batch_norm(in_channels)
self.relu = nn.ReLU(inplace=True)
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn2 = batch_norm(out_channels)
self.conv2 = conv3x3(out_channels, out_channels)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = self.downsample(x) if self.downsample is not None else x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
# Shortcut connection.
out += residual
return out
# class BottleneckBlockV1(nn.Module):
# """Bottleneck building block proposed in :cite:`he2016deep` (post-activation)."""
# pass
# class BottleneckBlockV2(nn.Module):
# """Bottleneck building block proposed in :cite:`he2016identity` (post-activation)."""
# pass
[docs]class ResNetCIFAR(nn.Module):
"""Basic ResNet implementation.
Args:
resnet_size (int): Number of layers
bottleneck (bool): Whether to use a bottleneck layer (``Not Implemented``)
num_classes (int): Number of output classes
version (int): Resnet version (1 or 2). Default: ``1``
"""
def __init__(
self, resnet_size, bottleneck, num_classes, version=_DEFAULT_RESNETCIFAR_VERSION
):
super(ResNetCIFAR, self).__init__()
if resnet_size % 6 != 2:
raise ValueError(
"The resnet_size should be (6 * num_blocks + 2). Got {}.".format(
resnet_size
)
)
num_blocks = (resnet_size - 2) // 6
if version not in (1, 2):
raise ValueError("Resnet version should be 1 or 2, got {}.".format(version))
if bottleneck:
raise NotImplementedError
else:
if version == 1:
block = BasicBlockV1
elif version == 2:
block = BasicBlockV2
else:
raise NotImplementedError
# The first layer
if version == 1 or version == 2:
self.prep = nn.Sequential(
conv3x3(in_channels=3, out_channels=16, stride=1),
batch_norm(num_features=16),
nn.ReLU(),
)
else:
raise NotImplementedError
# 6n layers
self.conv_1 = self._make_layer(
block, in_channels=16, out_channels=16, num_blocks=num_blocks, init_stride=1
)
self.conv_2 = self._make_layer(
block, in_channels=16, out_channels=32, num_blocks=num_blocks, init_stride=2
)
self.conv_3 = self._make_layer(
block, in_channels=32, out_channels=64, num_blocks=num_blocks, init_stride=2
)
# Add an average pooling layer:
# the output of conv_3 has shape H=W=8
# the output average pooling will be (batch_size, channels, 1, 1)
self.avgpool = nn.AvgPool2d(8, stride=8)
self.classifier = nn.Linear(in_features=64, out_features=num_classes, bias=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, in_channels, out_channels, num_blocks, init_stride=1):
"""Create a block of 2*n depth.
.. note::
In :cite:`he2016deep` there are two types of shortcuts: identity and projection. Here we use the following:
* identity shortcut for same number of channels
* projection shortcut for increasing number of channels
"""
# by the design of ResNet, if the init_stride > 1, then out_channels > in_channels.
# For project shortcut, the extra channels are created using a 1*1 convolution.
# For identity shortcut, zeros are padded.
if init_stride == 1:
downsample = None
else:
# Use projection when the dimension of channel increases.
downsample = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=init_stride,
bias=False,
),
batch_norm(num_features=out_channels),
)
# Maybe use downsample in the first block.
layers = [
block(in_channels, out_channels, stride=init_stride, downsample=downsample)
]
for _ in range(1, num_blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.prep(x)
x = self.conv_1(x)
x = self.conv_2(x)
x = self.conv_3(x)
x = self.avgpool(x)
# The plane has shape (1, 1)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
""" Version 2 of ResNet. """
class PreActBlock(nn.Module):
r""""Pre-activation Resnet Block used in ResNet-18"""
def __init__(self, in_channels, out_channels, stride=1):
super(PreActBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
)
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
return out + shortcut
class ResNet18CIFAR10(nn.Module):
"""ResNet implementation for CIFAR-10
The ResNet structure defined in :cite:`he2016deep` and :cite:`he2016identity`.
For CIFAR-10 dataset, the ResNet are configured to have 6n+2 layers where fixing n={3,5,7,9}
gives ResNet-20,32,44,56 seperately. The input image is assumed to have a shape of 32*32 pixels.
Args:
layers (:obj:`list` of :obj:`int`): List of resnet blocks per layer. Must contain 4 elements.
num_classes (int): Number of output classes. Default: ``1000``
"""
def __init__(self, layers, num_classes=1000):
super(ResNet18CIFAR10, self).__init__()
self.prep = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.layers = nn.Sequential(
self._make_layer(64, 64, layers[0], stride=1),
self._make_layer(64, 128, layers[1], stride=2),
self._make_layer(128, 256, layers[2], stride=2),
self._make_layer(256, 256, layers[3], stride=2),
)
self.classifier = nn.Linear(512, num_classes)
def _make_layer(self, in_channels, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(
PreActBlock(
in_channels=in_channels, out_channels=out_channels, stride=stride
)
)
in_channels = out_channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.prep(x)
x = self.layers(x)
x_avg = F.adaptive_avg_pool2d(x, (1, 1))
x_avg = x_avg.view(x_avg.size(0), -1)
x_max = F.adaptive_max_pool2d(x, (1, 1))
x_max = x_max.view(x_max.size(0), -1)
x = torch.cat([x_avg, x_max], dim=-1)
x = self.classifier(x)
return x
def resnet18_bkj(num_classes):
"""Constructs a ResNet-18 model from DAWN.
This
`implementation <https://github.com/bkj/basenet/blob/49b2b61e5b9420815c64227c5a10233267c1fb14/examples/cifar10.py>`_
comes from which gives results in
`DAWNBench <https://github.com/stanford-futuredata/dawn-bench-entries/blob/master/CIFAR10/train/basenet.json>`_.
Args:
num_classes (int): The number of output classes.
"""
model = ResNet18CIFAR10([2, 2, 2, 2], num_classes=num_classes)
return model
def get_resnet_model(model, version, dtype, num_classes=1000, use_cuda=False):
"""Create a resnet model
Args:
model (str): The name of the model, e.g. `resnet18`
version (int): The resnet version to use, `1`or `2`
num_classes (int): The number of output classes. Default: `1000`
use_cuda (bool): Whether to train on the GPU or not. Default: `False`
Returns:
A `torch.nn.Module` Resnet Model
"""
if model == "resnet18":
model = resnet18_bkj(num_classes)
elif model in ["resnet20", "resnet32", "resnet44", "resnet56", "resnet110"]:
resnet_size = int(model[len("resnet") :])
model = ResNetCIFAR(resnet_size, False, 10, version=version)
else:
raise NotImplementedError("{}_{} is not implemented.".format(model, version))
model = convert_dtype(dtype, model)
if use_cuda:
model = model.cuda()
return model