Mixture Density Networks

Motivation One of the coolest ideas I learned in this course is the probabilistic interpretation of Neural Networks. Instead of using the NN to predict directly the targets, use it to predict the parameters of a conditional distribution over the targets (given the input). So you’re learning for each input you have a conditional distribution over the target, each has the same form, but with different parameters. For example, we can assume that the target has a Gaussian conditional distribution, and the network predicts the mean of this distribution (we assume here that the variance is independent of the input, but this assumption be relaxed). We saw in class that if we train the network with the mean-squared error loss function we get the same solution as when we train it with the negative log-likelihood of this Gaussian. One might hope that we learn a more complex (multi-modal) conditional distribution for the target. This is actually the goal of Mixture density networks! we want to model the conditional distribution as a mixture of Gaussians, where each Gaussian component parameters are dependent on the input, i.e.: P(y_n \mid x_n) = \sum_{k=1}^K \pi_k(x_n) \mathcal{N}_k(y_n \mid \mu_k(x_n), \sigma_k^2(x_n)) The network here has 3 types of outputs: the mixing coefficient \pi_k, mean of the Gaussian component \mu_k, and its variance \sigma^2_k. One might think: we don’t have ground truths for those outputs, how could make the network learn them?! the answer is we don’t need ground truths, because the loss function we’re going to use is the negative log likelihood given the data, so we just update the parameters of the model as to minimize this loss function. For more discussion about this model you can check Bishop’s book chapter 5. Implementation I have written a Theano implementation of mixture density networks (MDN) which you can find here. I wrote it such that it supports multiple samples at once, so the Gaussian components are multivariate, and it also supports mini-batches of data. This made the implementation a little more interesting since I have to deal with 3d tensor for \mu. What I did is that instead of having one matrix for the output layer in the case of a standard MLP, you have a tensor for \mu, and two matrices for \sigma^2 and \pi. The activation function for \mu is the same as the desired output, and for \sigma^2 and \pi it’s a softplus and softmax respectively. Similar to what David has observed, the straight implementation of MDN would cause a lot of NaNs. A very important issue when implementing MDN is that you have the log-sum-exp expression in the log likelihood, which can be numerically unstable. This can be fixed using this trick. I also had to use a smaller initial learning rate than a the one I used in my previous MLP, otherwise I would get NaN in the likelihood. With these two tricks, I don’t get any more NaNs. For the RNADE paper trick, I tried multiplying the mean with the variance in the cost function, but this changes the gradients of the variance and it makes the performance worse. In addition, I didn’t find it helping at all. Multiplying the gradient of the \mu directly with the \sigma is a little tricky when you’re using Theano’s automatic differentiation, and that’s probably why when I checked the RNADE code I found that they’re computing the gradients without using Theano’s T.grad. Experiments We would like to compare the MDN model with a similar MLP, and we can compare them in terms of the mean negative log-likelihood (Mean NLL) and the MSE on the same set of validation set. Computing the log-likelihood of the MLP is easy, it’s just the log of a Gaussian, with the output of the network as the mean, and its variance is the maximum likelihood estimate from the data, which turns out to be the MSE. On the other hand, to compute the MSE for the MDN model, we need to sample from the target conditional distribution. We do that by doing the following for each input data point: we sample the component from the multinomial distribution over the components (parametrized by the mixing coefficients), which gives us a selected component, and then sample the prediction from the selected Gaussian component. I ran the first set of experiments on the AA phone data set. I took 100 sequences for training and 10 for validation. I trained two models. Both the MLP and the MDN take as input a frame of 240 samples, and output one sample. The dataset used to train the models has 162,891 training examples and 14,739 validation examples. The following plot shows training and validation mean NLL for the MDN for the following hyper-parameter configuration:

  • 2 hidden layers each has 300 units with tanh activations
  • initial learning rate MDN: 0.001
  • linear annealing of learning rate to 0 starting after 50 epochs
  • 128 samples per mini-batch
  • 3 components

mdn We can the Mean NLL decreasing which means the model is learning. The validation Mean NLL stabilizes after almost 100 epochs. What I am mainly interested in though is comparing the same MLP architecture with the MDN. Therefore, I used the pretty much the same hyper-parameters for both networks to see if we can get advantage by just having the mixture of Gaussians at the output layer. The following is a plot that shows results on the same validation set and using the following hyper-parameters: mdn_mlp_aa I was expecting the MDN to perform better than the MLP. However, we can see that the MLP is better than the MDN both in terms of MSE and Mean NLL. The minimum MSE in the MLP is 0.0222 and for the MDN is 0.0324, and the minimum Mean NLL for MLP is -1.29, and for the MDN is -0.77. This is actually the typical performance pattern in pretty much all experiments I did on this dataset. To investigate more I tried varying the number of components, and found that performance improves only a little as we increase components (For 10 components the minimum Mean NLL reaches -0.91). In both models I was not able to generate something that sounds like \aa\, but the following generated waveform from the MDN model shows that it was able to capture the periodicity of the \aa\ sound, but it’s still more peaky than a natural signal: gen_aa   We saw that the MDN doesn’t do better than the MLP in the \aa\ dataset, so it turns out we’re not benefiting from having a multi-model predictive distribution. To verify more, I performed another set of experiments on a more complicated task, where I used full utterances of one user (FCLT0) with the phoneme information (the current and next phonemes, as in the previous experiment). I trained on 9 utterances and validated on 1. The dataset has 402,939 training examples, and 70,621 validation examples. Using the same hyper-parameter settings, I got the following results: mdn_mlp_user   Here we see that the MDN beats the MLP in terms of the Mean NLL, but still doesn’t perform better on MSE. This is kind of surprising, as you might think that the MDN has a better model for the data, but it’s probably the variance of the sampling from MDN that’s increasing the error. This is still something interesting to investigate more into in the future.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: