From my understanding of VAE's, there's a step during training in the middle where, after the encoder produces a mean and standard deviation, random samples are drawn from the given learned distribution to create the encoded vector that the decoder works to decode. I understand how one uses the KL divergence to force the learned distribution to be approximately the standard Gaussian, but I don't understand how the reconstruction loss can be back propagated past this sampling step. Random sampling is not a differentiable operation, so how can the gradients propagate past it? Is my understanding of VAE's wrong?
Asked
Active
Viewed 3,341 times
9
-
5Does this answer your question? How does the reparameterization trick for VAEs work and why is it important? – David Dao Jun 17 '20 at 22:51
1 Answers
17
The reparameterization trick.
$$x = \text{sample}(\mathcal{N}(\mu, \sigma^2))$$
is not backpropable wrt $\mu$ or $\sigma$. However, we can rewrite this as:
$$x = \mu + \sigma\ \text{sample}( \mathcal{N}(0, 1))$$
which is clearly equivalent and backpropable.
shimao
- 26,092
-
1Does this mean that we can't build autoencoders that use a different distribution that can't be reparameterized this way? – enumaris Apr 26 '18 at 15:52
-
1@enumaris most distributions can be reparameterized. For example, you can use a categorical latent space using the gumbel softmax trick. – shimao Apr 26 '18 at 15:55
-
3But in theory the normal distribution is all you'll ever need, since a sufficiently powerful function approximator can always map the normal distribution to any arbitrary probability distribution. – shimao Apr 26 '18 at 15:56
-
Hey, thank you for this answer - this really helped me! But I have a question: Why is $\sigma$ squared when using as an argument for the function but not squared when used outside? – blue-phoenix Aug 09 '18 at 07:37
-
1@blue-phoenix Scaling a random variable by a factor of $k$ scales the variance by a factor of $k^2$ – shimao Aug 09 '18 at 07:40
-
So when I have a normal() implementation which takes the standard deviation as argument, not variance (is not backpropable, therefore I'm asking). Than I need to use the square root of $\sigma$ when using as a factor outside? – blue-phoenix Aug 09 '18 at 07:44
-
@blue-phoenix by convention, the two arguments to $\mathcal{N}$ are mean and variance respectively. $\sigma$ is the typical letter for standard deviation. – shimao Aug 09 '18 at 07:45
-
Concretely I'm talking about the pytorch implementation
normal(mean, std, out=None)-The :attr: 'std' is a tensor with the standard deviation of each output element's normal distribution.So I assumed this function takes $\mu$, $\sigma$ instead of $\mu$, $\sigma^2$. So I'm wondering if the formula above changes then to: $x = \mu + \sqrt{\sigma} * \text{sample}( \mathcal{N}(0, 1))$? Or doesn't it make any difference because variance resp. standard deviation equals 1? – blue-phoenix Aug 09 '18 at 07:55 -
1@blue-phoenix there's no need to add the square root. The fact that a particular implementation of the normal distribution is parameterized by the standard deviation rather than the variance doesn't change any of the math. – shimao Aug 09 '18 at 07:58
-
1@shimao: excuse me, why is it not backpropagatable with respect to mu and or std? – Hossein Oct 01 '19 at 10:57
-
1@Hossein it is, it's just that the elegant solution is to reparametrize (because we want unbiased gradients from a MC sample) – Firebug Sep 29 '23 at 12:51