Now let's get to examples from real world. These code fragments taken from official tutorials and popular repositories.
Learn how to improve code and how einops
can help you.
Left: as it was, Right: improved version
# start from importing some stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from einops import rearrange, reduce, asnumpy, parse_shape
from einops.layers.torch import Rearrange, Reduce
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
conv_net_old = Net()
conv_net_new = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
nn.Conv2d(10, 20, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
nn.Dropout2d(),
Rearrange('b c h w -> b (c h w)'),
nn.Linear(320, 50),
nn.ReLU(),
nn.Dropout(),
nn.Linear(50, 10),
nn.LogSoftmax(dim=1)
)
Reasons to prefer new implementation:
conv_net_new[:-1]
. One more reason to prefer nn.Sequential
class SuperResolutionNetOld(nn.Module):
def __init__(self, upscale_factor):
super(SuperResolutionNetOld, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
def SuperResolutionNetNew(upscale_factor):
return nn.Sequential(
nn.Conv2d(1, 64, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, upscale_factor ** 2, kernel_size=3, padding=1),
Rearrange('b (h2 w2) h w -> b (h h2) (w w2)', h2=upscale_factor, w2=upscale_factor),
)
Here is the difference:
Original code is already good - first line shows what kind of input is expected
def gram_matrix_old(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
def gram_matrix_new(y):
b, ch, h, w = y.shape
return torch.einsum('bchw,bdhw->bcd', [y, y]) / (h * w)
It would be great to use just 'b c1 h w,b c2 h w->b c1 c2'
, but einsum supports only one-letter axes
All we did here is just made information about shapes explicit to skip deciphering
class RNNModelOld(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
class RNNModelNew(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(p=dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
def forward(self, input, hidden):
t, b = input.shape
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = rearrange(self.drop(output), 't b nhid -> (t b) nhid')
decoded = rearrange(self.decoder(output), '(t b) token -> t b token', t=t, b=b)
return decoded, hidden
def channel_shuffle_old(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
# transpose
# - contiguous() required if transpose() is used before view().
# See https://github.com/pytorch/pytorch/issues/764
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
def channel_shuffle_new(x, groups):
return rearrange(x, 'b (c1 c2) h w -> b (c2 c1) h w', c1=groups)
While progress is obvious, this is not the limit. As you'll see below, we don't even need to write these couple of lines.
from collections import OrderedDict
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
# transpose
# - contiguous() required if transpose() is used before view().
# See https://github.com/pytorch/pytorch/issues/764
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class ShuffleUnitOld(nn.Module):
def __init__(self, in_channels, out_channels, groups=3,
grouped_conv=True, combine='add'):
super(ShuffleUnitOld, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.grouped_conv = grouped_conv
self.combine = combine
self.groups = groups
self.bottleneck_channels = self.out_channels // 4
# define the type of ShuffleUnit
if self.combine == 'add':
# ShuffleUnit Figure 2b
self.depthwise_stride = 1
self._combine_func = self._add
elif self.combine == 'concat':
# ShuffleUnit Figure 2c
self.depthwise_stride = 2
self._combine_func = self._concat
# ensure output of concat has the same channels as
# original output channels.
self.out_channels -= self.in_channels
else:
raise ValueError("Cannot combine tensors with \"{}\"" \
"Only \"add\" and \"concat\" are" \
"supported".format(self.combine))
# Use a 1x1 grouped or non-grouped convolution to reduce input channels
# to bottleneck channels, as in a ResNet bottleneck module.
# NOTE: Do not use group convolution for the first conv1x1 in Stage 2.
self.first_1x1_groups = self.groups if grouped_conv else 1
self.g_conv_1x1_compress = self._make_grouped_conv1x1(
self.in_channels,
self.bottleneck_channels,
self.first_1x1_groups,
batch_norm=True,
relu=True
)
# 3x3 depthwise convolution followed by batch normalization
self.depthwise_conv3x3 = conv3x3(
self.bottleneck_channels, self.bottleneck_channels,
stride=self.depthwise_stride, groups=self.bottleneck_channels)
self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)
# Use 1x1 grouped convolution to expand from
# bottleneck_channels to out_channels
self.g_conv_1x1_expand = self._make_grouped_conv1x1(
self.bottleneck_channels,
self.out_channels,
self.groups,
batch_norm=True,
relu=False
)
@staticmethod
def _add(x, out):
# residual connection
return x + out
@staticmethod
def _concat(x, out):
# concatenate along channel axis
return torch.cat((x, out), 1)
def _make_grouped_conv1x1(self, in_channels, out_channels, groups,
batch_norm=True, relu=False):
modules = OrderedDict()
conv = conv1x1(in_channels, out_channels, groups=groups)
modules['conv1x1'] = conv
if batch_norm:
modules['batch_norm'] = nn.BatchNorm2d(out_channels)
if relu:
modules['relu'] = nn.ReLU()
if len(modules) > 1:
return nn.Sequential(modules)
else:
return conv
def forward(self, x):
# save for combining later with output
residual = x
if self.combine == 'concat':
residual = F.avg_pool2d(residual, kernel_size=3,
stride=2, padding=1)
out = self.g_conv_1x1_compress(x)
out = channel_shuffle(out, self.groups)
out = self.depthwise_conv3x3(out)
out = self.bn_after_depthwise(out)
out = self.g_conv_1x1_expand(out)
out = self._combine_func(residual, out)
return F.relu(out)
class ShuffleUnitNew(nn.Module):
def __init__(self, in_channels, out_channels, groups=3,
grouped_conv=True, combine='add'):
super().__init__()
first_1x1_groups = groups if grouped_conv else 1
bottleneck_channels = out_channels // 4
self.combine = combine
if combine == 'add':
# ShuffleUnit Figure 2b
self.left = Rearrange('...->...') # identity
depthwise_stride = 1
else:
# ShuffleUnit Figure 2c
self.left = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
depthwise_stride = 2
# ensure output of concat has the same channels as original output channels.
out_channels -= in_channels
assert out_channels > 0
self.right = nn.Sequential(
# Use a 1x1 grouped or non-grouped convolution to reduce input channels
# to bottleneck channels, as in a ResNet bottleneck module.
conv1x1(in_channels, bottleneck_channels, groups=first_1x1_groups),
nn.BatchNorm2d(bottleneck_channels),
nn.ReLU(inplace=True),
# channel shuffle
Rearrange('b (c1 c2) h w -> b (c2 c1) h w', c1=groups),
# 3x3 depthwise convolution followed by batch
conv3x3(bottleneck_channels, bottleneck_channels,
stride=depthwise_stride, groups=bottleneck_channels),
nn.BatchNorm2d(bottleneck_channels),
# Use 1x1 grouped convolution to expand from
# bottleneck_channels to out_channels
conv1x1(bottleneck_channels, out_channels, groups=groups),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
if self.combine == 'add':
combined = self.left(x) + self.right(x)
else:
combined = torch.cat([self.left(x), self.right(x)], dim=1)
return F.relu(combined, inplace=True)
Rewriting the code helped to identify:
Other comments:
class ResNetOld(nn.Module):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNetOld, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def make_layer(inplanes, planes, block, n_blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
# output size won't match input, so adjust residual
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
return nn.Sequential(
block(inplanes, planes, stride, downsample),
*[block(planes * block.expansion, planes) for _ in range(1, n_blocks)]
)
def ResNetNew(block, layers, num_classes=1000):
e = block.expansion
resnet = nn.Sequential(
Rearrange('b c h w -> b c h w', c=3, h=224, w=224),
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
make_layer(64, 64, block, layers[0], stride=1),
make_layer(64 * e, 128, block, layers[1], stride=2),
make_layer(128 * e, 256, block, layers[2], stride=2),
make_layer(256 * e, 512, block, layers[3], stride=2),
# combined AvgPool and view in one averaging operation
Reduce('b c h w -> b c', 'mean'),
nn.Linear(512 * e, num_classes),
)
# initialization
for m in resnet.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return resnet
Changes:
make_layer
doesn't use internal state (that's quite faulty place)class RNNOld(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers,
bidirectional=bidirectional, dropout=dropout)
self.fc = nn.Linear(hidden_dim*2, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
#x = [sent len, batch size]
embedded = self.dropout(self.embedding(x))
#embedded = [sent len, batch size, emb dim]
output, (hidden, cell) = self.rnn(embedded)
#output = [sent len, batch size, hid dim * num directions]
#hidden = [num layers * num directions, batch size, hid dim]
#cell = [num layers * num directions, batch size, hid dim]
#concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers
#and apply dropout
hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
#hidden = [batch size, hid dim * num directions]
return self.fc(hidden.squeeze(0))
class RNNNew(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers,
bidirectional=bidirectional, dropout=dropout)
self.dropout = nn.Dropout(dropout)
self.directions = 2 if bidirectional else 1
self.fc = nn.Linear(hidden_dim * self.directions, output_dim)
def forward(self, x):
#x = [sent len, batch size]
embedded = self.dropout(self.embedding(x))
#embedded = [sent len, batch size, emb dim]
output, (hidden, cell) = self.rnn(embedded)
hidden = rearrange(hidden, '(layer dir) b c -> layer b (dir c)',
dir=self.directions)
# take the final layer's hidden
return self.fc(self.dropout(hidden[-1]))
class FastTextOld(nn.Module):
def __init__(self, vocab_size, embedding_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.fc = nn.Linear(embedding_dim, output_dim)
def forward(self, x):
#x = [sent len, batch size]
embedded = self.embedding(x)
#embedded = [sent len, batch size, emb dim]
embedded = embedded.permute(1, 0, 2)
#embedded = [batch size, sent len, emb dim]
pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)
#pooled = [batch size, embedding_dim]
return self.fc(pooled)
def FastTextNew(vocab_size, embedding_dim, output_dim):
return nn.Sequential(
Rearrange('t b -> t b'),
nn.Embedding(vocab_size, embedding_dim),
Reduce('t b c -> b c', 'mean'),
nn.Linear(embedding_dim, output_dim),
Rearrange('b c -> b c'),
)
Some comments on new code:
Rearrange('b t -> t b'),
class CNNOld(nn.Module):
def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.conv_0 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[0],embedding_dim))
self.conv_1 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[1],embedding_dim))
self.conv_2 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[2],embedding_dim))
self.fc = nn.Linear(len(filter_sizes)*n_filters, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
#x = [sent len, batch size]
x = x.permute(1, 0)
#x = [batch size, sent len]
embedded = self.embedding(x)
#embedded = [batch size, sent len, emb dim]
embedded = embedded.unsqueeze(1)
#embedded = [batch size, 1, sent len, emb dim]
conved_0 = F.relu(self.conv_0(embedded).squeeze(3))
conved_1 = F.relu(self.conv_1(embedded).squeeze(3))
conved_2 = F.relu(self.conv_2(embedded).squeeze(3))
#conv_n = [batch size, n_filters, sent len - filter_sizes[n]]
pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2)
pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)
pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)
#pooled_n = [batch size, n_filters]
cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim=1))
#cat = [batch size, n_filters * len(filter_sizes)]
return self.fc(cat)
class CNNNew(nn.Module):
def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.convs = nn.ModuleList([
nn.Conv1d(embedding_dim, n_filters, kernel_size=size) for size in filter_sizes
])
self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = rearrange(x, 't b -> t b')
emb = rearrange(self.embedding(x), 't b c -> b c t')
pooled = [reduce(conv(emb), 'b c t -> b c', 'max') for conv in self.convs]
concatenated = rearrange(pooled, 'filter b c -> b (filter c)')
return self.fc(self.dropout(F.relu(concatenated)))
class HighwayConv1dOld(nn.Conv1d):
def forward(self, inputs):
L = super(HighwayConv1dOld, self).forward(inputs)
H1, H2 = torch.chunk(L, 2, 1) # chunk at the feature dim
torch.sigmoid_(H1)
return H1 * H2 + (1.0 - H1) * inputs
class HighwayConv1dNew(nn.Conv1d):
def forward(self, inputs):
L = super().forward(inputs)
H1, H2 = rearrange(L, 'b (split c) t -> split b c t', split=2)
torch.sigmoid_(H1)
return H1 * H2 + (1.0 - H1) * inputs
class CBHG_Old(nn.Module):
"""CBHG module: a recurrent neural network composed of:
- 1-d convolution banks
- Highway networks + residual connections
- Bidirectional gated recurrent units
"""
def __init__(self, in_dim, K=16, projections=[128, 128]):
super(CBHG, self).__init__()
self.in_dim = in_dim
self.relu = nn.ReLU()
self.conv1d_banks = nn.ModuleList(
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
padding=k // 2, activation=self.relu)
for k in range(1, K + 1)])
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
in_sizes = [K * in_dim] + projections[:-1]
activations = [self.relu] * (len(projections) - 1) + [None]
self.conv1d_projections = nn.ModuleList(
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
for (in_size, out_size, ac) in zip(
in_sizes, projections, activations)])
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
self.highways = nn.ModuleList(
[Highway(in_dim, in_dim) for _ in range(4)])
self.gru = nn.GRU(
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
def forward_old(self, inputs):
# (B, T_in, in_dim)
x = inputs
# Needed to perform conv1d on time-axis
# (B, in_dim, T_in)
if x.size(-1) == self.in_dim:
x = x.transpose(1, 2)
T = x.size(-1)
# (B, in_dim*K, T_in)
# Concat conv1d bank outputs
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
# (B, T_in, in_dim)
# Back to the original shape
x = x.transpose(1, 2)
if x.size(-1) != self.in_dim:
x = self.pre_highway(x)
# Residual connection
x += inputs
for highway in self.highways:
x = highway(x)
# (B, T_in, in_dim*2)
outputs, _ = self.gru(x)
return outputs
def forward_new(self, inputs, input_lengths=None):
x = rearrange(inputs, 'b t c -> b c t')
_, _, T = x.shape
# Concat conv1d bank outputs
x = rearrange([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks],
'bank b c t -> b (bank c) t', c=self.in_dim)
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
x = rearrange(x, 'b c t -> b t c')
if x.size(-1) != self.in_dim:
x = self.pre_highway(x)
# Residual connection
x += inputs
for highway in self.highways:
x = highway(x)
# (B, T_in, in_dim*2)
outputs, _ = self.gru(self.highways(x))
return outputs
There is still a large room for improvements, but in this example only forward function was changed
Good news: there is no more need to guess order of dimensions. Neither for inputs nor for outputs
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
def forward(self, K, V, Q):
A = torch.bmm(K.transpose(1,2), Q) / np.sqrt(Q.shape[1])
A = F.softmax(A, 1)
R = torch.bmm(V, A)
return torch.cat((R, Q), dim=1)
def attention(K, V, Q):
_, n_channels, _ = K.shape
A = torch.einsum('bct,bcl->btl', [K, Q])
A = F.softmax(A * n_channels ** (-0.5), 1)
R = torch.einsum('bct,btl->bcl', [V, A])
return torch.cat((R, Q), dim=1)
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
class MultiHeadAttentionOld(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
self.layer_norm = nn.LayerNorm(d_model)
self.fc = nn.Linear(n_head * d_v, d_model)
nn.init.xavier_normal_(self.fc.weight)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
class MultiHeadAttentionNew(nn.Module):
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()
self.n_head = n_head
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
self.fc = nn.Linear(n_head * d_v, d_model)
nn.init.xavier_normal_(self.fc.weight)
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, q, k, v, mask=None):
residual = q
q = rearrange(self.w_qs(q), 'b l (head k) -> head b l k', head=self.n_head)
k = rearrange(self.w_ks(k), 'b t (head k) -> head b t k', head=self.n_head)
v = rearrange(self.w_vs(v), 'b t (head v) -> head b t v', head=self.n_head)
attn = torch.einsum('hblk,hbtk->hblt', [q, k]) / np.sqrt(q.shape[-1])
if mask is not None:
attn = attn.masked_fill(mask[None], -np.inf)
attn = torch.softmax(attn, dim=3)
output = torch.einsum('hblt,hbtv->hblv', [attn, v])
output = rearrange(output, 'head b l v -> b l (head v)')
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
Benefits of new implementation
SAGANs are currently SotA for image generation, and can be simplified using same tricks.
class Self_Attn_Old(nn.Module):
""" Self attention Layer"""
def __init__(self,in_dim,activation):
super(Self_Attn_Old,self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize,C,width ,height = x.size()
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
energy = torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)
out = self.gamma*out + x
return out,attention
class Self_Attn_New(nn.Module):
""" Self attention Layer"""
def __init__(self, in_dim):
super().__init__()
self.query_conv = nn.Conv2d(in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros([1]))
def forward(self, x):
proj_query = rearrange(self.query_conv(x), 'b c h w -> b (h w) c')
proj_key = rearrange(self.key_conv(x), 'b c h w -> b c (h w)')
proj_value = rearrange(self.value_conv(x), 'b c h w -> b (h w) c')
energy = torch.bmm(proj_query, proj_key)
attention = F.softmax(energy, dim=2)
out = torch.bmm(attention, proj_value)
out = x + self.gamma * rearrange(out, 'b (h w) c -> b c h w',
**parse_shape(x, 'b c h w'))
return out, attention
While this example was considered to be simplistic, I had to analyze surrounding code to understand what kind of input was expected. You can try yourself.
Additionally now the code works with any dtype, not only double; and new code supports using GPU.
class SequencePredictionOld(nn.Module):
def __init__(self):
super(SequencePredictionOld, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 51)
self.linear = nn.Linear(51, 1)
def forward(self, input, future = 0):
outputs = []
h_t = torch.zeros(input.size(0), 51, dtype=torch.double)
c_t = torch.zeros(input.size(0), 51, dtype=torch.double)
h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
for i in range(future):# if we should predict the future
h_t, c_t = self.lstm1(output, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
outputs = torch.stack(outputs, 1).squeeze(2)
return outputs
class SequencePredictionNew(nn.Module):
def __init__(self):
super(SequencePredictionNew, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 51)
self.linear = nn.Linear(51, 1)
def forward(self, input, future=0):
b, t = input.shape
h_t, c_t, h_t2, c_t2 = torch.zeros(4, b, 51, dtype=self.linear.weight.dtype,
device=self.linear.weight.device)
outputs = []
for input_t in rearrange(input, 'b t -> t b ()'):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
for i in range(future): # if we should predict the future
h_t, c_t = self.lstm1(output, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
return rearrange(outputs, 't b () -> b t')
class SpacialTransformOld(nn.Module):
def __init__(self):
super(Net, self).__init__()
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
class SpacialTransformNew(nn.Module):
def __init__(self):
super(Net, self).__init__()
# Spatial transformer localization-network
linear = nn.Linear(32, 3 * 2)
# Initialize the weights/bias with identity transformation
linear.weight.data.zero_()
linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
self.compute_theta = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
Rearrange('b c h w -> b (c h w)', h=3, w=3),
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
linear,
Rearrange('b (row col) -> b row col', row=2, col=3),
)
# Spatial transformer network forward function
def stn(self, x):
grid = F.affine_grid(self.compute_theta(x), x.size())
return F.grid_sample(x, grid)
That's a good old depth-to-space written manually!
Since GLOW is revertible, it will frequently rely on rearrange
-like operations.
def unsqueeze2d_old(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
factor2 = factor ** 2
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert C % (factor2) == 0, "{}".format(C)
x = input.view(B, C // factor2, factor, factor, H, W)
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(B, C // (factor2), H * factor, W * factor)
return x
def squeeze2d_old(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert H % factor == 0 and W % factor == 0, "{}".format((H, W))
x = input.view(B, C, H // factor, factor, W // factor, factor)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(B, C * factor * factor, H // factor, W // factor)
return x
def unsqueeze2d_new(input, factor=2):
return rearrange(input, 'b (c h2 w2) h w -> b c (h h2) (w w2)', h2=factor, w2=factor)
def squeeze2d_new(input, factor=2):
return rearrange(input, 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2=factor, w2=factor)
squeeze
isn't very helpful: which dimension is squeezed? There is torch.squeeze
, but it's very different.einops
anywaydef YOLO_prediction_old(input, num_classes, num_anchors, anchors, stride_h, stride_w):
bs = input.size(0)
in_h = input.size(2)
in_w = input.size(3)
scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in anchors]
prediction = input.view(bs, num_anchors,
5 + num_classes, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
# Get outputs
x = torch.sigmoid(prediction[..., 0]) # Center x
y = torch.sigmoid(prediction[..., 1]) # Center y
w = prediction[..., 2] # Width
h = prediction[..., 3] # Height
conf = torch.sigmoid(prediction[..., 4]) # Conf
pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred.
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
# Calculate offsets for each grid
grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_w, 1).repeat(
bs * num_anchors, 1, 1).view(x.shape).type(FloatTensor)
grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_h, 1).t().repeat(
bs * num_anchors, 1, 1).view(y.shape).type(FloatTensor)
# Calculate anchor w, h
anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
# Add offset and scale with anchors
pred_boxes = FloatTensor(prediction[..., :4].shape)
pred_boxes[..., 0] = x.data + grid_x
pred_boxes[..., 1] = y.data + grid_y
pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
# Results
_scale = torch.Tensor([stride_w, stride_h] * 2).type(FloatTensor)
output = torch.cat((pred_boxes.view(bs, -1, 4) * _scale,
conf.view(bs, -1, 1), pred_cls.view(bs, -1, num_classes)), -1)
return output
def YOLO_prediction_new(input, num_classes, num_anchors, anchors, stride_h, stride_w):
raw_predictions = rearrange(input, 'b (anchor prediction) h w -> prediction b anchor h w',
anchor=num_anchors, prediction=5 + num_classes)
anchors = torch.FloatTensor(anchors).to(input.device)
anchor_sizes = rearrange(anchors, 'anchor dim -> dim () anchor () ()')
_, _, _, in_h, in_w = raw_predictions.shape
grid_h = rearrange(torch.arange(in_h).float(), 'h -> () () h ()').to(input.device)
grid_w = rearrange(torch.arange(in_w).float(), 'w -> () () () w').to(input.device)
predicted_bboxes = torch.zeros_like(raw_predictions)
predicted_bboxes[0] = (raw_predictions[0].sigmoid() + grid_w) * stride_w # center x
predicted_bboxes[1] = (raw_predictions[1].sigmoid() + grid_h) * stride_h # center y
predicted_bboxes[2:4] = (raw_predictions[2:4].exp()) * anchor_sizes # bbox width and height
predicted_bboxes[4] = raw_predictions[4].sigmoid() # confidence
predicted_bboxes[5:] = raw_predictions[5:].sigmoid() # class predictions
# merging all predicted bboxes for each image
return rearrange(predicted_bboxes, 'prediction b anchor h w -> b (anchor h w) prediction')
We changed and fixed a lot:
.data
, but this has no real effect, as some branches preserve gradient till the endNext time you need to output drawings of you generative models, you can use this trick
device = 'cpu'
plt.imshow(np.transpose(vutils.make_grid(fake_batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
padded = F.pad(fake_batch[:64], [1, 1, 1, 1])
plt.imshow(rearrange(padded, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=8).cpu())
Better code is a vague term; to be specific, code is expected to be:
Provided examples show how to improve these criteria for deep learning code. And einops
helps a lot.