Pytorch Skip connection
Created | |
---|---|
Tags | ML Coding |
from torch import nn
import torch.nn.functional as F
import torchvision
# basic resdidual block of ResNet
# This is generic in the sense, it could be used for downsampling of features.
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=[1, 1], downsample=None):
"""
A basic residual block of ResNet
Parameters
----------
in_channels: Number of channels that the input have
out_channels: Number of channels that the output have
stride: strides in convolutional layers
downsample: A callable to be applied before addition of residual mapping
"""
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride[0],
padding=1, bias=False
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=stride[1],
padding=1, bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
residual = x
# applying a downsample function before adding it to the output
if(self.downsample is not None):
residual = downsample(residual)
out = F.relu(self.bn(self.conv1(x)))
out = self.bn(self.conv2(out))
# note that adding residual before activation
out = out + residual
out = F.relu(out)
return out
https://www.analyticsvidhya.com/blog/2021/08/all-you-need-to-know-about-skip-connections/