Autoregressive Generative Models in depth : Part 3
Hi and welcome to Part 3 of the series on Autoregressive Generative Models; last time we explained how to incorporate probabilistic interpretations to a network to allow it to generate new image samples.
This time we cover another key principle underlying these models; causality. Let’s get on with it…
Causality
One of the fundamental properties of a generative model is causality; which is the requirement that predictions made for a single element of a sequence (for example a pixel of an image, word of a sentence or note of a piece of music) only depends on previously generated elements and not future elements.
This is easiest to understand in the context of 1D or 2D convolutions where a regular convolution kernel will always include pixels belonging to the future of the current pixel; as depicted below
Causal Convolutions
In these posts we are dealing with networks involving 2D convolutions and these would break causality without modifications. To preserve causality;
- We must ensure that the result for a given pixel $(i,j)$ has no dependence on future pixels, throughout the full network.
- We must ensure that the pixel $(i,j)$ is connected to past pixels to enable it to be conditioned on them, throughout the full network.
Let’s deal with the first ambiguity here; using the terms “future” and “past”. Generally; when trying to associate these terms with an image we do so following how we read text, from left to right and top to bottom. That’s the same for this family of algorithms; but is ultimately an aesthetic choice.
The aim of the autoregressive model is to transform an image into a probability distribution on a pixel-by-pixel basis; so the output spatial size should be the same as the input spatial size. So we may want to make the following restriction
- The internal representations of the tensor throughout the network should maintain spatial size (H,W)
We would therefore allow only the channel number to change. Of course we can break this rule and have some sort of encoder-decoder structure where the spatial extent is compressed and the expended within the network (and pixelCNN++ works exactly this way).
There is often a difference between how the first and subsequent layers are treated. This comes down to whether or not the current pixel should be included in the layer; on the first layer it should not be included; since otherwise causality is broken; the layers have access to what is already in the pixel and can use the information to decide on the value to assign to it. However in subsequent layers the current pixel can be included (since the original pixel isn’t part of its value).
Regular Convolution : Why that won’t work
Regular convolution will not fulfil our goal here. As a simple example we will take an image height, width $H=W=5$ and kernel size $k=(3,3)$; to maintain the image size a padding of $(k-1)//2$ is used around the image (the $//$ here is the python operator which divides and then rounds down to an integer value, e.g $3//2=1$).
Now consider the very first evaluation; the kernel clearly includes pixels below and to the right of the current pixel. There’s no way to fix this with a different kernel size since the padding is always adjusted and the same problem reoccurs.
A nice term for the dependency on past pixels is causality; where future pixel choices are caused only by what happened before. These convolutions are often referred to as causal convolutions.
We can visualise causality nicely using the following graphic:
So the regular convolutions behave as above; the receptive field of the network is centred on the pixel and spreads out further the more layers we use (in this case only a few layers were used to illustrate).
Now let’s look at a series of causal convolutions;
Hopefully this makes the causality aspect clear (note that the current pixel is not red in the right hand plot; since as we have said the current pixel should not influence the network output at it own location; to prevent the network just “predicting” the value it already knows is there).
Masked Convolutions
An intuitive way to fix the problem above is to use the same (3,3) kernel but explicitly remove the weights (set them to zero) for the locations that are including future pixels.
There is often a difference between how the first and subsequent layers are treated in the network. This comes down to whether or not the current pixel should be included in the layer; on the first layer it should not be included; since otherwise causality is broken; the layers have access to what is already in the pixel and can use the information to decide on the value to assign to it. However in subsequent layers the current pixel can be included (since the original pixel isn’t part of its value). These are usually called mask A and mask B.
We are forced to keep the masking all the way through the network (using mask B); since any layer using a regular convolution would create an influence from previous pixels.
There is also the blind spot issue which explicitly breaks our requirements above as some pixels in the image cannot be connected to certain past pixels within the blind spot.
We can display the blind spot by creating an image of zeros; and placing a $1$ in a particular location $(i,j)$; then running a series of convolutional layers over the image we can identify any locations which are non-zero (remember to switch the bias terms off) as those influenced by the non-zero pixel; i.e. those pixels have the current pixel in their past.
The blind spot can be seen as areas in the current pixels future which are not influenced. This plot also serves as a good diagnostic for developing causal networks; it should never place non-zero values in the current pixels past.
This blind spot issue severely limits the use of masked kernels in the way described above and was noted and fixed in the follow up paper to pixelCNN.
Shift-and-Crop with Regular Convolution
There is another way to implement causality and avoid masking kernels at all; and I find it far from intuitive at first. Importantly though, it avoids the blind spot induced from masked convolutions and allows us to use regular un-masked convolutions.
The first step is to think about trying to split the masked kernel into two separate kernels; tackling the horizontal and vertical dependencies through the image separately.
Look at the formula for a regular convolution
\begin{aligned} {\rm Conv2d}(x)_{ij} = \sum_{l=i-k}^{i+k-1} \sum_{m=j-k}^{j+k-1} w_{lm} x_{lm} \end{aligned}We can easily let $w_{ij}=w^{(1)}_{ij} +w^{(2)}_{ij}$ and split the kernel into two pieces; then each piece can be applied as two separate convolutions and then added together after the convolution operation.
Then; why not have the two pieces perform separate functions; allowing one to handle the causality within the current row (horizontal stream) and one to handle the dependency on the previous row (vertical stream).
To make the algebra a bit more simple; let’s restrict the kernels used for causal convolutions to be odd dimensional.
So the key idea is to split the kernel into two; one which handles causality horizontally; and one which handles in vertically; and we want to avoid using masking. We will refer to these as the horizontal and vertical streams (sometimes also called stacks).
Horizontal Stream
Let’s start with horizontal stream; there are two masked kernels we can think about [1,1,0] and [1,0] which explicitly exclude the future pixel; only the first one includes a past pixel as well. The kernel [1,0] does absolutely nothing!
You notice from using [1,1,0] that the zero doesn’t do anything. We could just as easily use [1,1] (normal convolution) with the padding that was used for [1,1,0]. This would leave us with an extra point in the output (since we’ve left in the padding from the larger kernel). So we chop off the far end of the image.
Now if we had just begun with the [1,1] kernel we can see this as a pad+crop.
To generalize, if we have a kernel of size $(1, k)$; then we must pad with $(k-1)$ on the left to prevent the inclusion of future pixels; then since the normal padding is $2(k-1)//2$ (always bigger than or equal to $k-1$) we must chop off the end element for even kernel size (or leave it for odd kernel size).
So; the simplest (odd) kernel enforcing horizontal causality is [1,1,1] with a padding of 2 to the left and 0 to the right. We could also use [1,1] with a padding of 1 on the left.
Vertical Stream
The horizontal kernel deals with connections within a row; but we also need to include connections for the rows above.
The vertical stream proceeds basically the same as the horizontal. We want to avoid a masked kernel for same reasons. So we just remove the zeros from the masked kernel and add them into the image instead.
We can use the above examples for the kernel and apply padding rules as for the horizontal.
There is a need to apply an extra downward shift to the image; this is because otherwise the vertical stack breaks causality along the horizontal row.
Essentially we pad the image at the top with $(k-1)$ rows of zeros. Doing so will mean the row in the output will include connections between elements of that row, which we don’t want. These are removed with a shift.
Horizontal and Vertical Streams Together
So far we have split the masked kernel into two stacks; vertical and horizontal which handle the respective causality without using a masked kernel, but instead by shifting and cropping. There is some level of choice in exactly what size kernels to use for the task so for the same of simplicity I propose the following split;
The two kernels handle the vertical and horizontal causality according to the rules outlined previously.
We have a choice in how to combine the two streams. We could in principle feed the output of one into the input of the other. This won’t work though, if the vertical convolution is applied first then the shift will break causality in the horizontal stack. However the horizontal stack output can be fed into the vertical without breaking causality.
It makes sense to run the two convolutions and combine the outputs using addition; since this will piece back together the “full convolution” (without the past pixels).
Recall that for masked kernels; the first and subsequent layers are treated differently. That is still the case for the H/V stack; on the first layer an extra shift is made in the image to push the current pixel away after the first conv layer (so for the horizontal stack the image is shifted right, and for the vertical stakc the image is shifted down).
This is all explained the diagram below which shows the two options for handling the streams.
We can create a simple rule to create a pair of convolutions that cover vertical and horizontal regions in a pixel’s past (for a kernel size of $k=(k_0,k_1)$);
- Vertical kernel $k=(k_0-1,k_1)$:
- pad height with $(k_0-1,0)$
- pad width with $((k_1-1)//2, (k_1-1)//2)$ (normal “same” padding)
- If first layer downshift image
- Horizontal kernel $k=(1,k_1)$:
- pad height with $(k_0-1, 0)$
- pad width with $(k_1-1,0)$
- If first layer right shift image
Now the blind spot will be filled through the network, after some number of layers as shown below;
With the H/V split we have the same total number of parameters as a single $(3,3)$ kernel. It is also possible to merge the H-V streams after the first layer and just use the vertical stream from then on; so I think this represents a minimal implementation of causality.
Merging Horizontal and Vertical Streams
The question of how can we merge streams is answered; we keep H and V separate; but may optionally merge V into H. Now how should we merge them. It turns out each implementation has its own variant.
- Now the blind spot will be filled through the network, after some number of layers.
- PixelCNN : Used basic masking, Mask A and mask B, but has the blind spot an issue.
- Gated PixelCNN : use V/H streams; kept separate.
- pixelSNAIL : No masking; instead the kernels are reshaped and padding modified.
- pixelSNAIL in VQ-VAE2 : H/V combined in first layer, subsequent layers use masked filter. blind spot re-emerges.
The route taken by a couple of implementation, reduced to the elements which actually affect the spatial indices (i.e. not 1×1 Convolutions or Gating), looks like below
The V and H streams are processed as outlined above; then the V stream is added to the H stream. An additional skip connection is added to the H stream (not V, authors state via experimentation that it doesn’t affect results) allowing the H-conv to be bypassed.
What else is added? There are some 1×1 convolutions; these will densely connect the channels together to allow features to be modified without any mixing between spatial indices (i.e. no acccidental mixing be present/past/future pixels).
The original pixelSNAIL architecture keeps the streams separate throughout the network; merging them only at the end and as a result the implementation is a bit more complex.
I have found that other implementations of pixelSNAIL (for example in the VQVAE 2.0) actually look more like the pixelCNN implementation.
Overall; I am not sure what significant difference is made in performance due to these options; and therefore the simplest would be the best, which for me is the pixelCNN++.
Gating the Streams
The gating mechanisms for pixelCNN first appeared in the gated pixelCNN architecture.
The authors don’t really provide a convincing argument in favour of gating the streams. Gates occur frequently in RNN architectures and are important there for controlling/allowing information flow over long distances; and presumably the idea is that they will do the same here.
Normal gating (GLU) can be defined via;
$${\rm GLU}(X) = {\rm Conv}(X) \odot \sigma({\rm Conv}(X))$$
The input is passed through two convolution layers in parallel; with a sigmoid activation for one convolution and none for the other. There is then an element-wise product of the two.
Then the $sigma$ function maps the right hand side into a number $\in [0,1]$ which acts like a soft “switch” (gate) to allow the information from the lefthand side to be passed through or not.
In highway networks (which are referenced in the paper) there is a transform and carry gate
$$H(X) \cdot T(X) + X \cdot C(X)$$
Where $C=1-T$ therefore decides how much information is carried through the gate; compared to transformed by H. Then for CNNs T and C are expressed easily as convolutions, the $\sigma(X)$ represents the transform gate in pixelCNN but there is no carry gate. I guess this gating mechanism allows the network the ability to “switch off” the transform but not carry through the information. That may be why the additional skip connection is needed for H. It may also selectively switch the transform according to spatial region.
Another difference of the pixelSNAIL version to the vanilla gating is that the convolution is shared between the two functions. Imagine having many carry gates in sequence $X \cdot C(X)$; there will always be a route to pass the information straight through the layers due to the $X$ (equivalently there will always be a gradient term with large gradients). Because this is missing in the gated pixelCNN it is less clear what the gating mechanism is intended to do.
Code Implementation (Tensorflow/Keras)
Here I am going to focus on an implementation in Keras; since the pixelSNAIL code is old, difficult to read, and based on TensorFlow 1. Keras builds models using layers; and to implement causality in a simple way we can simply create a new layer called CausalConv2D
as shown below
class CausalConv2D(Layer):
"""
Basic causal convolution layer; implementing causality and weight normalization.
"""
def __init__(self, filters, kernel_size=[3, 3], weight_norm=True, shift=None, strides=1, activation="relu",
**kwargs):
self.output_dim = filters
super(CausalConv2D, self).__init__(**kwargs)
pad_h = ((kernel_size[1] - 1) // 2, (kernel_size[1] - 1) // 2)
pad_v = ((kernel_size[0] - 1) // 2, (kernel_size[0] - 1) // 2)
if shift == "down":
pad_h = ((kernel_size[1] - 1) // 2, (kernel_size[1] - 1) // 2)
pad_v = (kernel_size[0] - 1, 0)
elif shift == "right":
pad_h = (kernel_size[1] - 1, 0)
pad_v = ((kernel_size[0] - 1) // 2, (kernel_size[0] - 1) // 2)
elif shift == "downright":
pad_h = (kernel_size[1] - 1, 0)
pad_v = (kernel_size[0] - 1, 0)
self.padding = (pad_v, pad_h)
self.pad = ZeroPadding2D(padding=self.padding,
data_format="channels_last")
self.conv = Conv2D(filters=filters, kernel_size=kernel_size, padding="VALID", strides=strides,
activation=activation)
if weight_norm:
self.conv = WeightNormalization(self.conv, data_init=True)
def build(self, input_shape):
super(CausalConv2D, self).build(input_shape)
def call(self, x):
return self.conv(self.pad(x))
def compute_output_shape(self, input_shape):
return self.conv.compute_output_shape(input_shape)
def get_config(self):
config = super(CausalConv2D, self).get_config()
config.update({'padding': self.padding,
'output_dim': self.output_dim})
return config
The causality can now be implemented with a careful selection of the shift
and kernel_size
arguments. The shift
can be “down”, “right” and “downright” which is how the padding will be applied; for example with a “down” shift; the padding in the vertical direction will be applied fully to the top of the image (rather than equally at top/bottom as you would get with “same” padding).
On the first layer we must make additional shifts to the image after the CausalConv2d
layer. I have put this operation in a layer as well
class Shift(Layer):
"""
A layer to shift a tensor
"""
def __init__(self, direction, size=1, **kwargs):
self.size = size
self.direction = direction
super(Shift, self).__init__(**kwargs)
if self.direction == "down":
self.pad = ZeroPadding2D(padding=((self.size, 0), (0, 0)), data_format="channels_last")
self.crop = Cropping2D(((0, self.size), (0, 0)))
elif self.direction == "right":
self.pad = ZeroPadding2D(padding=((0, 0), (self.size, 0)), data_format="channels_last")
self.crop = Cropping2D(((0, 0), (0, self.size)))
def build(self, input_shape):
super(Shift, self).build(input_shape)
def call(self, x):
return self.crop(self.pad(x))
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = super(Shift, self).get_config()
config.update({'direction': self.direction,
'size': self.size})
return config
We can now set up the first layer of the model (which just sets up the horizontal and vertical streams and so doesn’t need to change between implementations)
x_in = Input(shape=(32, 32, 3))
v_stream = Shift("down")(CausalConv2D(128, [2, 3], shift="down")(x_in))
h_stream= Shift("right")(CausalConv2D(128, [1, 3], shift="right")(x_in))
In all subsequent layers the convolutions can be applied identically, just without the Shift
.
Conclusions
That’s all for now… In this third post of the series we went in depth on the idea of causality in autoregressive generative models; an essential requirement to allow the networks to be able to generate new images.
We showed how to build causality into a network architecture with as few modifications as possible; and this can be done with a CausalConvolution
operation in place of a normal Convolution2D
operation.
We will use this as the fundamental convolution operation for the pixelSNAIL network. In the next post we will have a look at how to construct the blocks which form the repeating structures of the pixelSNAIL network.