Autoregressive Generative Models in depth : Part 2
Hello, welcome to part 2 of the autoregressive model series. Part 1 was a quick introduction to the models and what they do (generate images). In this part we will go in depth on one of the key elements of these models; the probabalistic treatment of the network output.
Probability and Loss in Autoregressive Generative Models
Here I am going to discuss the loss function for autoregressive generative models and how it relates to probability.
We want a generative model to be able to produce a probability density function
\begin{aligned} p(x_{ij}) \ , \ x_{ij}\in [0,255]^3, \end{aligned}which describes the likelihood that the given image pixel takes a value $x$. In the full model; where the network uses previous pixel values to produce $p$, this is in fact a conditional probability distribution where the current pixel value is conditioned on all previous pixels. For now; let’s just forget the conditional nature of $p$.
Once we have this distribution we can sample it to obtain a value for each of the pixels in the image. We can choose this distribution
\begin{aligned} p(x_{ij}; {\bf \theta} ) \ , \ \theta_{i}\in \mathbb{R}, \end{aligned}We can then let the network fit the parameters $\theta $.
The simplest non-trivial model is a normal distribution $\mathcal{N}(\mu, \sigma)$ which depends on a mean $\mu$ and deviation $\sigma$; by itself this would not have enough capacity to model complex relationships between pixel values; so we can simply add a series of gaussians together to generate a more complex distribution.
We are following along with the pixelCNN++ and pixelSNAIL models. The authors of these models use a logistic distribution instead of a Gaussian, which is very similar but benefits from a more efficient sampling function. I hadn’t come across the logistic function before so I assume you haven’t either. Similar to a normal distribution it has a mean $\mu$ and a scale $s$ and the pdf can be written as
$${\rm logistic}(x; \mu, s) = \frac{e^{-(x-\mu)/s}}{s \left( 1 + e^{-(x-\mu)/s} \right)^2 }$$
Then the model used is a sum of $L$ logistic functions (here $x$ refers to a single pixel in the generated or training image)
$$p(x, \theta) = \sum_i^L p_i(x) = \sum_i^L \pi_i \ {\rm logistic}(x; \mu_i, s_i)$$
Which is correctly normalized if $\sum \pi_i = 1$. This can be enforced using a softmax
on the output logits of the network which are to be used for $\pi_i$.
The training objective is to minimize the negative-log of the probability distribution;
$$\hat{\theta} = {\rm argmin}_\theta \left( – \log(p(x,\theta)) \right)$$
The equation for $p(x)$ above refers to the pixel $x$; which has 3 values for each of the RGB colours. To deal with the colours we could simply have three logistic functions multiplied together; with each one corresponding to a different colour;
$$p_i(x) \equiv P(r) \cdot P(g) \cdot P(b)$$
with the $P(c) \equiv {\rm logistic}(c; \mu_c, s_c)$.
But it seems like a universal feature of images that the colors are highly correlated for a single pixel; and the pixelCNN++ paper formalizes this for the loss by letting the colours have a linear dependency which allows the parameters to model the correlation. This needs to be done to the means $\mu$ of the logistics as;
$$\begin{aligned}\mu_r’ &= \mu_r \\\mu_g’ &= \mu_g + c_0 r \\\mu_b’ &= \mu_b + c_1 g + c_2 r\end{aligned}$$
now $c_{0,1,2}$ are additional parameters that can be output by the network (so that they capture the complex intra-pixel dependencies). The RGB values are therefore derived using the following conditional distribution
$$p_i(x) \equiv P(r) \cdot P(g|r) \cdot P(b|r,g)$$
with
$$\begin{aligned}P(r) &= {\rm logistic}(x; \mu_r’, s_r) \\P(g|r) &= {\rm logistic}(x; \mu_g’, s_g) \\P(b|r,g) &= {\rm logistic}(x; \mu_b’, s_b)\end{aligned}$$
With that extra step; we have finally defined the probability $p(x| \theta)$ of obtaining the training data $x$ from the parameters $\theta$
Training (Loss)
For training; we have only the training images $X$ (this is unsupervised learning and there is no target label); we let the network predict the parameters $\theta$ from this image and then build the log-likelihood (log of the probability) from these as outlined above.
If the model parameters are good; then the probability $p(x)$ will be high and the training moves towards a situation where the model parameters $\theta$ correctly capture the underlying distribution of the training data.
That being the case; we can sample from these distributions to create new examples which follow the training set well.
In total then we have to optimize the following parameters (for an example where we choose $L=10$ logistic distribtions) for a single image pixel;
- $\pi_i \in [0,1]$ the probability of each logistic distribution (10)
- $\mu_{i,j} \in [-\infty, \infty]$ the centre of the $i$th logistic distribution for the $j$th color (30)
- $s_{i,j} \in [0, 1]$ the width of the $i$th logistic distribution for the $j$th color (30).
- $c_{k,i}$ the $k=0,1,2$ linear coefficients mapping $r\rightarrow g$ and $r,g \rightarrow b$ for each logistic function $i$ (30).
So a single image pixel requires $100$ parameters to model its probability distribution.
There are some additional details with actually implementing the loss function which require the edges cases $x=0$ and $x=255$ to be handled carefully; you can either refer to the implementation provided on my GitHub or this excellent article.
Sampling New Images
Due to the causal nature of the network; the sampling must be done with a full pass through the network for each pixel. What follows is the sampling process for a single pass of the network.
The network produces predictions of the 100 parmeters from the previous section according to the portion of the sampled image in the “past” (above and left of current pixel).
First a single distribution is chosen according to the $\pi_i \in [0,1]$; and then this distribution is used to sample the RGB values in order. The B (G) pixel value is conditioned on the R (RB) pixel values and clipped to the range (-1,1).
The distribution $i$ is chosen from the $\pi_i$ using the Gumbel-max trick; which efficiently samples from a softmax distribution (so the logits are softmaxed within the sampler).
Then the logistic distribution ${\rm logistic}(x; \mu_i, s_i)$ is itself sampled; using the mean and standard deviation logits fitted by the network; there will be a $\mu$ and $s$ for each of the RGB colours which we can call $\mu_{r,g,b}$.
where the values are sampled as;
$$\begin{aligned}\hat{r} &\sim {\rm logistic}(\mu_r, s_r) \\\hat{g} &\sim {\rm logistic}(\mu_g + c_0 \hat{r}, s_g) \\\hat{b} &\sim {\rm logistic}(\mu_b + c_1 \hat{r} + c_2 \hat{g}, s_b)\end{aligned}$$
There is a nice simple way to sample the logistic distribution
$$\hat{x} = \mu + s (\log(u) – \log(1-u)), \ \ u \sim U(0,1)$$
So the modification of the means can actually be applied after the sampling (or explicitly; you can sample ${\rm logistic}(\mu +a, s)$ or just sample ${\rm logistic}(\mu, s)$ and shift the result by $a$). Hence we can also write the sampling as
$$\begin{aligned}\hat{r} &= \hat{\mu_r} \\\hat{g} &= \hat{\mu_g} + c_0 \hat{r} \\\hat{b} &= \hat{\mu_b} + c_1 \hat{r} + c_2 \hat{g} \\ \\\hat{\mu}_c &\sim {\rm logistic}(\mu_c, s_c)\end{aligned}$$
The sampling function only needs to be applied to a single the pixel (the current pixel) from the output of the model (which produces a series of logits as we have seen); however in practice it is easier to just apply it across the entire image and then only select the pixel of interest. This is how it is implemented in all cases I have seen.
Implementing the Loss and Sampling Function in Tensorflow
The network returns logits only (since at least some of the parameters need to take on a range outside [0,1], and it is unwieldy to have the model be applying different activations to different chunks of output, far easier to do that in the loss).
When dealing with logarithms and exponentials there are a couple of edge cases where numbers may be become undefined such as $\log(0)$ or may run into overflow/underflow such as $e^{\infty}, e^{-\infty}$; and to ensure a stable code the in-built TensorFlow functions should be used which already handle these. It is fiddly; but most of the time implementing these things just requires patience.
When implementing the probability in code, the following reparameterization helps;
$$\begin{aligned}& \log \left( \sum_i \pi_i p_i(x) \right) \\=& \log \left( \sum_i e^{\log( \pi_i p_i(x))} \right) \end{aligned}$$
then we can use the efficient log_sum_exp
, applied to the log probabilities. The individual probabilities are further expanded;
$$\begin{aligned}\log(\pi_i p_i(x)) &= \log(\pi_i) + \log(p_i(x)) \\& = \log(\pi_i) + \log(P_i(r)) + \log(P_i(g|r))+ \log(P_i(b|r,g))\end{aligned}$$
$\log(\pi_i)$ can be found using log_softmax
; since the incoming values from the network are logits and this will correctly normalize them.
Functions for the loss and also sampling according to the discussions here are provided on my GitHub and are adapted from the original authors’ code (I have essentially just made them more readable and compatible with TensorFlow 2.0).
Here is the loss function (rewritten from the pixelCNN++ codebase), which returns the negative log likelihood in units of bits per dimension (those that are used in the results of the pixelSNAIL and pixelCNN papers)
def discretized_mix_logistic_loss(x, l):
"""
Negative log-likelihood loss for mixture of discretized logistics in "bits per dimension" units.
bits per dimension = NLL / (log(2) * B * H * W * C)
with negative log-likelihood;
$$
NLL(x) = \sum_i pi_i * logistic(x, mu_i, s_i)
$$
For example, a batch of images of (B,H,W,C) (1,32,32,3) as for cifar-10; with a likelihood build from a mixture of
nr_mix=10 logistic distributions expects an output from the network of (1,32,32,100) where the output feature length N=100
is (nr_mix * (1 + 3 + 3 + 3)) corresponding to the pi_i (mixture indicator), mu_i, s_i and c_i.
Parameters
----------
x, Tensor (B,H,W,3) :
The input RGB image, which must be scaled to the interval [-1,1]
l, Tensor (B,H,W,N) :
The output from a pixelCNN network with the same spatial size as the input, and where the output channels N
is (nr_mix * (1 + 3 + 3 + 3)) corresponding to the pi_i (mixture indicator), mu_i, s_i and c_i.
Returns
-------
loss, float
"""
ls = l.shape
#. number of logistics in distribution
nr_mix = int(ls[-1] / 10)
# unpacking the params of the mixture of logistics
split = [nr_mix, 3 * nr_mix, 3 * nr_mix, 3 * nr_mix]
pi_i, mu_i, log_s_i, rgb_coefficients = tf.split(l, num_or_size_splits=split, axis=-1)
log_s_i = tf.maximum(log_s_i, -7.)
log_s_i = tf.concat(tf.split(tf.expand_dims(log_s_i, -2), 3, -1), -2)
rgb_coefficients = tf.nn.tanh(rgb_coefficients)
one_over_s_i = tf.exp(-log_s_i)
# get mu_i and adjust based on preceding sub-pixels
mu_r, mu_g, mu_b = tf.split(mu_i, num_or_size_splits=3, axis=-1)
c0, c1, c2 = tf.split(rgb_coefficients, num_or_size_splits=3, axis=-1)
x_r, x_g, x_b = tf.split(x, num_or_size_splits=3, axis=-1)
mu_g += c0 * x_r
mu_b += c1 * x_r + c2 * x_g
mu_i = tf.concat([tf.expand_dims(mu_r, axis=-2),
tf.expand_dims(mu_g, axis=-2),
tf.expand_dims(mu_b, axis=-2)], axis=-2)
x = tf.expand_dims(x, -1)
x_minus_mu = tf.subtract(x, mu_i)
# log probability for edge case 0
plus_in = one_over_s_i * (x_minus_mu + 1. / 255.)
cdf_plus = tf.nn.sigmoid(plus_in)
log_cdf_plus = plus_in - tf.nn.softplus(plus_in)
# log probability for edge case of 255 (before scaling)
min_in = one_over_s_i * (x_minus_mu - 1. / 255.)
cdf_min = tf.nn.sigmoid(min_in)
log_one_minus_cdf_min = -tf.nn.softplus(min_in)
# probability for all other cases
cdf_delta = cdf_plus - cdf_min
mid_in = one_over_s_i * x_minus_mu
# log probability in the center of the bin, to be used in extreme cases
log_pdf_mid = mid_in - log_s_i - 2. * tf.nn.softplus(mid_in)
log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min,
tf.where(cdf_delta > 1e-5,
tf.math.log(tf.maximum(cdf_delta, 1e-12)),
log_pdf_mid - np.log(127.5))))
# sum log probs ==> multiply the probs
log_probs = tf.reduce_sum(log_probs, 3)
log_probs += tf.nn.log_softmax(pi_i - tf.reduce_max(pi_i, -1, keepdims=True))
loss = -tf.reduce_sum(logsumexp(log_probs))
n = tf.cast(tf.size(x), tf.float32)
return tf.cast(loss, tf.float32) / (n * np.log(2))
and the sampling function
def sample_from_discretized_mix_logistic(l, nr_mix):
"""
Sampling function for the pixelCNN family of algorithms which will generate an RGB image.
Parameters
----------
l, Tensor (B,H,W,N)
The output from a pixelCNN network, where N is (nr_mix * (1 + 3 + 3 + 3)) corresponding to the pi_i (mixture
indicator), mu_i, s_i and c_i.
nr_mix, int
The number of logistic distributions included in the network output. Usually 5 or 10
Returns
-------
Tensor, (B,H,W,3) : The RGB values of the sampled pixels
"""
ls = list(l.shape)
xs = ls[:-1] + [3]
# split the network output into its pieces
split = [nr_mix, 3 * nr_mix, 3 * nr_mix, 3 * nr_mix]
logit_probs, means, log_s, coeff = tf.split(l, num_or_size_splits=split, axis=-1)
means = tf.reshape(means, shape=xs + [nr_mix])
scale = tf.exp(tf.reshape(log_s, shape=xs + [nr_mix]))
coeff = tf.reshape(tf.nn.tanh(coeff), shape=xs + [nr_mix])
# the probabilities for each "mixture indicator"
logit_probs = tf.nn.log_softmax(logit_probs - tf.reduce_max(logit_probs, -1, keepdims=True))
# sample "mixture indicator" from softmax using Gumbel-max trick
rand_sample = -tf.math.log(tf.random.uniform(list(logit_probs.shape), minval=1e-5, maxval=1. - 1e-5))
sel = tf.argmax(logit_probs - tf.math.log(rand_sample), 3)
sel = tf.one_hot(sel, depth=nr_mix, dtype=tf.float32)
sel = tf.reshape(sel, xs[:-1] + [1, nr_mix])
# select logistic parameters from the sampled mixture indicator
means = tf.reduce_sum(means * sel, 4)
scale = tf.maximum(tf.reduce_sum(scale * sel, 4), -7.)
coeff = tf.reduce_sum(coeff * sel, 4)
# sample the RGB values (before adding linear dependence)
u = tf.random.uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5)
sample_mu = means + scale * (tf.math.log(u) - tf.math.log(1. - u))
mu_hat_r, mu_hat_g, mu_hat_b = tf.split(sample_mu, num_or_size_splits=3, axis=-1)
# include the linear dependence of r->g and r,g->b
c0, c1, c2 = tf.split(coeff, num_or_size_splits=3, axis=-1)
x_r = tf.clip_by_value(mu_hat_r, -1.0, 1.0)
x_g = tf.clip_by_value(mu_hat_g + c0 * x_r, -1.0, 1.0)
x_b = tf.clip_by_value(mu_hat_b + c1 * x_r + c2 * x_g, -1.0, 1.0)
return tf.concat([tf.reshape(x_r, xs[:-1] + [1]),
tf.reshape(x_g, xs[:-1] + [1]),
tf.reshape(x_b, xs[:-1] + [1])], 3)
Typical Loss Values
The loss function implemented above outputs bits per dimension (BPD) which is a standard metric to compare models; the state-of-the-art on the 32×32 cifar-10 image set is $2.85$. This is with a very deep model trained for days or weeks; so you should expect not to do this well. Smaller models achieve around $\gtrsim 3.0$.
I generally found that anything above $3.5$ yields very poor results on generated images. I tried another dataset (gemstones) and found these cifar-10 values to be a good guide.
Conclusions
That is all for this part of the series; we have covered how to implement a loss function for autoregressive generative models like pixelCNN using a “discretised mixed logistic” function and also how to consistently implement a sampling of the same distribution to allow the generation of new images on a trained model.
The code for this is available for TensorFlow 2.0 on my GitHub and you can use the loss/sampler for your own model (just ensure the model returns the correct number of logits).
In the next post I will cover causality in generative models based on CNNs.