Bayesian inference in Deep Learning
10 May 2019It is possible to perform full-on Bayesian inference in Deep Learning. One could train a network, assume some sort of Gaussian distribution around the weights that were found, and the likelihood comes directly from the loss function since, in general, that loss function is derived from MLE. However, this is costly, and in practice it’s not clear whether people do that or not.
There has been some attempts at using variational inference, i.e., approximating the posterior with a simpler distribution (e.g., Gaussian), by looking for the candidate distribution that minimizes the Kullback-Leibler divergence between the posterior and the candidate distribution. It also seems to be computationally expensive.
The most popular option is to use Monte-Carlo Dropout (MC Dropout) at inference time (see here and here). The idea is to generate samples of the solution of the NN by randomly shutting down a certain numbers of cells for each forward propagation. The author proves some interesting properties of their methods.
More recently, Stochastic Weight Averaging (SWA) was introduced (see paper and blog post). SWA is not, so to speak, a new optimization algorightm, it simply modifies the typical learning rate strategy (high learning rate, following by continuously decreasing rate, until a very slow learning rate) to always maintain a relatively large learning rate and continue to explore the loss function. And after the training procedure, the latest weights, i.e., corresponding to the exploration phase, are averaged out. That average is designed to position the SWA weight in the middle of a large flat region (the authors assume that NN minima are in large flat portions of the loss function), instead of close to the edge (typical optim, e.g., SGD), leading to better generalization properties. You have the option, in the exploration phase, to only includes a regular sub-set of the weights. There is potential issue when SWA is issued in conjunction with batch normalization; because batch normalization learns the statistics of the activation during training, but the SWA weights were never used during that phase. The solution proposed by the authors is to re-pass the training set through the network after training (i.e., with SWA weights), and re-calculate the BN statistics at that time.
The authors also have an extension called SWA-Gaussian (SWAG) that allows to carry uncertainty quantification by estimating the first 2 moments of the weight distribution, again based on snapshots of the exploration phase.
[deeplearning
pytorch
optimization
bayesian
]