Hello and welcome to Part 4 in the series on autoregressive generative models. We have covered some of the theory of these models and have implemented causality and probability to allow a CNN to be capable of generating new images based on a training set.

In this post we are going to focus on the building blocks that are used to construct the leading model of this class; pixelSNAIL (as well as some of the other variants of pixelCNNs). This brings in quite a few ideas from the field of deep learning and in some cases I will touch on them only briefly; and aim to cover them in more detail in other posts.

Gated Residual Block

We have discussed the horizontal and vertical streams which preserve causality and remove the blind spot issue which is created by using masked kernels in post 3. Now to build the architecture up we create a few blocks-with-jobs and link these together repeatedly to form a deep architecture.

Because these architectures are very deep; it is important to have a residual unit as the basic element (read the original paper on residual blocks if this idea is unfamiliar).

The GRB is relatively simple and forms the building block for the entire model; it is formed from a couple of causal-convolutions followed by the gating function. A skip connection is added from input to output (this enforces the residual aspect).

The GRB also adds in an optional auxiliary input (shown greyed out in the figures) which could be:

  • A conditioning input; for example the class that an image belongs to. This allows samples to be generated which are dependent on the chosen class.
  • A conditioning input based on the input image (such as in VQVAE-2 where a separate model produces multiple outputs which encode an image at various resolutions).
  • The horizontal and vertical streams can be fed into the input/auxiliary input to combine them.

GRB = (residual connection) + (auxiliary input) + (gating mechanism)

We covered the gating mechanism in the last post.

The Gated Residual Block (GRB) as it is implemented in the original pixelSNAIL codebase (left) and also with the more recent PyTorch version as part of the VQ-VAE 2 model.

Attention Block

Attention mechanisms tend to be motivated by the need to use “global” information to inform “local” scale operations; so it’s a way to allow a network to focus itself by bringing in information from a much larger receptive field than perhaps offered by the regular convolutional layers of the network. This could be crucial for long-term trends in images to help generate a consistent image over many far apart pixels.

The pixelCNN architecture is a natural place for attention; since the GRBs at the bottom of the network (close to the input) have a limited receptive field; but may want to use information from pixels far in their past. Perhaps if an image has multiple objects; then the attention can be placed on an object far away to help decide on current pixels; think about a picture at noon or at midnight; perhaps the presence of a sun/moon in the start of the image generation aids in the distribution object colours and appearances later.

The pixelSNAIL model was the first to implement the idea of attention into the pixelCNN architecture. The attention idea is implemented using a self-attention block; unfortunately the paper does not give enough details to reproduce this exactly and you must refer to the code implementation.

I’m skeptical of the extent of the use of trying to understand what the attention is doing… I’ll give the gist of it here but I’m not sure there is a concrete way to prove the network is doing what we expect it to.

The ideas of key, query and value are important.

  • The query is a spatial location, where we “query” how that location should impact the other locations. The feature maps undergo a linear transformation to produce the queries.
  • The key is also a linear embedding of all the spatial locations into some smaller space; think of keys indexing python dictionaries, they act like a lookup of a particular spatial location.
  • The value is another linear embedding of the inputs; and these are the features of each spatial location that will be fed through the block.

The block works in the following way : The values are each weighted by an overall weight (called the attention) based upon a similarity function that compares the query (at the value’s spatial location) to the keys for all other spatial locations.

This self attention is therefore global because the spatial maps of the full image are used.

In the way we have described the attention; it makes sense that the same input feature map $X$ is used for the $K,Q,V$ but in pixelSNAIL; the $K,V$ have an additional set of features $h(X)$ coming from a series of GRB. It isn’t clear why this is necessary; since $h(X)$ gets passed alongside the attention block anyway.

The block enforces that the number of output channels equals the number that went in; this is an aesthetic choice.

We can call this block “multi-headed causal attention”

The rough outline of what happens in this block is;

  • Each of the $K,Q,V$ are passed through independent linear maps to project them into smaller $M$-dimensional spaces (typically 16 dimensional).
  • A number of heads are used ($N$); the logic here is that given the total number of input features from the input; they may refer to different abstract groups of features that may place attention in different places. Having multiple attention heads allows the model to pay attention to different “classes” of features. Typically 8 heads are used; and the number of heads times the dimensionality of the smaller projected space (16) must equal the number of input channels (128). So using more channels allows either more attention heads or a larger dimensional space.
  • The Keys and Querys are combined to form attention maps through a matrix multiplication. So each 16-dim vector from K, Q has a dot product which represents the strength of the connection between the two spatial locations, i.e. how similar they are.
  • The attention maps are now a $(N,H\times W,H\times W)$ matrix; i.e. for each head $N$ an attention map is a $(H\times W, H\times W)$ matrix where the $(i,j)$th element tells us how much attention to pay to spatial location $i$ coming from spatial location $j$.
  • The attention maps are passed through a softmax, so that for each index $i$ (query location) we have a normalised probability distribution over the spatial locations $j$ where a given values is in [0,1] and the sum over all $j$ is 1. The attention is therefore weighting the locations of particular importance to location $i$.
  • Now for each attention head; the maps are multiplied by the values, which are a $(H \times W,M)$ matrix; giving the embedded vector representation of the features at each location. We end up with a matrix (256,16) giving the same embedded vector representations at each location, but weighted by the attention.
  • Finally the head and vector indices are combined to produce the features output by the attention block. The tensor is reshaped back to $(B,H,W,C)$, with $C=M\times N$.
  • The attention block can be placed anywhere in the network as it doesn’t change the input shape; it could go at the start where its ability to bring in global information could help the network learn abstract features. Or, it could go at the end where it can make use of strong abstract features to include long range dependencies.

Causality in the Attention Block

As discussed in post 3; we must be careful to preserve causality and so far; we have broken it; since the attention maps link all spatial locations together. Luckily this is very easy to fix using a mask.

The attention maps at coordinates $(i,j)$ link spatial location $i$ (the query) to spatial location $j$ (the key) where $i,j \in \{0, \cdots, H \cdot W \}$ where $H \cdot W=1024$ for the cifar-10 set. Therefore to preserve causality we must ensure all $j > i$ entries of the attention map are zero. We do that by multiplying the attention maps with the mask;

Implementation in Keras

We can implement the causal attention block as a Keras layer; although it doesn’t actually have any trainable weights in the way I have coded it; so it could just as easily be a regular function. I prefer it as a layer as it helps in debugging and keeps the attention operation clearly defined in the model.summary() and associated functionality;

class CausalAttention(Layer):

    def __init__(self, **kwargs):
        super(CausalAttentionLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        super(CausalAttentionLayer, self).build(input_shape) 

    def call(self, x):
        key, query, value = x

        nr_chns = key.shape[-1]
        mixin_chns = value.shape[-1]

        canvas_size = int(np.prod(key.shape[1:-1]))
        canvas_size_q = int(np.prod(query.shape[1:-1]))
        causal_mask = get_causal_mask(canvas_size_q, 1)

        q_m = Reshape((canvas_size_q, nr_chns))(tf.debugging.check_numerics(query, "badQ"))
        k_m = Reshape((canvas_size, nr_chns))(tf.debugging.check_numerics(key, "badK"))
        v_m = Reshape((canvas_size, mixin_chns))(tf.debugging.check_numerics(value, "badV"))

        dot = tf.matmul(q_m, k_m, transpose_b=True)
        dk = tf.cast(nr_chns, tf.float32)
        causal_probs = tf.nn.softmax(dot / tf.math.sqrt(dk) - 1e9 * causal_mask, axis=-1)  * causal_mask
        mixed = tf.matmul(causal_probs, v_m)

        out = Reshape(query.shape[1:-1] + [mixin_chns])(mixed)
        return out

Network-in-Network Block

The architecture also uses network-in-network (NIN) blocks which are used to densely connect the channels of a layer. The spatial coordinates are not mixed; this is the same as a 1×1 Convolution (and so does not break causality). Without any non-linear activation function; this block also performs the linear mapping required in the attention block.

This can be implemented easily as a layer in Keras (which allows us to add some extra features like weight normalization (see later in this post)

class NetworkInNetwork(Layer):

    def __init__(self, filters, activation=None, weight_norm=True, **kwargs):
        self.filters = filters
        self.activation = activation
        super(NetworkInNetwork, self).__init__(**kwargs)

        if weight_norm:
            self.dense = WeightNormalization(Dense(self.filters))
        else:
            self.dense = Dense(self.filters)
        self.activation = Activation(self.activation)

    def build(self, input_shape):
        super(NetworkInNetwork, self).build(input_shape)

    def call(self, x):
        x = self.dense(x)
        if self.activation is not None:
            x = self.activation(x)
        return x

    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.filters)

    def get_config(self):
        config = super(NetworkInNetwork, self).get_config()
        return config

Putting the Blocks Together

pixelCNN Architecture

The pixelCNN model is one of a series of variants introduced in the pixelRNN paper. It’s architecture is that paper is incredibly simple, just a series of simple residual blocks. Masked kernels are used and so this model suffers from a blind spot.

In the follow up paper; the idea of the residual block is expanded; the block has a horizontal and vertical stream and also involves gating. These blocks are chained together, with their larger model having 20 such blocks.

The pixelCNN++ architecture introduces a U-Net style; where the image is spatial contracted and then expanded within the network. This approach will be covered elsewhere and requires extra theory such as transposed convolutions.

pixelSNAIL Architecture

PixelSNAIL develops the key ideas of pixelCNN but introduces some new features via the attention mechanism. We bring together the ideas of;

  • Causality
  • Attention
  • Gating and Residual Units
  • Class Conditioning
  • Weight Normalization

In more detail, the residual block and attention block are combined together in a single “unit” which can be called a pixel block; the pixel block is then repeated a number of times (12 in the paper). The code is available and shares its structure with that of pixelCNN++ (the code is not pretty at all… compare it in number of lines to the PyTorch implementation)

The two popular variants of the architecture are shown above; the original authors’ code (with the task of being a generative model itself) and the version used in the VQ-VAE 2 model (with a different task of fitting a prior on the VQVAE latent space, and also written with PyTorch).

Efficient Training

We now cover a couple of extra features of the pixelCNN models which are generally focused around making the training as efficient as possible.

Because the pixelSNAIL model is very large and takes multiple GPUs to train; techniques to speed up convergence are essential. Beyond using a good optimizer (Adam) there are several additional techniques implemented which I will touch on briefly.

Learning Rate Decay

Learning rate decay is a standard technique and also very easy to implement (Keras optimizers have it as a standard parameter). The learning rate is decreased at each epoch by a factor $\gamma \lesssim 1$ so that as the model converges it takes smaller steps in parameter space, which is better when closer to a minima. For the paper $\gamma = 0.999998$ with an initial learning rate of $0.001$. Bear in mind that this model likely ran for thousands of epochs (the learning rate decreases by only 1% after 5000 epochs); if you have a smaller model which converges faster you should reduce $\gamma$ so that by the time you reach your final epoch the learning rate has decreased noticeably.

Dropout

Dropout is a method of regularising the network (i.e. preventing overfitting) by randomly dropping (setting to zero) a fixed percentage of the weights of a layer. It is a standard layer in Keras.

In the paper, $50\%$ of weights are dropped on each convolutional layer which suggests overfitting was a big problem.

Polyak Averaging

The idea of Polyak averaging is that all trainable model parameters (the weights) are updated based on an Exponential Moving Average (EMA). If the weights at step $t$ (a single batch) are $\theta^{(t)}$ then they are updated via

$$\hat{\theta}^{(t)} = \alpha \hat{\theta}^{(t-1)} + (1-\alpha) \theta^{(t)} $$

So $\alpha$ will control the number of batches over which the weights will be averaged. In the paper it is $\alpha = 0.9995$ and so averages weights over a long distance.

Weight Normalization

Weight normalization is a normalization of the weights of a layer which splits the magnitude of the weights from their direction. It is used to speed up convergence.

It is an alternative to the more common batch normalization (which normalizes the batch of data rather than the weights). The weights of the network are parameterized as

\begin{aligned} {\bf w} = {\bf g} \frac{\bf v}{ || v||} \end{aligned}

where ${\bf g} ,{\bf v}$ are optimized in place of a regular kernel’s weights. It offers a good alternative to batch normalization especially for RNNs, which is perhaps how it came to be favoured for the pixelCNN classes of models since they originate from pixelRNN.

The initialization of these weights is performed using the training data (an added complexity for implementations)

In the original pixelSNAIL code; weight normalization was implemented by hand for the Network-in-Network and CausalConvolution layers. However for TensorFlow 2, the addons library contains a layer WeightNormalization the works identically and wraps around a normal layer like

WeightNormalization(Conv2D(...), data_init=True)

Which is how I have implemented it in Keras.

Conclusions

This post may take a few reads and will definitely require a bit of reading around each topic if you want more detail. Hopefully it has offered a useful reference for how to create these networks and a starting point for generating new ideas.

The pixelSNAIL architecture is heavily influenced by the other models in the pixelCNN family which are themselves heavily influenced by ideas originating in RNN research. This makes them somewhat unusual if you approach them having seen most object detection or segmentation algorithms beforehand like me.

In the next post; I will finally move on to implementing and training the network on data and having a play about with it.