My understanding is that the goal of using Gumbel softmax is to change an output that contains logits into a one-hot vector corresponding to the highest probability choice (based on those logits).
However, to me it seems like there are many simpler ways to do this:
- Use (soft)argmax on the logits directly. This is the simplest choice. We could also use straight-through estimator just like we could with Gumbel. The only reason why I think this could not be the best option is because we might always get the same output that not representative of the probability. For example if our output is given by $(U(1,2), U(0,1))$ we will always get vector [1,0], so the distribution of those vectors will not be the same as probabilities.
- Use softmax and then (soft)argmax on the logits directly. I saw this idea in the answer to this question. It seems like according to the author this would not be differentiable but both softmax and (soft)argmax are differentiable and again we could also use straight-through estimator just like we could with Gumbel.
- Using similar functions instead of Gumbel. Why can't we just add a sample from the following distribution $\tan(\pi \mathcal{U}(0,1)+\frac{\pi}{2})$ instead of Gumbel distribution (which corresponds to $-\log(-\log(\mathcal{U}(0,1)))$)? Both give us values between (0,1) where extreme values get close to $\pm$infinity. Why is Gumbel better here?
I looked at few similar questions but none of them have a clear answer for me. For example here the best answer states that it is common to just use softmax, so it clearly works, but why is it not the most common approach?