Given a gamma distribution with unit scale and shape $\theta$, and given an arbitrary variate $x$, what is the derivative of the variate $x$ with respect to $\theta$?
In other words, I would like to produce a stochastic function: $$ \begin{align} g: \mathcal R &\to \mathcal R \times \mathcal R \\ \theta &\mapsto \left( z, \frac{dz}{d\theta} \right) \end{align} $$ where $z \sim \mathrm{Gamma}(\theta, 1)$.
This kind of question comes up in machine learning with variational autoencoders whenever we try to differentiate loss functions with respect to parameters that generate distributions. A common work-around is to use the reparametrization trick.
My goal with this question is to avoid using the reparamitrazation trick and instead improve the machine learning library that I use to arm it with an appropriate so-called "JVP function".
Here's some example Python code using the Jax library:
import numpy as np
from jax import value_and_grad
from jax.random import gamma, PRNGKey
key = PRNGKey(123)
def f(theta):
return gamma(key, theta)
print("theta z z_dot")
for theta in np.arange(0.1, 2.5, 0.4):
z, z_dot = value_and_grad(f)(theta)
print(f"{theta:.3f} {z:.3f} {z_dot:.3f}")
Prints:
theta z z_dot
0.100 0.001 0.071
0.500 0.425 1.227
0.900 1.013 1.266
1.300 1.489 1.213
1.700 1.973 1.186
2.100 2.446 1.167
Changing the random seed produces different variates and different derivatives.
Actually, looking at this more clearly, it may not be the derivative I need, but rather the Hessian. It's probably too late to change this question so I'll leave it up. If it's not too much of an imposition, and I get a good answer to this, I'll ask another question about the Hessian.
Edit to explain what we're trying to do. First, see the short section on "undifferentiable expectations" here. Now, let's keep all the same notation ($\theta, x, z, \epsilon, f, g, p$). Suppose that you're using a differentiable programming library.
Let's consider the "reverse mode" wherein we want to keep "primals" and "cotangents". In this case, that means that we sample $\epsilon$, calculate $z$, then $L \triangleq f(z)$, which we take to be our loss. Then, we calculate $\ddot z \triangleq \frac{dL}{dz} = f'(z)$, and then by backpropagation, we want $\ddot \theta \triangleq \frac{dL}{d\theta} = \frac{dL}{dz} \frac{dz}{d\theta} = \ddot z \frac{dz}{d\theta}$. It's this latter term that is the subject of this question: $\dot z \triangleq \frac{dz}{d\theta}$.
I realize that this may not be well-defined in general, but it appears to be well-defined for the gamma distribution since Jax has no problem producing it.
You could also examine the code if that's more clear for you. I stochastically produce
– Neil G Aug 17 '23 at 03:04x, x_dotfromt.